File indexing completed on 2025-01-18 10:11:09
0001 #ifndef TMVA_RREADER
0002 #define TMVA_RREADER
0003
0004 #include "TString.h"
0005 #include "TXMLEngine.h"
0006
0007 #include "TMVA/RTensor.hxx"
0008 #include "TMVA/Reader.h"
0009
0010 #include <memory> // std::unique_ptr
0011 #include <sstream> // std::stringstream
0012
0013 namespace TMVA {
0014 namespace Experimental {
0015
0016 namespace Internal {
0017
0018
0019 enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass };
0020
0021
0022 struct XMLConfig {
0023 unsigned int numVariables;
0024 std::vector<std::string> variables;
0025 std::vector<std::string> variable_expressions;
0026 unsigned int numSpectators;
0027 std::vector<std::string> spectators;
0028 std::vector<std::string> spectator_expressions;
0029 unsigned int numClasses;
0030 std::vector<std::string> classes;
0031 AnalysisType analysisType;
0032 XMLConfig()
0033 : numVariables(0), variables(std::vector<std::string>(0)),
0034 numSpectators(0), spectators(std::vector<std::string>(0)),
0035 numClasses(0), classes(std::vector<std::string>(0)),
0036 analysisType(Internal::AnalysisType::Undefined)
0037 {
0038 }
0039 };
0040
0041
0042 inline XMLConfig ParseXMLConfig(const std::string &filename)
0043 {
0044 XMLConfig c;
0045
0046
0047 TXMLEngine xml;
0048 auto xmldoc = xml.ParseFile(filename.c_str());
0049 if (!xmldoc) {
0050 std::stringstream ss;
0051 ss << "Failed to open TMVA XML file "
0052 << filename << ".";
0053 throw std::runtime_error(ss.str());
0054 }
0055 auto mainNode = xml.DocGetRootElement(xmldoc);
0056 for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
0057 const auto nodeName = std::string(xml.GetNodeName(node));
0058
0059 if (nodeName.compare("Variables") == 0) {
0060 c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
0061 c.variables = std::vector<std::string>(c.numVariables);
0062 c.variable_expressions = std::vector<std::string>(c.numVariables);
0063 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0064 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
0065 c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
0066 c.variable_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
0067 }
0068 }
0069
0070 else if (nodeName.compare("Spectators") == 0) {
0071 c.numSpectators = std::atoi(xml.GetAttr(node, "NSpec"));
0072 c.spectators = std::vector<std::string>(c.numSpectators);
0073 c.spectator_expressions = std::vector<std::string>(c.numSpectators);
0074 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0075 const auto iVariable = std::atoi(xml.GetAttr(thisNode, "SpecIndex"));
0076 c.spectators[iVariable] = xml.GetAttr(thisNode, "Title");
0077 c.spectator_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
0078 }
0079 }
0080
0081 else if (nodeName.compare("Classes") == 0) {
0082 c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
0083 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0084 c.classes.push_back(xml.GetAttr(thisNode, "Name"));
0085 }
0086 }
0087
0088 else if (nodeName.compare("GeneralInfo") == 0) {
0089 std::string analysisType = "";
0090 for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0091 if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
0092 analysisType = xml.GetAttr(thisNode, "value");
0093 }
0094 }
0095 if (analysisType.compare("Classification") == 0) {
0096 c.analysisType = Internal::AnalysisType::Classification;
0097 } else if (analysisType.compare("Regression") == 0) {
0098 c.analysisType = Internal::AnalysisType::Regression;
0099 } else if (analysisType.compare("Multiclass") == 0) {
0100 c.analysisType = Internal::AnalysisType::Multiclass;
0101 }
0102 }
0103 }
0104 xml.FreeDoc(xmldoc);
0105
0106
0107 if (c.numVariables != c.variables.size() || c.numVariables == 0) {
0108 std::stringstream ss;
0109 ss << "Failed to parse input variables from TMVA config " << filename << ".";
0110 throw std::runtime_error(ss.str());
0111 }
0112 if (c.numSpectators != c.spectators.size()) {
0113 std::stringstream ss;
0114 ss << "Failed to parse input spectators from TMVA config " << filename << ".";
0115 throw std::runtime_error(ss.str());
0116 }
0117 if (c.numClasses != c.classes.size() || c.numClasses == 0) {
0118 std::stringstream ss;
0119 ss << "Failed to parse output classes from TMVA config " << filename << ".";
0120 throw std::runtime_error(ss.str());
0121 }
0122 if (c.analysisType == Internal::AnalysisType::Undefined) {
0123 std::stringstream ss;
0124 ss << "Failed to parse analysis type from TMVA config " << filename << ".";
0125 throw std::runtime_error(ss.str());
0126 }
0127
0128 return c;
0129 }
0130
0131 }
0132
0133
0134
0135
0136 class RReader {
0137 private:
0138 std::unique_ptr<Reader> fReader;
0139 std::vector<float> fVariableValues;
0140 std::vector<std::string> fVariables;
0141 std::vector<std::string> fVariableExpressions;
0142 std::vector<float> fSpectatorValues;
0143 std::vector<std::string> fSpectators;
0144 std::vector<std::string> fSpectatorExpressions;
0145 unsigned int fNumClasses;
0146 const char *name = "RReader";
0147 Internal::AnalysisType fAnalysisType;
0148
0149 public:
0150
0151 RReader(const std::string &path)
0152 {
0153
0154 auto c = Internal::ParseXMLConfig(path);
0155 fVariables = c.variables;
0156 fVariableExpressions = c.variable_expressions;
0157 fSpectators = c.spectators;
0158 fSpectatorExpressions = c.spectator_expressions;
0159 fAnalysisType = c.analysisType;
0160 fNumClasses = c.numClasses;
0161
0162
0163 fReader = std::make_unique<Reader>("Silent");
0164 const auto numVars = fVariables.size();
0165 fVariableValues = std::vector<float>(numVars);
0166 for (std::size_t i = 0; i < numVars; i++) {
0167 fReader->AddVariable(TString(fVariableExpressions[i]), &fVariableValues[i]);
0168 }
0169 const auto numSpecs = fSpectators.size();
0170 fSpectatorValues = std::vector<float>(numSpecs);
0171 for (std::size_t i = 0; i < numSpecs; i++) {
0172 fReader->AddSpectator(TString(fSpectatorExpressions[i]), &fSpectatorValues[i]);
0173 }
0174 fReader->BookMVA(name, path.c_str());
0175 }
0176
0177
0178 std::vector<float> Compute(const std::vector<float> &x)
0179 {
0180 if (x.size() != (fVariables.size()+fSpectators.size()))
0181 throw std::runtime_error("Size of input vector is not equal to number of variables.");
0182
0183
0184 const auto nVars = fVariables.size();
0185 for (std::size_t i = 0; i != nVars ; ++i) {
0186 fVariableValues[i] = x[i];
0187 }
0188 for (std::size_t i = 0; i != fSpectators.size(); ++i) {
0189 fSpectatorValues[i] = x[nVars+i];
0190 }
0191
0192
0193 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0194
0195
0196
0197 if (fAnalysisType == Internal::AnalysisType::Classification) {
0198 return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
0199 }
0200
0201 else if (fAnalysisType == Internal::AnalysisType::Regression) {
0202 return fReader->EvaluateRegression(name);
0203 }
0204
0205 else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
0206 return fReader->EvaluateMulticlass(name);
0207 }
0208
0209 else {
0210 throw std::runtime_error("RReader has undefined analysis type.");
0211 return std::vector<float>();
0212 }
0213 }
0214
0215
0216 RTensor<float> Compute(RTensor<float> &x)
0217 {
0218
0219 const auto shape = x.GetShape();
0220 if (shape.size() != 2)
0221 throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
0222
0223 const auto numEntries = shape[0];
0224 const auto numVars = shape[1];
0225 if (numVars != (fVariables.size()+fSpectators.size()))
0226 throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
0227
0228
0229 unsigned int numClasses = 1;
0230 if (fAnalysisType == Internal::AnalysisType::Multiclass)
0231 numClasses = fNumClasses;
0232 RTensor<float> y({numEntries * numClasses});
0233 if (fAnalysisType == Internal::AnalysisType::Multiclass)
0234 y = y.Reshape({numEntries, numClasses});
0235
0236
0237 const auto nVars = fVariables.size();
0238 for (std::size_t i = 0; i < numEntries; i++) {
0239 for (std::size_t j = 0; j < nVars; j++) {
0240 fVariableValues[j] = x(i, j);
0241 }
0242 for (std::size_t j = 0; j < fSpectators.size(); ++j) {
0243 fSpectatorValues[j] = x(i, nVars+j);
0244 }
0245 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0246
0247 if (fAnalysisType == Internal::AnalysisType::Classification) {
0248 y(i) = fReader->EvaluateMVA(name);
0249 }
0250
0251 else if (fAnalysisType == Internal::AnalysisType::Regression) {
0252 y(i) = fReader->EvaluateRegression(name)[0];
0253 }
0254
0255 else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
0256 const auto p = fReader->EvaluateMulticlass(name);
0257 for (std::size_t k = 0; k < numClasses; k++)
0258 y(i, k) = p[k];
0259 }
0260 }
0261
0262 return y;
0263 }
0264
0265 std::vector<std::string> GetVariableNames() { return fVariables; }
0266 std::vector<std::string> GetSpectatorNames() { return fSpectators; }
0267 };
0268
0269 }
0270 }
0271
0272 #endif