File indexing completed on 2025-09-17 09:14:10
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014 #ifndef RooFit_Detail_CodegenContext_h
0015 #define RooFit_Detail_CodegenContext_h
0016
0017 #include <RooAbsCollection.h>
0018 #include <RooFit/EvalContext.h>
0019 #include <RooNumber.h>
0020
0021 #include <ROOT/RSpan.hxx>
0022
0023 #include <cstddef>
0024 #include <map>
0025 #include <sstream>
0026 #include <string>
0027 #include <type_traits>
0028 #include <unordered_map>
0029
0030 template <class T>
0031 class RooTemplateProxy;
0032
0033 namespace RooFit {
0034 namespace Experimental {
0035
0036 template <int P>
0037 struct Prio {
0038 static_assert(P >= 1 && P <= 10, "P must be 1 <= P <= 10!");
0039 static auto next() { return Prio<P + 1>{}; }
0040 };
0041
0042 using PrioHighest = Prio<1>;
0043 using PrioLowest = Prio<10>;
0044
0045
0046 class CodegenContext {
0047 public:
0048 void addResult(RooAbsArg const *key, std::string const &value);
0049 void addResult(const char *key, std::string const &value);
0050
0051 std::string const &getResult(RooAbsArg const &arg);
0052
0053 template <class T>
0054 std::string const &getResult(RooTemplateProxy<T> const &key)
0055 {
0056 return getResult(key.arg());
0057 }
0058
0059
0060
0061
0062
0063 std::size_t outputSize(RooFit::Detail::DataKey key) const
0064 {
0065 auto found = _nodeOutputSizes.find(key);
0066 if (found != _nodeOutputSizes.end())
0067 return found->second;
0068 return 1;
0069 }
0070
0071 void addToGlobalScope(std::string const &str);
0072 void addVecObs(const char *key, int idx);
0073
0074 void addToCodeBody(RooAbsArg const *klass, std::string const &in);
0075
0076 void addToCodeBody(std::string const &in, bool isScopeIndep = false);
0077
0078
0079
0080
0081 template <typename... Args_t>
0082 std::string buildCall(std::string const &funcname, Args_t const &...args)
0083 {
0084 std::stringstream ss;
0085 ss << funcname << "(" << buildArgs(args...) << ")";
0086 return ss.str();
0087 }
0088
0089
0090
0091
0092
0093
0094
0095 class LoopScope {
0096 public:
0097 LoopScope(CodegenContext &ctx, std::vector<TNamed const *> &&vars) : _ctx{ctx}, _vars{vars} {}
0098 ~LoopScope() { _ctx.endLoop(*this); }
0099
0100 std::vector<TNamed const *> const &vars() const { return _vars; }
0101
0102 private:
0103 CodegenContext &_ctx;
0104 const std::vector<TNamed const *> _vars;
0105 };
0106
0107 std::unique_ptr<LoopScope> beginLoop(RooAbsArg const *in);
0108
0109 std::string getTmpVarName() const;
0110
0111 std::string buildArg(RooAbsCollection const &x);
0112
0113 std::string buildArg(std::span<const double> arr);
0114 std::string buildArg(std::span<const int> arr) { return buildArgSpanImpl(arr); }
0115
0116 std::vector<double> const &xlArr() { return _xlArr; }
0117
0118 void collectFunction(std::string const &name);
0119 std::vector<std::string> const &collectedFunctions() { return _collectedFunctions; }
0120
0121 std::string
0122 buildFunction(RooAbsArg const &arg, std::map<RooFit::Detail::DataKey, std::size_t> const &outputSizes = {});
0123
0124 auto const &outputSizes() const { return _nodeOutputSizes; }
0125
0126 struct ScopeRAII {
0127 std::string _fn;
0128 CodegenContext &_ctx;
0129 RooAbsArg const *_arg;
0130
0131 public:
0132 ScopeRAII(RooAbsArg const *arg, CodegenContext &ctx);
0133 ~ScopeRAII();
0134 };
0135 ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg) { return {arg, *this}; }
0136
0137 private:
0138 void pushScope();
0139 void popScope();
0140 template <class T>
0141 std::string buildArgSpanImpl(std::span<const T> arr);
0142
0143 bool isScopeIndependent(RooAbsArg const *in) const;
0144
0145 void endLoop(LoopScope const &scope);
0146
0147 void addResult(TNamed const *key, std::string const &value);
0148
0149 template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
0150 std::string buildArg(T x)
0151 {
0152 return RooNumber::toString(x);
0153 }
0154
0155
0156 template <class T, typename std::enable_if<std::is_integral<T>{}, bool>::type = true>
0157 std::string buildArg(T x)
0158 {
0159 return std::to_string(x);
0160 }
0161
0162 std::string buildArg(std::string const &x) { return x; }
0163
0164 std::string buildArg(std::nullptr_t) { return "nullptr"; }
0165
0166 std::string buildArg(RooAbsArg const &arg) { return getResult(arg); }
0167
0168 template <class T>
0169 std::string buildArg(RooTemplateProxy<T> const &arg)
0170 {
0171 return getResult(arg);
0172 }
0173
0174 std::string buildArgs() { return ""; }
0175
0176 template <class Arg_t>
0177 std::string buildArgs(Arg_t const &arg)
0178 {
0179 return buildArg(arg);
0180 }
0181
0182 template <typename Arg_t, typename... Args_t>
0183 std::string buildArgs(Arg_t const &arg, Args_t const &...args)
0184 {
0185 return buildArg(arg) + ", " + buildArgs(args...);
0186 }
0187
0188 template <class T>
0189 std::string typeName() const;
0190
0191
0192 std::unordered_map<const TNamed *, std::string> _nodeNames;
0193
0194 std::unordered_map<const TNamed *, int> _vecObsIndices;
0195
0196 std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
0197
0198 std::vector<std::string> _code;
0199
0200 unsigned _indent = 0;
0201
0202 mutable int _tmpVarIdx = 0;
0203
0204 std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
0205 std::vector<double> _xlArr;
0206 std::vector<std::string> _collectedFunctions;
0207 };
0208
0209 template <>
0210 inline std::string CodegenContext::typeName<double>() const
0211 {
0212 return "double";
0213 }
0214 template <>
0215 inline std::string CodegenContext::typeName<int>() const
0216 {
0217 return "int";
0218 }
0219
0220 template <class T>
0221 std::string CodegenContext::buildArgSpanImpl(std::span<const T> arr)
0222 {
0223 unsigned int n = arr.size();
0224 std::string arrName = getTmpVarName();
0225 std::string arrDecl = typeName<T>() + " " + arrName + "[" + std::to_string(n) + "] = {";
0226 for (unsigned int i = 0; i < n; i++) {
0227 arrDecl += " " + std::to_string(arr[i]) + ",";
0228 }
0229 arrDecl.back() = '}';
0230 arrDecl += ";\n";
0231 addToCodeBody(arrDecl, true);
0232
0233 return arrName;
0234 }
0235
0236 void declareDispatcherCode(std::string const &funcName);
0237
0238 void codegen(RooAbsArg &arg, CodegenContext &ctx);
0239
0240 }
0241 }
0242
0243 #endif