Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 09:14:10

0001 /*
0002  * Project: RooFit
0003  * Authors:
0004  *   Garima Singh, CERN 2023
0005  *   Jonas Rembser, CERN 2023
0006  *
0007  * Copyright (c) 2023, CERN
0008  *
0009  * Redistribution and use in source and binary forms,
0010  * with or without modification, are permitted according to the terms
0011  * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
0012  */
0013 
0014 #ifndef RooFit_Detail_CodegenContext_h
0015 #define RooFit_Detail_CodegenContext_h
0016 
0017 #include <RooAbsCollection.h>
0018 #include <RooFit/EvalContext.h>
0019 #include <RooNumber.h>
0020 
0021 #include <ROOT/RSpan.hxx>
0022 
0023 #include <cstddef>
0024 #include <map>
0025 #include <sstream>
0026 #include <string>
0027 #include <type_traits>
0028 #include <unordered_map>
0029 
0030 template <class T>
0031 class RooTemplateProxy;
0032 
0033 namespace RooFit {
0034 namespace Experimental {
0035 
0036 template <int P>
0037 struct Prio {
0038    static_assert(P >= 1 && P <= 10, "P must be 1 <= P <= 10!");
0039    static auto next() { return Prio<P + 1>{}; }
0040 };
0041 
0042 using PrioHighest = Prio<1>;
0043 using PrioLowest = Prio<10>;
0044 
0045 /// @brief A class to maintain the context for squashing of RooFit models into code.
0046 class CodegenContext {
0047 public:
0048    void addResult(RooAbsArg const *key, std::string const &value);
0049    void addResult(const char *key, std::string const &value);
0050 
0051    std::string const &getResult(RooAbsArg const &arg);
0052 
0053    template <class T>
0054    std::string const &getResult(RooTemplateProxy<T> const &key)
0055    {
0056       return getResult(key.arg());
0057    }
0058 
0059    /// @brief Figure out the output size of a node. It is the size of the
0060    /// vector observable that it depends on, or 1 if it doesn't depend on any
0061    /// or is a reducer node.
0062    /// @param key The node to look up the size for.
0063    std::size_t outputSize(RooFit::Detail::DataKey key) const
0064    {
0065       auto found = _nodeOutputSizes.find(key);
0066       if (found != _nodeOutputSizes.end())
0067          return found->second;
0068       return 1;
0069    }
0070 
0071    void addToGlobalScope(std::string const &str);
0072    void addVecObs(const char *key, int idx);
0073 
0074    void addToCodeBody(RooAbsArg const *klass, std::string const &in);
0075 
0076    void addToCodeBody(std::string const &in, bool isScopeIndep = false);
0077 
0078    /// @brief Build the code to call the function with name `funcname`, passing some arguments.
0079    /// The arguments can either be doubles or some RooFit arguments whose
0080    /// results will be looked up in the context.
0081    template <typename... Args_t>
0082    std::string buildCall(std::string const &funcname, Args_t const &...args)
0083    {
0084       std::stringstream ss;
0085       ss << funcname << "(" << buildArgs(args...) << ")";
0086       return ss.str();
0087    }
0088 
0089    /// @brief A class to manage loop scopes using the RAII technique. To wrap your code around a loop,
0090    /// simply place it between a brace inclosed scope with a call to beginLoop at the top. For e.g.
0091    /// {
0092    ///   auto scope = ctx.beginLoop({<-set of vector observables to loop over->});
0093    ///   // your loop body code goes here.
0094    /// }
0095    class LoopScope {
0096    public:
0097       LoopScope(CodegenContext &ctx, std::vector<TNamed const *> &&vars) : _ctx{ctx}, _vars{vars} {}
0098       ~LoopScope() { _ctx.endLoop(*this); }
0099 
0100       std::vector<TNamed const *> const &vars() const { return _vars; }
0101 
0102    private:
0103       CodegenContext &_ctx;
0104       const std::vector<TNamed const *> _vars;
0105    };
0106 
0107    std::unique_ptr<LoopScope> beginLoop(RooAbsArg const *in);
0108 
0109    std::string getTmpVarName() const;
0110 
0111    std::string buildArg(RooAbsCollection const &x);
0112 
0113    std::string buildArg(std::span<const double> arr);
0114    std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }
0115 
0116    std::vector<double> const &xlArr() { return _xlArr; }
0117 
0118    void collectFunction(std::string const &name);
0119    std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
0120 
0121    std::string
0122    buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes = {});
0123 
0124    auto const &outputSizes() const { return _nodeOutputSizes; }
0125 
0126    struct ScopeRAII {
0127       std::string _fn;
0128       CodegenContext &_ctx;
0129       RooAbsArg const *_arg;
0130 
0131    public:
0132       ScopeRAII(RooAbsArg const *arg, CodegenContext &ctx);
0133       ~ScopeRAII();
0134    };
0135    ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg) { return {arg, *this}; }
0136 
0137 private:
0138    void pushScope();
0139    void popScope();
0140    template <class T>
0141    std::string buildArgSpanImpl(std::span<const T> arr);
0142 
0143    bool isScopeIndependent(RooAbsArg const *in) const;
0144 
0145    void endLoop(LoopScope const &scope);
0146 
0147    void addResult(TNamed const *key, std::string const &value);
0148 
0149    template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
0150    std::string buildArg(T x)
0151    {
0152       return RooNumber::toString(x);
0153    }
0154 
0155    // If input is integer, we want to print it into the code like one (i.e. avoid the unnecessary '.0000').
0156    template <class T, typename std::enable_if<std::is_integral<T>{}, bool>::type = true>
0157    std::string buildArg(T x)
0158    {
0159       return std::to_string(x);
0160    }
0161 
0162    std::string buildArg(std::string const &x) { return x; }
0163 
0164    std::string buildArg(std::nullptr_t) { return "nullptr"; }
0165 
0166    std::string buildArg(RooAbsArg const &arg) { return getResult(arg); }
0167 
0168    template <class T>
0169    std::string buildArg(RooTemplateProxy<T> const &arg)
0170    {
0171       return getResult(arg);
0172    }
0173 
0174    std::string buildArgs() { return ""; }
0175 
0176    template <class Arg_t>
0177    std::string buildArgs(Arg_t const &arg)
0178    {
0179       return buildArg(arg);
0180    }
0181 
0182    template <typename Arg_t, typename... Args_t>
0183    std::string buildArgs(Arg_t const &arg, Args_t const &...args)
0184    {
0185       return buildArg(arg) + ", " + buildArgs(args...);
0186    }
0187 
0188    template <class T>
0189    std::string typeName() const;
0190 
0191    /// @brief Map of node names to their result strings.
0192    std::unordered_map<const TNamed *, std::string> _nodeNames;
0193    /// @brief A map to keep track of the observable indices if they are non scalar.
0194    std::unordered_map<const TNamed *, int> _vecObsIndices;
0195    /// @brief Map of node output sizes.
0196    std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
0197    /// @brief The code layered by lexical scopes used as a stack.
0198    std::vector<std::string> _code;
0199    /// @brief The indentation level for pretty-printing.
0200    unsigned _indent = 0;
0201    /// @brief Index to get unique names for temporary variables.
0202    mutable int _tmpVarIdx = 0;
0203    /// @brief A map to keep track of list names as assigned by addResult.
0204    std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
0205    std::vector<double> _xlArr;
0206    std::vector<std::string> _collectedFunctions;
0207 };
0208 
0209 template <>
0210 inline std::string CodegenContext::typeName<double>() const
0211 {
0212    return "double";
0213 }
0214 template <>
0215 inline std::string CodegenContext::typeName<int>() const
0216 {
0217    return "int";
0218 }
0219 
0220 template <class T>
0221 std::string CodegenContext::buildArgSpanImpl(std::span<const T> arr)
0222 {
0223    unsigned int n = arr.size();
0224    std::string arrName = getTmpVarName();
0225    std::string arrDecl = typeName<T>() + " " + arrName + "[" + std::to_string(n) + "] = {";
0226    for (unsigned int i = 0; i < n; i++) {
0227       arrDecl += " " + std::to_string(arr[i]) + ",";
0228    }
0229    arrDecl.back() = '}';
0230    arrDecl += ";\n";
0231    addToCodeBody(arrDecl, true);
0232 
0233    return arrName;
0234 }
0235 
0236 void declareDispatcherCode(std::string const &funcName);
0237 
0238 void codegen(RooAbsArg &arg, CodegenContext &ctx);
0239 
0240 } // namespace Experimental
0241 } // namespace RooFit
0242 
0243 #endif