Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:40

0001 /**********************************************************************************
0002  * Project: ROOT - a Root-integrated toolkit for multivariate data analysis       *
0003  * Package: TMVA                                                                  *                                        *
0004  *                                                                                *
0005  * Description:                                                                   *
0006  *                                                                                *
0007  * Authors:                                                                       *
0008  *      Lorenzo Moneta                                  *
0009  *                                                                                *
0010  * Copyright (c) 2022:                                                            *
0011  *      CERN, Switzerland                                                         *
0012  *                                                                                *
0013  **********************************************************************************/
0014 
0015 
0016 #ifndef TMVA_RSOFIEREADER
0017 #define TMVA_RSOFIEREADER
0018 
0019 
0020 #include <string>
0021 #include <vector>
0022 #include <memory> // std::unique_ptr
0023 #include <sstream> // std::stringstream
0024 #include <iostream>
0025 #include "TROOT.h"
0026 #include "TSystem.h"
0027 #include "TError.h"
0028 #include "TInterpreter.h"
0029 #include "TUUID.h"
0030 #include "TMVA/RTensor.hxx"
0031 #include "Math/Util.h"
0032 
0033 namespace TMVA {
0034 namespace Experimental {
0035 
0036 
0037 
0038 
0039 /// TMVA::RSofieReader class for reading external Machine Learning models
0040 /// in ONNX files, Keras .h5 files or PyTorch .pt files
0041 /// and performing the inference using SOFIE
0042 /// It is reccomended to use ONNX if possible since there is a larger support for
0043 /// model operators.
0044 
0045 class RSofieReader  {
0046 
0047 
0048 public:
0049    /// Dummy constructor which needs model loading  afterwards
0050    RSofieReader() {}
0051    /// Create TMVA model from ONNX file
0052    /// print level can be 0 (minimal) 1 with info , 2 with all ONNX parsing info
0053    RSofieReader(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0054    {
0055       Load(path, inputShapes, verbose);
0056    }
0057 
0058    void Load(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0059    {
0060 
0061       enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef}; // type of model
0062       EModelType type = kNotDef;
0063 
0064       auto pos1 = path.rfind("/");
0065       auto pos2 = path.find(".onnx");
0066       if (pos2 != std::string::npos) {
0067          type = kONNX;
0068       } else {
0069          pos2 = path.find(".h5");
0070          if (pos2 != std::string::npos) {
0071              type = kKeras;
0072          } else {
0073             pos2 = path.find(".pt");
0074             if (pos2 != std::string::npos) {
0075                type = kPt;
0076             }
0077             else {
0078                pos2 = path.find(".root");
0079                if (pos2 != std::string::npos) {
0080                   type = kROOT;
0081                }
0082             }
0083          }
0084       }
0085       if (type == kNotDef) {
0086          throw std::runtime_error("Input file is not an ONNX or Keras or PyTorch file");
0087       }
0088       if (pos1 == std::string::npos)
0089          pos1 = 0;
0090       else
0091          pos1 += 1;
0092       std::string modelName = path.substr(pos1,pos2-pos1);
0093       std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
0094       if (verbose) std::cout << "Parsing SOFIE model " << modelName << " of type " << fileType << std::endl;
0095 
0096       // create code for parsing model and generate C++ code for inference
0097       // make it in a separate scope to avoid polluting global interpreter space
0098       std::string parserCode;
0099       if (type == kONNX) {
0100          // check first if we can load the SOFIE parser library
0101          if (gSystem->Load("libROOTTMVASofieParser") < 0) {
0102             throw std::runtime_error("RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
0103          }
0104          gInterpreter->Declare("#include \"TMVA/RModelParser_ONNX.hxx\"");
0105          parserCode += "{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
0106          if (verbose == 2)
0107             parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\",true); \n";
0108          else
0109             parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\"); \n";
0110       }
0111       else if (type == kKeras) {
0112          // use Keras direct parser
0113          if (gSystem->Load("libPyMVA") < 0) {
0114             throw std::runtime_error("RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
0115          }
0116          // assume batch size is first entry in first input !
0117          std::string batch_size = "-1";
0118          if (!inputShapes.empty() && ! inputShapes[0].empty())
0119             batch_size = std::to_string(inputShapes[0][0]);
0120          parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
0121                        "\"," + batch_size + "); \n";
0122       }
0123       else if (type == kPt) {
0124          // use PyTorch direct parser
0125          if (gSystem->Load("libPyMVA") < 0) {
0126             throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
0127          }
0128          if (inputShapes.size() == 0) {
0129             throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
0130          }
0131          std::string inputShapesStr = "{";
0132          for (unsigned int i = 0; i < inputShapes.size(); i++) {
0133             inputShapesStr += "{ ";
0134             for (unsigned int j = 0; j < inputShapes[i].size(); j++) {
0135                inputShapesStr += ROOT::Math::Util::ToString(inputShapes[i][j]);
0136                if (j < inputShapes[i].size()-1) inputShapesStr += ", ";
0137             }
0138             inputShapesStr += "}";
0139             if (i < inputShapes.size()-1) inputShapesStr += ", ";
0140          }
0141          inputShapesStr += "}";
0142          parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path + "\", "
0143                     + inputShapesStr + "); \n";
0144       }
0145       else if (type == kROOT) {
0146          // use  parser from ROOT
0147          parserCode += "{\nauto fileRead = TFile::Open(\"" + path + "\",\"READ\");\n";
0148          parserCode += "TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
0149          parserCode += "auto keyList = fileRead->GetListOfKeys(); TString name;\n";
0150          parserCode += "for (const auto&& k : *keyList)  { \n";
0151          parserCode += "   TString cname =  ((TKey*)k)->GetClassName();  if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
0152          parserCode += "fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
0153          parserCode += "TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
0154       }
0155 
0156        // add custom operators if needed
0157       if (fCustomOperators.size() > 0) {
0158 
0159          for (auto & op : fCustomOperators) {
0160             parserCode += "{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
0161                       + op.fOpName + "\"," + op.fInputNames + "," + op.fOutputNames + "," + op.fOutputShapes + ",\"" + op.fFileName + "\");\n";
0162             parserCode += "std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
0163             parserCode += "model.AddOperator(std::move(op));\n}\n";
0164          }
0165       }
0166 
0167       int batchSize = 1;
0168       if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
0169          batchSize = inputShapes[0][0];
0170          if (batchSize < 1) batchSize = 1;
0171       }
0172       if (verbose) std::cout << "generating the code with batch size = " << batchSize << " ...\n";
0173 
0174       parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0175                    + ROOT::Math::Util::ToString(batchSize) + ", 0, " + std::to_string(verbose) + "); \n";
0176 
0177       if (verbose) {
0178          parserCode += "model.PrintRequiredInputTensors();\n";
0179          parserCode += "model.PrintIntermediateTensors();\n";
0180          parserCode += "model.PrintOutputTensors();\n";
0181       }
0182 
0183       // add custom operators if needed
0184 #if 0
0185       if (fCustomOperators.size() > 0) {
0186          if (verbose) {
0187             parserCode += "model.PrintRequiredInputTensors();\n";
0188             parserCode += "model.PrintIntermediateTensors();\n";
0189             parserCode += "model.PrintOutputTensors();\n";
0190          }
0191          for (auto & op : fCustomOperators) {
0192             parserCode += "{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
0193                       + op.fOpName + "\"," + op.fInputNames + "," + op.fOutputNames + "," + op.fOutputShapes + ",\"" + op.fFileName + "\");\n";
0194             parserCode += "std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
0195             parserCode += "model.AddOperator(std::move(op));\n}\n";
0196          }
0197          parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0198                    + ROOT::Math::Util::ToString(batchSize) + "); \n";
0199       }
0200 #endif
0201       if (verbose > 1)
0202          parserCode += "model.PrintGenerated(); \n";
0203       parserCode += "model.OutputGenerated();\n";
0204 
0205       parserCode += "int nInputs = model.GetInputTensorNames().size();\n";
0206 
0207       // need information on number of inputs (assume output is 1)
0208 
0209       //end of parsing code, close the scope and return 1 to indicate a success
0210       parserCode += "return nInputs;\n}\n";
0211 
0212       if (verbose) std::cout << "//ParserCode being executed:\n" << parserCode << std::endl;
0213 
0214       auto iret = gROOT->ProcessLine(parserCode.c_str());
0215       if (iret <= 0) {
0216          std::string msg = "RSofieReader: error processing the parser code: \n" + parserCode;
0217          throw std::runtime_error(msg);
0218       }
0219       fNInputs = iret;
0220       if (fNInputs > 3) {
0221          throw std::runtime_error("RSofieReader does not yet support model with > 3 inputs");
0222       }
0223 
0224       // compile now the generated code and create Session class
0225       std::string modelHeader = modelName + ".hxx";
0226       if (verbose) std::cout << "compile generated code from file " <<modelHeader << std::endl;
0227       if (gSystem->AccessPathName(modelHeader.c_str())) {
0228          std::string msg = "RSofieReader: input header file " + modelHeader + " is not existing";
0229          throw std::runtime_error(msg);
0230       }
0231       if (verbose) std::cout << "Creating Inference function for model " << modelName << std::endl;
0232       std::string declCode;
0233       declCode += "#pragma cling optimize(2)\n";
0234       declCode += "#include \"" + modelHeader + "\"\n";
0235       // create global session instance: use UUID to have an unique name
0236       std::string sessionClassName = "TMVA_SOFIE_" + modelName + "::Session";
0237       TUUID uuid;
0238       std::string uidName = uuid.AsString();
0239       uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
0240          []( char const& c ) -> bool { return !std::isalnum(c); } ), uidName.end());
0241 
0242       std::string sessionName = "session_" + uidName;
0243       declCode += sessionClassName + " " + sessionName + ";";
0244 
0245       if (verbose) std::cout << "//global session declaration\n" << declCode << std::endl;
0246 
0247       bool ret = gInterpreter->Declare(declCode.c_str());
0248       if (!ret) {
0249          std::string msg = "RSofieReader: error compiling inference code and creating session class\n" + declCode;
0250          throw std::runtime_error(msg);
0251       }
0252 
0253       fSessionPtr = (void *) gInterpreter->Calc(sessionName.c_str());
0254 
0255       // define a function to be called for inference
0256       std::stringstream ifuncCode;
0257       std::string funcName = "SofieInference_" + uidName;
0258       ifuncCode << "std::vector<float> " + funcName + "( void * ptr";
0259       for (int i = 0; i < fNInputs; i++)
0260          ifuncCode << ", float * data" << i;
0261       ifuncCode << ") {\n";
0262       ifuncCode << "   " << sessionClassName << " * s = " << "(" << sessionClassName << "*) (ptr);\n";
0263       ifuncCode << "   return s->infer(";
0264       for (int i = 0; i < fNInputs; i++) {
0265          if (i>0) ifuncCode << ",";
0266          ifuncCode << "data" << i;
0267       }
0268       ifuncCode << ");\n";
0269       ifuncCode << "}\n";
0270 
0271       if (verbose) std::cout << "//Inference function code using global session instance\n"
0272                               << ifuncCode.str() << std::endl;
0273 
0274       ret = gInterpreter->Declare(ifuncCode.str().c_str());
0275       if (!ret) {
0276          std::string msg = "RSofieReader: error compiling inference function\n" + ifuncCode.str();
0277          throw std::runtime_error(msg);
0278       }
0279       fFuncPtr = (void *) gInterpreter->Calc(funcName.c_str());
0280       //fFuncPtr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fptr);
0281       fInitialized = true;
0282    }
0283 
0284    // Add custom operator
0285     void AddCustomOperator(const std::string &opName, const std::string &inputNames, const std::string & outputNames,
0286       const std::string & outputShapes, const std::string & fileName) {
0287          if (fInitialized)  std::cout << "WARNING: Model is already loaded and initialised. It must be done after adding the custom operators" << std::endl;
0288          fCustomOperators.push_back( {fileName, opName,inputNames, outputNames,outputShapes});
0289       }
0290 
0291    // implementations for different outputs
0292    std::vector<float> DoCompute(const std::vector<float> & x1) {
0293       if (fNInputs != 1) {
0294          std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0295          throw std::runtime_error(msg);
0296       }
0297       auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0298       return fptr(fSessionPtr, x1.data());
0299    }
0300    std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2) {
0301       if (fNInputs != 2) {
0302          std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0303          throw std::runtime_error(msg);
0304       }
0305       auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *)>(fFuncPtr);
0306       return fptr(fSessionPtr, x1.data(),x2.data());
0307    }
0308    std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2, const std::vector<float> & x3) {
0309       if (fNInputs != 3) {
0310          std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0311          throw std::runtime_error(msg);
0312       }
0313       auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *, const float *)>(fFuncPtr);
0314       return fptr(fSessionPtr, x1.data(),x2.data(),x3.data());
0315    }
0316 
0317    /// Compute model prediction on vector
0318    template<typename... T>
0319    std::vector<float> Compute(T... x)
0320    {
0321       if(!fInitialized) {
0322          return std::vector<float>();
0323       }
0324 
0325       // Take lock to protect model evaluation
0326       R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0327 
0328       // Evaluate TMVA model (need to add support for multiple outputs)
0329       return DoCompute(x...);
0330 
0331    }
0332    std::vector<float> Compute(const std::vector<float> &x) {
0333       if(!fInitialized) {
0334          return std::vector<float>();
0335       }
0336 
0337       // Take lock to protect model evaluation
0338       R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0339 
0340       // Evaluate TMVA model (need to add support for multiple outputs)
0341       return DoCompute(x);
0342    }
0343    /// Compute model prediction on input RTensor
0344    /// The shape of the input tensor should be {nevents, nfeatures}
0345    /// and the return shape will be {nevents, noutputs}
0346    /// support for now only a single input
0347    RTensor<float> Compute(RTensor<float> &x)
0348    {
0349       if(!fInitialized) {
0350          return RTensor<float>({0});
0351       }
0352       const auto nrows = x.GetShape()[0];
0353       const auto rowsize = x.GetStrides()[0];
0354       auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0355       auto result = fptr(fSessionPtr, x.GetData());
0356 
0357       RTensor<float> y({nrows, result.size()}, MemoryLayout::ColumnMajor);
0358       std::copy(result.begin(),result.end(), y.GetData());
0359       //const bool layout = x.GetMemoryLayout() == MemoryLayout::ColumnMajor ? false : true;
0360       // assume column major layout
0361       for (size_t i = 1; i < nrows; i++) {
0362          result = fptr(fSessionPtr, x.GetData() + i*rowsize);
0363          std::copy(result.begin(),result.end(), y.GetData() + i*result.size());
0364       }
0365       return y;
0366    }
0367 
0368 private:
0369 
0370    bool fInitialized = false;
0371    int fNInputs = 0;
0372    void * fSessionPtr = nullptr;
0373    void * fFuncPtr = nullptr;
0374 
0375    // data to insert custom operators
0376    struct CustomOperatorData {
0377       std::string fFileName; // code implementing the custom operator
0378       std::string fOpName; // operator name
0379       std::string fInputNames;  // input tensor names (convert as string as {"n1", "n2"})
0380       std::string fOutputNames;  // output tensor names converted as trind
0381       std::string fOutputShapes; // output shapes
0382    };
0383    std::vector<CustomOperatorData> fCustomOperators;
0384 
0385 };
0386 
0387 } // namespace Experimental
0388 } // namespace TMVA
0389 
0390 #endif // TMVA_RREADER