File indexing completed on 2025-01-18 10:11:04
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 #ifndef TMVA_RBDT
0021 #define TMVA_RBDT
0022
0023 #include <Rtypes.h>
0024 #include <ROOT/RSpan.hxx>
0025 #include <TMVA/RTensor.hxx>
0026
0027 #include <array>
0028 #include <istream>
0029 #include <string>
0030 #include <unordered_map>
0031 #include <vector>
0032
0033 namespace TMVA {
0034
0035 namespace Experimental {
0036
0037 class RBDT final {
0038 public:
0039 typedef float Value_t;
0040
0041
0042 RBDT() = default;
0043
0044
0045 RBDT(const std::string &key, const std::string &filename);
0046
0047
0048
0049
0050
0051 template <typename Vector>
0052 Vector Compute(const Vector &x) const
0053 {
0054 std::size_t nOut = fBaseResponses.size() > 2 ? fBaseResponses.size() : 1;
0055 Vector y(nOut);
0056 ComputeImpl(x.data(), y.data());
0057 return y;
0058 }
0059
0060
0061 inline std::vector<Value_t> Compute(std::vector<Value_t> const &x) const { return Compute<std::vector<Value_t>>(x); }
0062
0063 RTensor<Value_t> Compute(RTensor<Value_t> const &x) const;
0064
0065 static RBDT LoadText(std::string const &txtpath, std::vector<std::string> &features, int nClasses, bool logistic,
0066 Value_t baseScore);
0067
0068 private:
0069
0070 using IndexMap = std::unordered_map<int, int>;
0071
0072 void Softmax(const Value_t *array, Value_t *out) const;
0073 void ComputeImpl(const Value_t *array, Value_t *out) const;
0074 Value_t EvaluateBinary(const Value_t *array) const;
0075 static void correctIndices(std::span<int> indices, IndexMap const &nodeIndices, IndexMap const &leafIndices);
0076 static void terminateTree(TMVA::Experimental::RBDT &ff, int &nPreviousNodes, int &nPreviousLeaves,
0077 IndexMap &nodeIndices, IndexMap &leafIndices, int &treesSkipped);
0078 static RBDT
0079 LoadText(std::istream &is, std::vector<std::string> &features, int nClasses, bool logistic, Value_t baseScore);
0080
0081 std::vector<int> fRootIndices;
0082 std::vector<unsigned int> fCutIndices;
0083 std::vector<Value_t> fCutValues;
0084 std::vector<int> fLeftIndices;
0085 std::vector<int> fRightIndices;
0086 std::vector<Value_t> fResponses;
0087 std::vector<int> fTreeNumbers;
0088 std::vector<Value_t> fBaseResponses;
0089 Value_t fBaseScore = 0.0;
0090 bool fLogistic = false;
0091
0092 ClassDefNV(RBDT, 1);
0093 };
0094
0095 }
0096
0097 }
0098
0099 #endif