Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:09

0001 #ifndef TMVA_RREADER
0002 #define TMVA_RREADER
0003 
0004 #include "TString.h"
0005 #include "TXMLEngine.h"
0006 
0007 #include "TMVA/RTensor.hxx"
0008 #include "TMVA/Reader.h"
0009 
0010 #include <memory> // std::unique_ptr
0011 #include <sstream> // std::stringstream
0012 
0013 namespace TMVA {
0014 namespace Experimental {
0015 
0016 namespace Internal {
0017 
0018 /// Internal definition of analysis types
0019 enum AnalysisType : unsigned int { Undefined = 0, Classification, Regression, Multiclass };
0020 
0021 /// Container for information extracted from TMVA XML config
0022 struct XMLConfig {
0023    unsigned int numVariables;
0024    std::vector<std::string> variables;
0025    std::vector<std::string> variable_expressions;
0026    unsigned int numSpectators;
0027    std::vector<std::string> spectators;
0028    std::vector<std::string> spectator_expressions;
0029    unsigned int numClasses;
0030    std::vector<std::string> classes;
0031    AnalysisType analysisType;
0032    XMLConfig()
0033       : numVariables(0), variables(std::vector<std::string>(0)),
0034         numSpectators(0), spectators(std::vector<std::string>(0)),
0035         numClasses(0), classes(std::vector<std::string>(0)),
0036         analysisType(Internal::AnalysisType::Undefined)
0037    {
0038    }
0039 };
0040 
0041 /// Parse TMVA XML config
0042 inline XMLConfig ParseXMLConfig(const std::string &filename)
0043 {
0044    XMLConfig c;
0045 
0046    // Parse XML file and find root node
0047    TXMLEngine xml;
0048    auto xmldoc = xml.ParseFile(filename.c_str());
0049    if (!xmldoc) {
0050       std::stringstream ss;
0051       ss << "Failed to open TMVA XML file "
0052          << filename << ".";
0053       throw std::runtime_error(ss.str());
0054    }
0055    auto mainNode = xml.DocGetRootElement(xmldoc);
0056    for (auto node = xml.GetChild(mainNode); node; node = xml.GetNext(node)) {
0057       const auto nodeName = std::string(xml.GetNodeName(node));
0058       // Read out input variables
0059       if (nodeName.compare("Variables") == 0) {
0060          c.numVariables = std::atoi(xml.GetAttr(node, "NVar"));
0061          c.variables = std::vector<std::string>(c.numVariables);
0062          c.variable_expressions = std::vector<std::string>(c.numVariables);
0063          for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0064             const auto iVariable = std::atoi(xml.GetAttr(thisNode, "VarIndex"));
0065             c.variables[iVariable] = xml.GetAttr(thisNode, "Title");
0066             c.variable_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
0067          }
0068       }
0069       // Read out input spectators
0070       else if (nodeName.compare("Spectators") == 0) {
0071          c.numSpectators = std::atoi(xml.GetAttr(node, "NSpec"));
0072          c.spectators = std::vector<std::string>(c.numSpectators);
0073          c.spectator_expressions = std::vector<std::string>(c.numSpectators);
0074          for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0075             const auto iVariable = std::atoi(xml.GetAttr(thisNode, "SpecIndex"));
0076             c.spectators[iVariable] = xml.GetAttr(thisNode, "Title");
0077             c.spectator_expressions[iVariable] = xml.GetAttr(thisNode, "Expression");
0078          }
0079       }
0080       // Read out output classes
0081       else if (nodeName.compare("Classes") == 0) {
0082          c.numClasses = std::atoi(xml.GetAttr(node, "NClass"));
0083          for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0084             c.classes.push_back(xml.GetAttr(thisNode, "Name"));
0085          }
0086       }
0087       // Read out analysis type
0088       else if (nodeName.compare("GeneralInfo") == 0) {
0089          std::string analysisType = "";
0090          for (auto thisNode = xml.GetChild(node); thisNode; thisNode = xml.GetNext(thisNode)) {
0091             if (std::string("AnalysisType").compare(xml.GetAttr(thisNode, "name")) == 0) {
0092                analysisType = xml.GetAttr(thisNode, "value");
0093             }
0094          }
0095          if (analysisType.compare("Classification") == 0) {
0096             c.analysisType = Internal::AnalysisType::Classification;
0097          } else if (analysisType.compare("Regression") == 0) {
0098             c.analysisType = Internal::AnalysisType::Regression;
0099          } else if (analysisType.compare("Multiclass") == 0) {
0100             c.analysisType = Internal::AnalysisType::Multiclass;
0101          }
0102       }
0103    }
0104    xml.FreeDoc(xmldoc);
0105 
0106    // Error-handling
0107    if (c.numVariables != c.variables.size() || c.numVariables == 0) {
0108       std::stringstream ss;
0109       ss << "Failed to parse input variables from TMVA config " << filename << ".";
0110       throw std::runtime_error(ss.str());
0111    }
0112    if (c.numSpectators != c.spectators.size()) {
0113       std::stringstream ss;
0114       ss << "Failed to parse input spectators from TMVA config " << filename << ".";
0115       throw std::runtime_error(ss.str());
0116    }
0117    if (c.numClasses != c.classes.size() || c.numClasses == 0) {
0118       std::stringstream ss;
0119       ss << "Failed to parse output classes from TMVA config " << filename << ".";
0120       throw std::runtime_error(ss.str());
0121    }
0122    if (c.analysisType == Internal::AnalysisType::Undefined) {
0123       std::stringstream ss;
0124       ss << "Failed to parse analysis type from TMVA config " << filename << ".";
0125       throw std::runtime_error(ss.str());
0126    }
0127 
0128    return c;
0129 }
0130 
0131 } // namespace Internal
0132 
0133 /// A replacement for the TMVA::Reader legacy interface.
0134 /// Performs inference for TMVA models stored as XML files.
0135 /// For neural network inference consider using [SOFIE](https://github.com/root-project/root/blob/master/tmva/sofie/README.md) instead.
0136 class RReader {
0137 private:
0138    std::unique_ptr<Reader> fReader;
0139    std::vector<float> fVariableValues;
0140    std::vector<std::string> fVariables;
0141    std::vector<std::string> fVariableExpressions;
0142    std::vector<float> fSpectatorValues;
0143    std::vector<std::string> fSpectators;
0144    std::vector<std::string> fSpectatorExpressions;
0145    unsigned int fNumClasses;
0146    const char *name = "RReader";
0147    Internal::AnalysisType fAnalysisType;
0148 
0149 public:
0150    /// Create TMVA model from XML file
0151    RReader(const std::string &path)
0152    {
0153       // Load config
0154       auto c = Internal::ParseXMLConfig(path);
0155       fVariables = c.variables;
0156       fVariableExpressions = c.variable_expressions;
0157       fSpectators = c.spectators;
0158       fSpectatorExpressions = c.spectator_expressions;
0159       fAnalysisType = c.analysisType;
0160       fNumClasses = c.numClasses;
0161 
0162       // Setup reader
0163       fReader = std::make_unique<Reader>("Silent");
0164       const auto numVars = fVariables.size();
0165       fVariableValues = std::vector<float>(numVars);
0166       for (std::size_t i = 0; i < numVars; i++) {
0167          fReader->AddVariable(TString(fVariableExpressions[i]), &fVariableValues[i]);
0168       }
0169       const auto numSpecs = fSpectators.size();
0170       fSpectatorValues = std::vector<float>(numSpecs);
0171       for (std::size_t i = 0; i < numSpecs; i++) {
0172          fReader->AddSpectator(TString(fSpectatorExpressions[i]), &fSpectatorValues[i]);
0173       }
0174       fReader->BookMVA(name, path.c_str());
0175    }
0176 
0177    /// Compute model prediction on vector
0178    std::vector<float> Compute(const std::vector<float> &x)
0179    {
0180       if (x.size() != (fVariables.size()+fSpectators.size()))
0181          throw std::runtime_error("Size of input vector is not equal to number of variables.");
0182 
0183       // Copy over inputs to memory used by TMVA reader
0184       const auto nVars = fVariables.size();
0185       for (std::size_t i = 0; i != nVars ; ++i) {
0186          fVariableValues[i] = x[i];
0187       }
0188       for (std::size_t i = 0; i != fSpectators.size(); ++i) {
0189          fSpectatorValues[i] = x[nVars+i];
0190       }
0191 
0192       // Take lock to protect model evaluation
0193       R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0194 
0195       // Evaluate TMVA model
0196       // Classification
0197       if (fAnalysisType == Internal::AnalysisType::Classification) {
0198          return std::vector<float>({static_cast<float>(fReader->EvaluateMVA(name))});
0199       }
0200       // Regression
0201       else if (fAnalysisType == Internal::AnalysisType::Regression) {
0202          return fReader->EvaluateRegression(name);
0203       }
0204       // Multiclass
0205       else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
0206          return fReader->EvaluateMulticlass(name);
0207       }
0208       // Throw error
0209       else {
0210          throw std::runtime_error("RReader has undefined analysis type.");
0211          return std::vector<float>();
0212       }
0213    }
0214 
0215    /// Compute model prediction on input RTensor
0216    RTensor<float> Compute(RTensor<float> &x)
0217    {
0218       // Error-handling for input tensor
0219       const auto shape = x.GetShape();
0220       if (shape.size() != 2)
0221          throw std::runtime_error("Can only compute model outputs for input tensor of rank 2.");
0222 
0223       const auto numEntries = shape[0];
0224       const auto numVars = shape[1];
0225       if (numVars != (fVariables.size()+fSpectators.size()))
0226          throw std::runtime_error("Second dimension of input tensor is not equal to number of variables.");
0227 
0228       // Define shape of output tensor based on analysis type
0229       unsigned int numClasses = 1;
0230       if (fAnalysisType == Internal::AnalysisType::Multiclass)
0231          numClasses = fNumClasses;
0232       RTensor<float> y({numEntries * numClasses});
0233       if (fAnalysisType == Internal::AnalysisType::Multiclass)
0234          y = y.Reshape({numEntries, numClasses});
0235 
0236       // Fill output tensor
0237       const auto nVars = fVariables.size(); // number of non-spectator variables
0238       for (std::size_t i = 0; i < numEntries; i++) {
0239          for (std::size_t j = 0; j < nVars; j++) {
0240             fVariableValues[j] = x(i, j);
0241          }
0242          for (std::size_t j = 0; j < fSpectators.size(); ++j) {
0243             fSpectatorValues[j] = x(i, nVars+j);
0244          }
0245          R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0246          // Classification
0247          if (fAnalysisType == Internal::AnalysisType::Classification) {
0248             y(i) = fReader->EvaluateMVA(name);
0249          }
0250          // Regression
0251          else if (fAnalysisType == Internal::AnalysisType::Regression) {
0252             y(i) = fReader->EvaluateRegression(name)[0];
0253          }
0254          // Multiclass
0255          else if (fAnalysisType == Internal::AnalysisType::Multiclass) {
0256             const auto p = fReader->EvaluateMulticlass(name);
0257             for (std::size_t k = 0; k < numClasses; k++)
0258                y(i, k) = p[k];
0259          }
0260       }
0261 
0262       return y;
0263    }
0264 
0265    std::vector<std::string> GetVariableNames() { return fVariables; }
0266    std::vector<std::string> GetSpectatorNames() { return fSpectators; }
0267 };
0268 
0269 } // namespace Experimental
0270 } // namespace TMVA
0271 
0272 #endif // TMVA_RREADER