File indexing completed on 2025-09-18 09:32:40
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016 #ifndef TMVA_RSOFIEREADER
0017 #define TMVA_RSOFIEREADER
0018
0019
0020 #include <string>
0021 #include <vector>
0022 #include <memory> // std::unique_ptr
0023 #include <sstream> // std::stringstream
0024 #include <iostream>
0025 #include "TROOT.h"
0026 #include "TSystem.h"
0027 #include "TError.h"
0028 #include "TInterpreter.h"
0029 #include "TUUID.h"
0030 #include "TMVA/RTensor.hxx"
0031 #include "Math/Util.h"
0032
0033 namespace TMVA {
0034 namespace Experimental {
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 class RSofieReader {
0046
0047
0048 public:
0049
0050 RSofieReader() {}
0051
0052
0053 RSofieReader(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0054 {
0055 Load(path, inputShapes, verbose);
0056 }
0057
0058 void Load(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0059 {
0060
0061 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef};
0062 EModelType type = kNotDef;
0063
0064 auto pos1 = path.rfind("/");
0065 auto pos2 = path.find(".onnx");
0066 if (pos2 != std::string::npos) {
0067 type = kONNX;
0068 } else {
0069 pos2 = path.find(".h5");
0070 if (pos2 != std::string::npos) {
0071 type = kKeras;
0072 } else {
0073 pos2 = path.find(".pt");
0074 if (pos2 != std::string::npos) {
0075 type = kPt;
0076 }
0077 else {
0078 pos2 = path.find(".root");
0079 if (pos2 != std::string::npos) {
0080 type = kROOT;
0081 }
0082 }
0083 }
0084 }
0085 if (type == kNotDef) {
0086 throw std::runtime_error("Input file is not an ONNX or Keras or PyTorch file");
0087 }
0088 if (pos1 == std::string::npos)
0089 pos1 = 0;
0090 else
0091 pos1 += 1;
0092 std::string modelName = path.substr(pos1,pos2-pos1);
0093 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
0094 if (verbose) std::cout << "Parsing SOFIE model " << modelName << " of type " << fileType << std::endl;
0095
0096
0097
0098 std::string parserCode;
0099 if (type == kONNX) {
0100
0101 if (gSystem->Load("libROOTTMVASofieParser") < 0) {
0102 throw std::runtime_error("RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
0103 }
0104 gInterpreter->Declare("#include \"TMVA/RModelParser_ONNX.hxx\"");
0105 parserCode += "{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
0106 if (verbose == 2)
0107 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\",true); \n";
0108 else
0109 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\"); \n";
0110 }
0111 else if (type == kKeras) {
0112
0113 if (gSystem->Load("libPyMVA") < 0) {
0114 throw std::runtime_error("RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
0115 }
0116
0117 std::string batch_size = "-1";
0118 if (!inputShapes.empty() && ! inputShapes[0].empty())
0119 batch_size = std::to_string(inputShapes[0][0]);
0120 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
0121 "\"," + batch_size + "); \n";
0122 }
0123 else if (type == kPt) {
0124
0125 if (gSystem->Load("libPyMVA") < 0) {
0126 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
0127 }
0128 if (inputShapes.size() == 0) {
0129 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
0130 }
0131 std::string inputShapesStr = "{";
0132 for (unsigned int i = 0; i < inputShapes.size(); i++) {
0133 inputShapesStr += "{ ";
0134 for (unsigned int j = 0; j < inputShapes[i].size(); j++) {
0135 inputShapesStr += ROOT::Math::Util::ToString(inputShapes[i][j]);
0136 if (j < inputShapes[i].size()-1) inputShapesStr += ", ";
0137 }
0138 inputShapesStr += "}";
0139 if (i < inputShapes.size()-1) inputShapesStr += ", ";
0140 }
0141 inputShapesStr += "}";
0142 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path + "\", "
0143 + inputShapesStr + "); \n";
0144 }
0145 else if (type == kROOT) {
0146
0147 parserCode += "{\nauto fileRead = TFile::Open(\"" + path + "\",\"READ\");\n";
0148 parserCode += "TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
0149 parserCode += "auto keyList = fileRead->GetListOfKeys(); TString name;\n";
0150 parserCode += "for (const auto&& k : *keyList) { \n";
0151 parserCode += " TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
0152 parserCode += "fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
0153 parserCode += "TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
0154 }
0155
0156
0157 if (fCustomOperators.size() > 0) {
0158
0159 for (auto & op : fCustomOperators) {
0160 parserCode += "{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
0161 + op.fOpName + "\"," + op.fInputNames + "," + op.fOutputNames + "," + op.fOutputShapes + ",\"" + op.fFileName + "\");\n";
0162 parserCode += "std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
0163 parserCode += "model.AddOperator(std::move(op));\n}\n";
0164 }
0165 }
0166
0167 int batchSize = 1;
0168 if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
0169 batchSize = inputShapes[0][0];
0170 if (batchSize < 1) batchSize = 1;
0171 }
0172 if (verbose) std::cout << "generating the code with batch size = " << batchSize << " ...\n";
0173
0174 parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0175 + ROOT::Math::Util::ToString(batchSize) + ", 0, " + std::to_string(verbose) + "); \n";
0176
0177 if (verbose) {
0178 parserCode += "model.PrintRequiredInputTensors();\n";
0179 parserCode += "model.PrintIntermediateTensors();\n";
0180 parserCode += "model.PrintOutputTensors();\n";
0181 }
0182
0183
0184 #if 0
0185 if (fCustomOperators.size() > 0) {
0186 if (verbose) {
0187 parserCode += "model.PrintRequiredInputTensors();\n";
0188 parserCode += "model.PrintIntermediateTensors();\n";
0189 parserCode += "model.PrintOutputTensors();\n";
0190 }
0191 for (auto & op : fCustomOperators) {
0192 parserCode += "{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
0193 + op.fOpName + "\"," + op.fInputNames + "," + op.fOutputNames + "," + op.fOutputShapes + ",\"" + op.fFileName + "\");\n";
0194 parserCode += "std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
0195 parserCode += "model.AddOperator(std::move(op));\n}\n";
0196 }
0197 parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0198 + ROOT::Math::Util::ToString(batchSize) + "); \n";
0199 }
0200 #endif
0201 if (verbose > 1)
0202 parserCode += "model.PrintGenerated(); \n";
0203 parserCode += "model.OutputGenerated();\n";
0204
0205 parserCode += "int nInputs = model.GetInputTensorNames().size();\n";
0206
0207
0208
0209
0210 parserCode += "return nInputs;\n}\n";
0211
0212 if (verbose) std::cout << "//ParserCode being executed:\n" << parserCode << std::endl;
0213
0214 auto iret = gROOT->ProcessLine(parserCode.c_str());
0215 if (iret <= 0) {
0216 std::string msg = "RSofieReader: error processing the parser code: \n" + parserCode;
0217 throw std::runtime_error(msg);
0218 }
0219 fNInputs = iret;
0220 if (fNInputs > 3) {
0221 throw std::runtime_error("RSofieReader does not yet support model with > 3 inputs");
0222 }
0223
0224
0225 std::string modelHeader = modelName + ".hxx";
0226 if (verbose) std::cout << "compile generated code from file " <<modelHeader << std::endl;
0227 if (gSystem->AccessPathName(modelHeader.c_str())) {
0228 std::string msg = "RSofieReader: input header file " + modelHeader + " is not existing";
0229 throw std::runtime_error(msg);
0230 }
0231 if (verbose) std::cout << "Creating Inference function for model " << modelName << std::endl;
0232 std::string declCode;
0233 declCode += "#pragma cling optimize(2)\n";
0234 declCode += "#include \"" + modelHeader + "\"\n";
0235
0236 std::string sessionClassName = "TMVA_SOFIE_" + modelName + "::Session";
0237 TUUID uuid;
0238 std::string uidName = uuid.AsString();
0239 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
0240 []( char const& c ) -> bool { return !std::isalnum(c); } ), uidName.end());
0241
0242 std::string sessionName = "session_" + uidName;
0243 declCode += sessionClassName + " " + sessionName + ";";
0244
0245 if (verbose) std::cout << "//global session declaration\n" << declCode << std::endl;
0246
0247 bool ret = gInterpreter->Declare(declCode.c_str());
0248 if (!ret) {
0249 std::string msg = "RSofieReader: error compiling inference code and creating session class\n" + declCode;
0250 throw std::runtime_error(msg);
0251 }
0252
0253 fSessionPtr = (void *) gInterpreter->Calc(sessionName.c_str());
0254
0255
0256 std::stringstream ifuncCode;
0257 std::string funcName = "SofieInference_" + uidName;
0258 ifuncCode << "std::vector<float> " + funcName + "( void * ptr";
0259 for (int i = 0; i < fNInputs; i++)
0260 ifuncCode << ", float * data" << i;
0261 ifuncCode << ") {\n";
0262 ifuncCode << " " << sessionClassName << " * s = " << "(" << sessionClassName << "*) (ptr);\n";
0263 ifuncCode << " return s->infer(";
0264 for (int i = 0; i < fNInputs; i++) {
0265 if (i>0) ifuncCode << ",";
0266 ifuncCode << "data" << i;
0267 }
0268 ifuncCode << ");\n";
0269 ifuncCode << "}\n";
0270
0271 if (verbose) std::cout << "//Inference function code using global session instance\n"
0272 << ifuncCode.str() << std::endl;
0273
0274 ret = gInterpreter->Declare(ifuncCode.str().c_str());
0275 if (!ret) {
0276 std::string msg = "RSofieReader: error compiling inference function\n" + ifuncCode.str();
0277 throw std::runtime_error(msg);
0278 }
0279 fFuncPtr = (void *) gInterpreter->Calc(funcName.c_str());
0280
0281 fInitialized = true;
0282 }
0283
0284
0285 void AddCustomOperator(const std::string &opName, const std::string &inputNames, const std::string & outputNames,
0286 const std::string & outputShapes, const std::string & fileName) {
0287 if (fInitialized) std::cout << "WARNING: Model is already loaded and initialised. It must be done after adding the custom operators" << std::endl;
0288 fCustomOperators.push_back( {fileName, opName,inputNames, outputNames,outputShapes});
0289 }
0290
0291
0292 std::vector<float> DoCompute(const std::vector<float> & x1) {
0293 if (fNInputs != 1) {
0294 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0295 throw std::runtime_error(msg);
0296 }
0297 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0298 return fptr(fSessionPtr, x1.data());
0299 }
0300 std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2) {
0301 if (fNInputs != 2) {
0302 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0303 throw std::runtime_error(msg);
0304 }
0305 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *)>(fFuncPtr);
0306 return fptr(fSessionPtr, x1.data(),x2.data());
0307 }
0308 std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2, const std::vector<float> & x3) {
0309 if (fNInputs != 3) {
0310 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0311 throw std::runtime_error(msg);
0312 }
0313 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *, const float *)>(fFuncPtr);
0314 return fptr(fSessionPtr, x1.data(),x2.data(),x3.data());
0315 }
0316
0317
0318 template<typename... T>
0319 std::vector<float> Compute(T... x)
0320 {
0321 if(!fInitialized) {
0322 return std::vector<float>();
0323 }
0324
0325
0326 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0327
0328
0329 return DoCompute(x...);
0330
0331 }
0332 std::vector<float> Compute(const std::vector<float> &x) {
0333 if(!fInitialized) {
0334 return std::vector<float>();
0335 }
0336
0337
0338 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0339
0340
0341 return DoCompute(x);
0342 }
0343
0344
0345
0346
0347 RTensor<float> Compute(RTensor<float> &x)
0348 {
0349 if(!fInitialized) {
0350 return RTensor<float>({0});
0351 }
0352 const auto nrows = x.GetShape()[0];
0353 const auto rowsize = x.GetStrides()[0];
0354 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0355 auto result = fptr(fSessionPtr, x.GetData());
0356
0357 RTensor<float> y({nrows, result.size()}, MemoryLayout::ColumnMajor);
0358 std::copy(result.begin(),result.end(), y.GetData());
0359
0360
0361 for (size_t i = 1; i < nrows; i++) {
0362 result = fptr(fSessionPtr, x.GetData() + i*rowsize);
0363 std::copy(result.begin(),result.end(), y.GetData() + i*result.size());
0364 }
0365 return y;
0366 }
0367
0368 private:
0369
0370 bool fInitialized = false;
0371 int fNInputs = 0;
0372 void * fSessionPtr = nullptr;
0373 void * fFuncPtr = nullptr;
0374
0375
0376 struct CustomOperatorData {
0377 std::string fFileName;
0378 std::string fOpName;
0379 std::string fInputNames;
0380 std::string fOutputNames;
0381 std::string fOutputShapes;
0382 };
0383 std::vector<CustomOperatorData> fCustomOperators;
0384
0385 };
0386
0387 }
0388 }
0389
0390 #endif