File indexing completed on 2025-12-15 10:29:20
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015 #ifndef RooFit_RooEvaluatorWrapper_h
0016 #define RooFit_RooEvaluatorWrapper_h
0017
0018 #include <RooAbsData.h>
0019 #include <RooFit/Evaluator.h>
0020 #include <RooGlobalFunc.h>
0021 #include <RooRealProxy.h>
0022 #include <RooSetProxy.h>
0023
0024 #include <stack>
0025
0026 class RooAbsArg;
0027 class RooAbsCategory;
0028 class RooAbsPdf;
0029
0030 namespace RooFit::Experimental {
0031
0032 class RooFuncWrapper;
0033
0034 class RooEvaluatorWrapper final : public RooAbsReal {
0035 public:
0036 RooEvaluatorWrapper(RooAbsReal &topNode, RooAbsData *data, bool useGPU, std::string const &rangeName,
0037 RooAbsPdf const *simPdf, bool takeGlobalObservablesFromData);
0038
0039 RooEvaluatorWrapper(const RooEvaluatorWrapper &other, const char *name = nullptr);
0040
0041 ~RooEvaluatorWrapper();
0042
0043 TObject *clone(const char *newname) const override { return new RooEvaluatorWrapper(*this, newname); }
0044
0045 double defaultErrorLevel() const override { return _topNode->defaultErrorLevel(); }
0046
0047 bool getParameters(const RooArgSet *observables, RooArgSet &outputSet, bool stripDisconnected = true) const override;
0048
0049 bool setData(RooAbsData &data, bool cloneData) override;
0050
0051 double getValV(const RooArgSet *) const override { return evaluate(); }
0052
0053 void applyWeightSquared(bool flag) override { _topNode->applyWeightSquared(flag); }
0054
0055 void printMultiline(std::ostream &os, Int_t , bool = false,
0056 TString = "") const override
0057 {
0058 _evaluator->print(os);
0059 }
0060
0061
0062 void constOptimizeTestStatistic(ConstOpCode , bool ) override {}
0063
0064 bool hasGradient() const override;
0065
0066 void gradient(double *out) const override;
0067
0068 void generateGradient();
0069
0070 void setUseGeneratedFunctionCode(bool);
0071
0072 void writeDebugMacro(std::string const &) const;
0073
0074 protected:
0075 double evaluate() const override;
0076
0077 private:
0078 void createFuncWrapper();
0079
0080 std::shared_ptr<RooFit::Evaluator> _evaluator;
0081 std::shared_ptr<RooFuncWrapper> _funcWrapper;
0082 RooRealProxy _topNode;
0083 RooAbsData *_data = nullptr;
0084 RooSetProxy _paramSet;
0085 std::string _rangeName;
0086 RooAbsPdf const *_pdf = nullptr;
0087 const bool _takeGlobalObservablesFromData;
0088 bool _useGeneratedFunctionCode = false;
0089 std::stack<std::vector<double>> _vectorBuffers;
0090 std::map<RooFit::Detail::DataKey, std::span<const double>> _dataSpans;
0091 };
0092
0093 }
0094
0095 #endif
0096
0097