Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:14

0001 /*
0002  * Project: RooFit
0003  * Authors:
0004  *   Garima Singh, CERN 2023
0005  *   Jonas Rembser, CERN 2023
0006  *
0007  * Copyright (c) 2023, CERN
0008  *
0009  * Redistribution and use in source and binary forms,
0010  * with or without modification, are permitted according to the terms
0011  * listed in LICENSE (http://roofit.sourceforge.net/license.txt)
0012  */
0013 
0014 #ifndef RooFit_Detail_CodeSquashContext_h
0015 #define RooFit_Detail_CodeSquashContext_h
0016 
0017 #include <RooAbsCollection.h>
0018 #include <RooFit/EvalContext.h>
0019 #include <RooNumber.h>
0020 
0021 #include <ROOT/RSpan.hxx>
0022 
0023 #include <cstddef>
0024 #include <map>
0025 #include <sstream>
0026 #include <string>
0027 #include <type_traits>
0028 #include <unordered_map>
0029 
0030 template <class T>
0031 class RooTemplateProxy;
0032 
0033 namespace RooFit {
0034 
0035 namespace Experimental {
0036 class RooFuncWrapper;
0037 }
0038 
0039 namespace Detail {
0040 
0041 /// @brief A class to maintain the context for squashing of RooFit models into code.
0042 class CodeSquashContext {
0043 public:
0044    CodeSquashContext(std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes, std::vector<double> &xlarr, Experimental::RooFuncWrapper &wrapper);
0045 
0046    void addResult(RooAbsArg const *key, std::string const &value);
0047    void addResult(const char *key, std::string const &value);
0048 
0049    std::string const &getResult(RooAbsArg const &arg);
0050 
0051    template <class T>
0052    std::string const &getResult(RooTemplateProxy<T> const &key)
0053    {
0054       return getResult(key.arg());
0055    }
0056 
0057    /// @brief Figure out the output size of a node. It is the size of the
0058    /// vector observable that it depends on, or 1 if it doesn't depend on any
0059    /// or is a reducer node.
0060    /// @param key The node to look up the size for.
0061    std::size_t outputSize(RooFit::Detail::DataKey key) const
0062    {
0063       auto found = _nodeOutputSizes.find(key);
0064       if (found != _nodeOutputSizes.end())
0065          return found->second;
0066       return 1;
0067    }
0068 
0069    void addToGlobalScope(std::string const &str);
0070    std::string assembleCode(std::string const &returnExpr);
0071    void addVecObs(const char *key, int idx);
0072 
0073    void addToCodeBody(RooAbsArg const *klass, std::string const &in);
0074 
0075    void addToCodeBody(std::string const &in, bool isScopeIndep = false);
0076 
0077    /// @brief Build the code to call the function with name `funcname`, passing some arguments.
0078    /// The arguments can either be doubles or some RooFit arguments whose
0079    /// results will be looked up in the context.
0080    template <typename... Args_t>
0081    std::string buildCall(std::string const &funcname, Args_t const &...args)
0082    {
0083       std::stringstream ss;
0084       ss << funcname << "(" << buildArgs(args...) << ")";
0085       return ss.str();
0086    }
0087 
0088    /// @brief A class to manage loop scopes using the RAII technique. To wrap your code around a loop,
0089    /// simply place it between a brace inclosed scope with a call to beginLoop at the top. For e.g.
0090    /// {
0091    ///   auto scope = ctx.beginLoop({<-set of vector observables to loop over->});
0092    ///   // your loop body code goes here.
0093    /// }
0094    class LoopScope {
0095    public:
0096       LoopScope(CodeSquashContext &ctx, std::vector<TNamed const *> &&vars) : _ctx{ctx}, _vars{vars} {}
0097       ~LoopScope() { _ctx.endLoop(*this); }
0098 
0099       std::vector<TNamed const *> const &vars() const { return _vars; }
0100 
0101    private:
0102       CodeSquashContext &_ctx;
0103       const std::vector<TNamed const *> _vars;
0104    };
0105 
0106    std::unique_ptr<LoopScope> beginLoop(RooAbsArg const *in);
0107 
0108    std::string getTmpVarName() const;
0109 
0110    std::string buildArg(RooAbsCollection const &x);
0111 
0112    std::string buildArg(std::span<const double> arr);
0113    std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }
0114 
0115    Experimental::RooFuncWrapper *_wrapper = nullptr;
0116 
0117 private:
0118    template <class T>
0119    std::string buildArgSpanImpl(std::span<const T> arr);
0120 
0121    bool isScopeIndependent(RooAbsArg const *in) const;
0122 
0123    void endLoop(LoopScope const &scope);
0124 
0125    void addResult(TNamed const *key, std::string const &value);
0126 
0127    template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
0128    std::string buildArg(T x)
0129    {
0130       return RooNumber::toString(x);
0131    }
0132 
0133    // If input is integer, we want to print it into the code like one (i.e. avoid the unnecessary '.0000').
0134    template <class T, typename std::enable_if<std::is_integral<T>{}, bool>::type = true>
0135    std::string buildArg(T x)
0136    {
0137       return std::to_string(x);
0138    }
0139 
0140    std::string buildArg(std::string const &x) { return x; }
0141 
0142    std::string buildArg(std::nullptr_t) { return "nullptr"; }
0143 
0144    std::string buildArg(RooAbsArg const &arg) { return getResult(arg); }
0145 
0146    template <class T>
0147    std::string buildArg(RooTemplateProxy<T> const &arg)
0148    {
0149       return getResult(arg);
0150    }
0151 
0152    std::string buildArgs() { return ""; }
0153 
0154    template <class Arg_t>
0155    std::string buildArgs(Arg_t const &arg)
0156    {
0157       return buildArg(arg);
0158    }
0159 
0160    template <typename Arg_t, typename... Args_t>
0161    std::string buildArgs(Arg_t const &arg, Args_t const &...args)
0162    {
0163       return buildArg(arg) + ", " + buildArgs(args...);
0164    }
0165 
0166    template <class T>
0167    std::string typeName() const;
0168 
0169    /// @brief Map of node names to their result strings.
0170    std::unordered_map<const TNamed *, std::string> _nodeNames;
0171    /// @brief Block of code that is placed before the rest of the function body.
0172    std::string _globalScope;
0173    /// @brief A map to keep track of the observable indices if they are non scalar.
0174    std::unordered_map<const TNamed *, int> _vecObsIndices;
0175    /// @brief Map of node output sizes.
0176    std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
0177    /// @brief Stores the squashed code body.
0178    std::string _code;
0179    /// @brief The current number of for loops the started.
0180    int _loopLevel = 0;
0181    /// @brief Index to get unique names for temporary variables.
0182    mutable int _tmpVarIdx = 0;
0183    /// @brief Keeps track of the position to go back and insert code to.
0184    int _scopePtr = -1;
0185    /// @brief Stores code that eventually gets injected into main code body.
0186    /// Mainly used for placing decls outside of loops.
0187    std::string _tempScope;
0188    /// @brief A map to keep track of list names as assigned by addResult.
0189    std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> listNames;
0190    std::vector<double> &_xlArr;
0191 };
0192 
0193 template <>
0194 inline std::string CodeSquashContext::typeName<double>() const
0195 {
0196    return "double";
0197 }
0198 template <>
0199 inline std::string CodeSquashContext::typeName<int>() const
0200 {
0201    return "int";
0202 }
0203 
0204 template <class T>
0205 std::string CodeSquashContext::buildArgSpanImpl(std::span<const T> arr)
0206 {
0207    unsigned int n = arr.size();
0208    std::string arrName = getTmpVarName();
0209    std::string arrDecl = typeName<T>() + " " + arrName + "[" + std::to_string(n) + "] = {";
0210    for (unsigned int i = 0; i < n; i++) {
0211       arrDecl += " " + std::to_string(arr[i]) + ",";
0212    }
0213    arrDecl.back() = '}';
0214    arrDecl += ";\n";
0215    addToCodeBody(arrDecl, true);
0216 
0217    return arrName;
0218 }
0219 
0220 } // namespace Detail
0221 
0222 } // namespace RooFit
0223 
0224 #endif