Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:04

0001 /**********************************************************************************
0002  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis       *
0003  * Package: TMVA                                                                  *
0004  *                                                                                *
0005  *                                                                                *
0006  * Description:                                                                   *
0007  *                                                                                *
0008  * Authors:                                                                       *
0009  *      Stefan Wunsch (stefan.wunsch@cern.ch)                                     *
0010  *      Jonas Rembser (jonas.rembser@cern.ch)                                     *
0011  *                                                                                *
0012  * Copyright (c) 2024:                                                            *
0013  *      CERN, Switzerland                                                         *
0014  *                                                                                *
0015  * Redistribution and use in source and binary forms, with or without             *
0016  * modification, are permitted according to the terms listed in LICENSE           *
0017  * (see tmva/doc/LICENSE)                                          *
0018  **********************************************************************************/
0019 
0020 #ifndef TMVA_RBDT
0021 #define TMVA_RBDT
0022 
0023 #include <Rtypes.h>
0024 #include <ROOT/RSpan.hxx>
0025 #include <TMVA/RTensor.hxx>
0026 
0027 #include <array>
0028 #include <istream>
0029 #include <string>
0030 #include <unordered_map>
0031 #include <vector>
0032 
0033 namespace TMVA {
0034 
0035 namespace Experimental {
0036 
0037 class RBDT final {
0038 public:
0039    typedef float Value_t;
0040 
0041    /// IO constructor (both for ROOT IO and LoadText()).
0042    RBDT() = default;
0043 
0044    /// Construct backends from model in ROOT file.
0045    RBDT(const std::string &key, const std::string &filename);
0046 
0047    /// Compute model prediction on a single event.
0048    ///
0049    /// The method is intended to be used with std::vectors-like containers,
0050    /// for example RVecs.
0051    template <typename Vector>
0052    Vector Compute(const Vector &x) const
0053    {
0054       std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
0055       Vector y(nOut);
0056       ComputeImpl(x.data(), y.data());
0057       return y;
0058    }
0059 
0060    /// Compute model prediction on a single event.
0061    inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }
0062 
0063    RTensor<Value_t> Compute(RTensor<Value_t> const &x) const;
0064 
0065    static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses, bool logistic,
0066                         Value_t baseScore);
0067 
0068 private:
0069    /// Map from XGBoost to RBDT indices.
0070    using IndexMap = std::unordered_map<int, int>;
0071 
0072    void Softmax(const Value_t *array, Value_t *out) const;
0073    void ComputeImpl(const Value_t *array, Value_t *out) const;
0074    Value_t EvaluateBinary(const Value_t *array) const;
0075    static void correctIndices(std::span<int> indices, IndexMap const &nodeIndices, IndexMap const &leafIndices);
0076    static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves,
0077                              IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped);
0078    static RBDT
0079    LoadText(std::istream &is, std::vector<std::string> &features, int nClasses, bool logistic, Value_t baseScore);
0080 
0081    std::vector<int> fRootIndices;
0082    std::vector<unsigned int> fCutIndices;
0083    std::vector<Value_t> fCutValues;
0084    std::vector<int> fLeftIndices;
0085    std::vector<int> fRightIndices;
0086    std::vector<Value_t> fResponses;
0087    std::vector<int> fTreeNumbers;
0088    std::vector<Value_t> fBaseResponses;
0089    Value_t fBaseScore = 0.0;
0090    bool fLogistic = false;
0091 
0092    ClassDefNV(RBDT, 1);
0093 };
0094 
0095 } // namespace Experimental
0096 
0097 } // namespace TMVA
0098 
0099 #endif // TMVA_RBDT