File indexing completed on 2025-01-18 10:11:09
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016 #ifndef TMVA_RSOFIEREADER
0017 #define TMVA_RSOFIEREADER
0018
0019
0020 #include <string>
0021 #include <vector>
0022 #include <memory> // std::unique_ptr
0023 #include <sstream> // std::stringstream
0024 #include <iostream>
0025 #include "TROOT.h"
0026 #include "TSystem.h"
0027 #include "TError.h"
0028 #include "TInterpreter.h"
0029 #include "TUUID.h"
0030 #include "TMVA/RTensor.hxx"
0031 #include "Math/Util.h"
0032
0033 namespace TMVA {
0034 namespace Experimental {
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045 class RSofieReader {
0046
0047
0048 public:
0049
0050 RSofieReader() {}
0051
0052
0053 RSofieReader(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0054 {
0055 Load(path, inputShapes, verbose);
0056 }
0057
0058 void Load(const std::string &path, std::vector<std::vector<size_t>> inputShapes = {}, int verbose = 0)
0059 {
0060
0061 enum EModelType {kONNX, kKeras, kPt, kROOT, kNotDef};
0062 EModelType type = kNotDef;
0063
0064 auto pos1 = path.rfind("/");
0065 auto pos2 = path.find(".onnx");
0066 if (pos2 != std::string::npos) {
0067 type = kONNX;
0068 } else {
0069 pos2 = path.find(".h5");
0070 if (pos2 != std::string::npos) {
0071 type = kKeras;
0072 } else {
0073 pos2 = path.find(".pt");
0074 if (pos2 != std::string::npos) {
0075 type = kPt;
0076 }
0077 else {
0078 pos2 = path.find(".root");
0079 if (pos2 != std::string::npos) {
0080 type = kROOT;
0081 }
0082 }
0083 }
0084 }
0085 if (type == kNotDef) {
0086 throw std::runtime_error("Input file is not an ONNX or Keras or PyTorch file");
0087 }
0088 if (pos1 == std::string::npos)
0089 pos1 = 0;
0090 else
0091 pos1 += 1;
0092 std::string modelName = path.substr(pos1,pos2-pos1);
0093 std::string fileType = path.substr(pos2+1, path.length()-pos2-1);
0094 if (verbose) std::cout << "Parsing SOFIE model " << modelName << " of type " << fileType << std::endl;
0095
0096
0097
0098 std::string parserCode;
0099 if (type == kONNX) {
0100
0101 if (gSystem->Load("libROOTTMVASofieParser") < 0) {
0102 throw std::runtime_error("RSofieReader: cannot use SOFIE with ONNX since libROOTTMVASofieParser is missing");
0103 }
0104 gInterpreter->Declare("#include \"TMVA/RModelParser_ONNX.hxx\"");
0105 parserCode += "{\nTMVA::Experimental::SOFIE::RModelParser_ONNX parser ; \n";
0106 if (verbose == 2)
0107 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\",true); \n";
0108 else
0109 parserCode += "TMVA::Experimental::SOFIE::RModel model = parser.Parse(\"" + path + "\"); \n";
0110 }
0111 else if (type == kKeras) {
0112
0113 if (gSystem->Load("libPyMVA") < 0) {
0114 throw std::runtime_error("RSofieReader: cannot use SOFIE with Keras since libPyMVA is missing");
0115 }
0116
0117 std::string batch_size = "-1";
0118 if (!inputShapes.empty() && ! inputShapes[0].empty())
0119 batch_size = std::to_string(inputShapes[0][0]);
0120 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyKeras::Parse(\"" + path +
0121 "\"," + batch_size + "); \n";
0122 }
0123 else if (type == kPt) {
0124
0125 if (gSystem->Load("libPyMVA") < 0) {
0126 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since libPyMVA is missing");
0127 }
0128 if (inputShapes.size() == 0) {
0129 throw std::runtime_error("RSofieReader: cannot use SOFIE with PyTorch since the input tensor shape is missing and is needed by the PyTorch parser");
0130 }
0131 std::string inputShapesStr = "{";
0132 for (unsigned int i = 0; i < inputShapes.size(); i++) {
0133 inputShapesStr += "{ ";
0134 for (unsigned int j = 0; j < inputShapes[i].size(); j++) {
0135 inputShapesStr += ROOT::Math::Util::ToString(inputShapes[i][j]);
0136 if (j < inputShapes[i].size()-1) inputShapesStr += ", ";
0137 }
0138 inputShapesStr += "}";
0139 if (i < inputShapes.size()-1) inputShapesStr += ", ";
0140 }
0141 inputShapesStr += "}";
0142 parserCode += "{\nTMVA::Experimental::SOFIE::RModel model = TMVA::Experimental::SOFIE::PyTorch::Parse(\"" + path + "\", "
0143 + inputShapesStr + "); \n";
0144 }
0145 else if (type == kROOT) {
0146
0147 parserCode += "{\nauto fileRead = TFile::Open(\"" + path + "\",\"READ\");\n";
0148 parserCode += "TMVA::Experimental::SOFIE::RModel * modelPtr;\n";
0149 parserCode += "auto keyList = fileRead->GetListOfKeys(); TString name;\n";
0150 parserCode += "for (const auto&& k : *keyList) { \n";
0151 parserCode += " TString cname = ((TKey*)k)->GetClassName(); if (cname==\"TMVA::Experimental::SOFIE::RModel\") name = k->GetName(); }\n";
0152 parserCode += "fileRead->GetObject(name,modelPtr); fileRead->Close(); delete fileRead;\n";
0153 parserCode += "TMVA::Experimental::SOFIE::RModel & model = *modelPtr;\n";
0154 }
0155
0156 int batchSize = 1;
0157 if (inputShapes.size() > 0 && inputShapes[0].size() > 0) {
0158 batchSize = inputShapes[0][0];
0159 if (batchSize < 1) batchSize = 1;
0160 }
0161 if (verbose) std::cout << "generating the code with batch size = " << batchSize << " ...\n";
0162
0163 parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0164 + ROOT::Math::Util::ToString(batchSize) + "); \n";
0165
0166
0167 if (fCustomOperators.size() > 0) {
0168 if (verbose) {
0169 parserCode += "model.PrintRequiredInputTensors();\n";
0170 parserCode += "model.PrintIntermediateTensors();\n";
0171 parserCode += "model.PrintOutputTensors();\n";
0172 }
0173 for (auto & op : fCustomOperators) {
0174 parserCode += "{ auto p = new TMVA::Experimental::SOFIE::ROperator_Custom<float>(\""
0175 + op.fOpName + "\"," + op.fInputNames + "," + op.fOutputNames + "," + op.fOutputShapes + ",\"" + op.fFileName + "\");\n";
0176 parserCode += "std::unique_ptr<TMVA::Experimental::SOFIE::ROperator> op(p);\n";
0177 parserCode += "model.AddOperator(std::move(op));\n}\n";
0178 }
0179 parserCode += "model.Generate(TMVA::Experimental::SOFIE::Options::kDefault,"
0180 + ROOT::Math::Util::ToString(batchSize) + "); \n";
0181 }
0182 if (verbose > 1)
0183 parserCode += "model.PrintGenerated(); \n";
0184 parserCode += "model.OutputGenerated();\n";
0185
0186 parserCode += "int nInputs = model.GetInputTensorNames().size();\n";
0187
0188
0189
0190
0191 parserCode += "return nInputs;\n}\n";
0192
0193 if (verbose) std::cout << "//ParserCode being executed:\n" << parserCode << std::endl;
0194
0195 auto iret = gROOT->ProcessLine(parserCode.c_str());
0196 if (iret <= 0) {
0197 std::string msg = "RSofieReader: error processing the parser code: \n" + parserCode;
0198 throw std::runtime_error(msg);
0199 }
0200 fNInputs = iret;
0201 if (fNInputs > 3) {
0202 throw std::runtime_error("RSofieReader does not yet support model with > 3 inputs");
0203 }
0204
0205
0206 std::string modelHeader = modelName + ".hxx";
0207 if (verbose) std::cout << "compile generated code from file " <<modelHeader << std::endl;
0208 if (gSystem->AccessPathName(modelHeader.c_str())) {
0209 std::string msg = "RSofieReader: input header file " + modelHeader + " is not existing";
0210 throw std::runtime_error(msg);
0211 }
0212 if (verbose) std::cout << "Creating Inference function for model " << modelName << std::endl;
0213 std::string declCode;
0214 declCode += "#pragma cling optimize(2)\n";
0215 declCode += "#include \"" + modelHeader + "\"\n";
0216
0217 std::string sessionClassName = "TMVA_SOFIE_" + modelName + "::Session";
0218 TUUID uuid;
0219 std::string uidName = uuid.AsString();
0220 uidName.erase(std::remove_if(uidName.begin(), uidName.end(),
0221 []( char const& c ) -> bool { return !std::isalnum(c); } ), uidName.end());
0222
0223 std::string sessionName = "session_" + uidName;
0224 declCode += sessionClassName + " " + sessionName + ";";
0225
0226 if (verbose) std::cout << "//global session declaration\n" << declCode << std::endl;
0227
0228 bool ret = gInterpreter->Declare(declCode.c_str());
0229 if (!ret) {
0230 std::string msg = "RSofieReader: error compiling inference code and creating session class\n" + declCode;
0231 throw std::runtime_error(msg);
0232 }
0233
0234 fSessionPtr = (void *) gInterpreter->Calc(sessionName.c_str());
0235
0236
0237 std::stringstream ifuncCode;
0238 std::string funcName = "SofieInference_" + uidName;
0239 ifuncCode << "std::vector<float> " + funcName + "( void * ptr";
0240 for (int i = 0; i < fNInputs; i++)
0241 ifuncCode << ", float * data" << i;
0242 ifuncCode << ") {\n";
0243 ifuncCode << " " << sessionClassName << " * s = " << "(" << sessionClassName << "*) (ptr);\n";
0244 ifuncCode << " return s->infer(";
0245 for (int i = 0; i < fNInputs; i++) {
0246 if (i>0) ifuncCode << ",";
0247 ifuncCode << "data" << i;
0248 }
0249 ifuncCode << ");\n";
0250 ifuncCode << "}\n";
0251
0252 if (verbose) std::cout << "//Inference function code using global session instance\n"
0253 << ifuncCode.str() << std::endl;
0254
0255 ret = gInterpreter->Declare(ifuncCode.str().c_str());
0256 if (!ret) {
0257 std::string msg = "RSofieReader: error compiling inference function\n" + ifuncCode.str();
0258 throw std::runtime_error(msg);
0259 }
0260 fFuncPtr = (void *) gInterpreter->Calc(funcName.c_str());
0261
0262 fInitialized = true;
0263 }
0264
0265
0266 void AddCustomOperator(const std::string &opName, const std::string &inputNames, const std::string & outputNames,
0267 const std::string & outputShapes, const std::string & fileName) {
0268 if (fInitialized) std::cout << "WARNING: Model is already loaded and initialised. It must be done after adding the custom operators" << std::endl;
0269 fCustomOperators.push_back( {fileName, opName,inputNames, outputNames,outputShapes});
0270 }
0271
0272
0273 std::vector<float> DoCompute(const std::vector<float> & x1) {
0274 if (fNInputs != 1) {
0275 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0276 throw std::runtime_error(msg);
0277 }
0278 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0279 return fptr(fSessionPtr, x1.data());
0280 }
0281 std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2) {
0282 if (fNInputs != 2) {
0283 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0284 throw std::runtime_error(msg);
0285 }
0286 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *)>(fFuncPtr);
0287 return fptr(fSessionPtr, x1.data(),x2.data());
0288 }
0289 std::vector<float> DoCompute(const std::vector<float> & x1, const std::vector<float> & x2, const std::vector<float> & x3) {
0290 if (fNInputs != 3) {
0291 std::string msg = "Wrong number of inputs - model requires " + std::to_string(fNInputs);
0292 throw std::runtime_error(msg);
0293 }
0294 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *, const float *, const float *)>(fFuncPtr);
0295 return fptr(fSessionPtr, x1.data(),x2.data(),x3.data());
0296 }
0297
0298
0299 template<typename... T>
0300 std::vector<float> Compute(T... x)
0301 {
0302 if(!fInitialized) {
0303 return std::vector<float>();
0304 }
0305
0306
0307 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0308
0309
0310 return DoCompute(x...);
0311
0312 }
0313 std::vector<float> Compute(const std::vector<float> &x) {
0314 if(!fInitialized) {
0315 return std::vector<float>();
0316 }
0317
0318
0319 R__WRITE_LOCKGUARD(ROOT::gCoreMutex);
0320
0321
0322 return DoCompute(x);
0323 }
0324
0325
0326
0327
0328 RTensor<float> Compute(RTensor<float> &x)
0329 {
0330 if(!fInitialized) {
0331 return RTensor<float>({0});
0332 }
0333 const auto nrows = x.GetShape()[0];
0334 const auto rowsize = x.GetStrides()[0];
0335 auto fptr = reinterpret_cast<std::vector<float> (*)(void *, const float *)>(fFuncPtr);
0336 auto result = fptr(fSessionPtr, x.GetData());
0337
0338 RTensor<float> y({nrows, result.size()}, MemoryLayout::ColumnMajor);
0339 std::copy(result.begin(),result.end(), y.GetData());
0340
0341
0342 for (size_t i = 1; i < nrows; i++) {
0343 result = fptr(fSessionPtr, x.GetData() + i*rowsize);
0344 std::copy(result.begin(),result.end(), y.GetData() + i*result.size());
0345 }
0346 return y;
0347 }
0348
0349 private:
0350
0351 bool fInitialized = false;
0352 int fNInputs = 0;
0353 void * fSessionPtr = nullptr;
0354 void * fFuncPtr = nullptr;
0355
0356
0357 struct CustomOperatorData {
0358 std::string fFileName;
0359 std::string fOpName;
0360 std::string fInputNames;
0361 std::string fOutputNames;
0362 std::string fOutputShapes;
0363 };
0364 std::vector<CustomOperatorData> fCustomOperators;
0365
0366 };
0367
0368 }
0369 }
0370
0371 #endif