Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:29:43

0001 /*
0002  * Project: RooFit
0003  * Authors:
0004  *   Garima Singh, CERN 2023
0005  *   Jonas Rembser, CERN 2023
0006  *
0007  * Copyright (c) 2023, CERN
0008  *
0009  * Redistribution and use in source and binary forms,
0010  * with or without modification, are permitted according to the terms
0011  * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
0012  */
0013 
0014 #ifndef RooFit_Detail_CodegenContext_h
0015 #define RooFit_Detail_CodegenContext_h
0016 
0017 #include <RooAbsCollection.h>
0018 #include <RooFit/EvalContext.h>
0019 
0020 #include <ROOT/RSpan.hxx>
0021 
0022 #include <cstddef>
0023 #include <iomanip>
0024 #include <map>
0025 #include <sstream>
0026 #include <string>
0027 #include <type_traits>
0028 #include <unordered_map>
0029 
0030 template <class T>
0031 class RooTemplateProxy;
0032 
0033 namespace RooFit {
0034 namespace Experimental {
0035 
0036 template <int P>
0037 struct Prio {
0038    static_assert(P >= 1 && P <= 10, "P must be 1 <= P <= 10!");
0039    static auto next() { return Prio<P + 1>{}; }
0040 };
0041 
0042 using PrioHighest = Prio<1>;
0043 using PrioLowest = Prio<10>;
0044 
0045 /// @brief A class to maintain the context for squashing of RooFit models into code.
0046 class CodegenContext {
0047 public:
0048    void addResult(RooAbsArg const *key, std::string const &value);
0049    void addResult(const char *key, std::string const &value);
0050 
0051    std::string const &getResult(RooAbsArg const &arg);
0052 
0053    template <class T>
0054    std::string const &getResult(RooTemplateProxy<T> const &key)
0055    {
0056       return getResult(key.arg());
0057    }
0058 
0059    /// @brief Figure out the output size of a node. It is the size of the
0060    /// vector observable that it depends on, or 1 if it doesn't depend on any
0061    /// or is a reducer node.
0062    /// @param key The node to look up the size for.
0063    std::size_t outputSize(RooFit::Detail::DataKey key) const
0064    {
0065       auto found = _nodeOutputSizes.find(key);
0066       if (found != _nodeOutputSizes.end())
0067          return found->second;
0068       return 1;
0069    }
0070 
0071    void addToGlobalScope(std::string const &str);
0072    void addVecObs(const char *key, int idx);
0073    int observableIndexOf(const RooAbsArg &arg) const;
0074 
0075    void addToCodeBody(RooAbsArg const *klass, std::string const &in);
0076 
0077    void addToCodeBody(std::string const &in, bool isScopeIndep = false);
0078 
0079    /// @brief Build the code to call the function with name `funcname`, passing some arguments.
0080    /// The arguments can either be doubles or some RooFit arguments whose
0081    /// results will be looked up in the context.
0082    template <typename... Args_t>
0083    std::string buildCall(std::string const &funcname, Args_t const &...args)
0084    {
0085       std::stringstream ss;
0086       ss << funcname << "(" << buildArgs(args...) << ")";
0087       return ss.str();
0088    }
0089 
0090    /// @brief A class to manage loop scopes using the RAII technique. To wrap your code around a loop,
0091    /// simply place it between a brace inclosed scope with a call to beginLoop at the top. For e.g.
0092    /// {
0093    ///   auto scope = ctx.beginLoop({<-set of vector observables to loop over->});
0094    ///   // your loop body code goes here.
0095    /// }
0096    class LoopScope {
0097    public:
0098       LoopScope(CodegenContext &ctx, std::vector<TNamed const *> &&vars) : _ctx{ctx}, _vars{vars} {}
0099       ~LoopScope() { _ctx.endLoop(*this); }
0100 
0101       std::vector<TNamed const *> const &vars() const { return _vars; }
0102 
0103    private:
0104       CodegenContext &_ctx;
0105       const std::vector<TNamed const *> _vars;
0106    };
0107 
0108    std::unique_ptr<LoopScope> beginLoop(RooAbsArg const *in);
0109 
0110    std::string getTmpVarName() const;
0111 
0112    std::string buildArg(RooAbsCollection const &x, std::string const &arrayType = "double");
0113 
0114    std::string buildArg(std::span<const double> arr);
0115    std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }
0116 
0117    std::vector<double> const &xlArr() { return _xlArr; }
0118 
0119    void collectFunction(std::string const &name);
0120    std::string const &collectedCode() { return _collectedCode; }
0121    std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
0122 
0123    std::string
0124    buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes = {});
0125 
0126    auto const &outputSizes() const { return _nodeOutputSizes; }
0127 
0128    struct ScopeRAII {
0129       std::string _fn;
0130       CodegenContext &_ctx;
0131       RooAbsArg const *_arg;
0132 
0133    public:
0134       ScopeRAII(RooAbsArg const *arg, CodegenContext &ctx);
0135       ~ScopeRAII();
0136    };
0137    ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg) { return {arg, *this}; }
0138 
0139 private:
0140    void pushScope();
0141    void popScope();
0142    template <class T>
0143    std::string buildArgSpanImpl(std::span<const T> arr);
0144 
0145    bool isScopeIndependent(RooAbsArg const *in) const;
0146 
0147    void endLoop(LoopScope const &scope);
0148 
0149    void addResult(TNamed const *key, std::string const &value);
0150 
0151    template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
0152    std::string buildArg(T x)
0153    {
0154       std::stringstream ss;
0155       ss << std::setprecision(std::numeric_limits<double>::max_digits10) << x;
0156       return ss.str();
0157    }
0158 
0159    // If input is integer, we want to print it into the code like one (i.e. avoid the unnecessary '.0000').
0160    template <class T, typename std::enable_if<std::is_integral<T>{}, bool>::type = true>
0161    std::string buildArg(T x)
0162    {
0163       return std::to_string(x);
0164    }
0165 
0166    std::string buildArg(std::string const &x) { return x; }
0167 
0168    std::string buildArg(std::nullptr_t) { return "nullptr"; }
0169 
0170    std::string buildArg(RooAbsArg const &arg) { return getResult(arg); }
0171 
0172    template <class T>
0173    std::string buildArg(RooTemplateProxy<T> const &arg)
0174    {
0175       return getResult(arg);
0176    }
0177 
0178    std::string buildArgs() { return ""; }
0179 
0180    template <class Arg_t>
0181    std::string buildArgs(Arg_t const &arg)
0182    {
0183       return buildArg(arg);
0184    }
0185 
0186    template <typename Arg_t, typename... Args_t>
0187    std::string buildArgs(Arg_t const &arg, Args_t const &...args)
0188    {
0189       return buildArg(arg) + ", " + buildArgs(args...);
0190    }
0191 
0192    template <class T>
0193    std::string typeName() const;
0194 
0195    /// @brief Map of node names to their result strings.
0196    std::unordered_map<const TNamed *, std::string> _nodeNames;
0197    /// @brief A map to keep track of the observable indices if they are non scalar.
0198    std::unordered_map<const TNamed *, int> _vecObsIndices;
0199    /// @brief Map of node output sizes.
0200    std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
0201    /// @brief The code layered by lexical scopes used as a stack.
0202    std::vector<std::string> _code;
0203    /// @brief The indentation level for pretty-printing.
0204    unsigned _indent = 0;
0205    /// @brief Index to get unique names for temporary variables.
0206    mutable int _tmpVarIdx = 0;
0207    /// @brief A map to keep track of list names as assigned by addResult.
0208    std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
0209    std::vector<double> _xlArr;
0210    std::vector<std::string> _collectedFunctions;
0211    std::string _collectedCode;
0212 };
0213 
0214 template <>
0215 inline std::string CodegenContext::typeName<double>() const
0216 {
0217    return "double";
0218 }
0219 template <>
0220 inline std::string CodegenContext::typeName<int>() const
0221 {
0222    return "int";
0223 }
0224 
0225 template <class T>
0226 std::string CodegenContext::buildArgSpanImpl(std::span<const T> arr)
0227 {
0228    unsigned int n = arr.size();
0229    std::string arrName = getTmpVarName();
0230    std::stringstream ss;
0231    ss << typeName<T>() << " " << arrName << "[" << n << "] = {";
0232    for (unsigned int i = 0; i < n; i++) {
0233       ss << " " << arr[i] << ",";
0234    }
0235    std::string arrDecl = ss.str();
0236    arrDecl.back() = '}';
0237    arrDecl += ";\n";
0238    addToCodeBody(arrDecl, true);
0239 
0240    return arrName;
0241 }
0242 
0243 void declareDispatcherCode(std::string const &funcName);
0244 
0245 void codegen(RooAbsArg &arg, CodegenContext &ctx);
0246 
0247 } // namespace Experimental
0248 } // namespace RooFit
0249 
0250 #endif