Warning, file /include/root/TMVA/ROperator_SubGraph.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_SubGraph
0002 #define TMVA_SOFIE_ROPERATOR_SubGraph
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015
0016 class ROperator_If final : public ROperator
0017 {
0018
0019 private:
0020
0021 std::string fNX;
0022 ETensorType fType = ETensorType::UNDEFINED;
0023 std::vector<std::string> fNYs;
0024 std::shared_ptr<RModel> fModel_then;
0025 std::shared_ptr<RModel> fModel_else;
0026 std::string fInputSignature_modelThen;
0027 std::string fInputSignature_modelElse;
0028
0029 public:
0030 ROperator_If(){}
0031 ROperator_If(const std::string & nameX, const std::vector<std::string> & nameYs, std::unique_ptr<RModel> model_then, std::unique_ptr<RModel> model_else):
0032 fNX(UTILITY::Clean_name(nameX)), fNYs(nameYs), fModel_then(std::move(model_then)), fModel_else(std::move(model_else))
0033 {
0034 for (auto & n : fNYs)
0035 n = UTILITY::Clean_name(n);
0036
0037 fInputTensorNames = { fNX };
0038 std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(),
0039 [](const std::string& s) -> std::string_view { return s; });
0040 }
0041
0042 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0043 return input;
0044 }
0045
0046 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0047 auto ret = input;
0048 return ret;
0049 }
0050
0051 void Initialize(RModel& model) override {
0052
0053 if (model.CheckIfTensorAlreadyExist(fNX) == false){
0054 throw std::runtime_error("TMVA SOFIE If Op Input Tensor is not found in model");
0055 }
0056
0057 model.InitializeSubGraph(fModel_then);
0058 model.InitializeSubGraph(fModel_else);
0059
0060
0061 fInputSignature_modelThen = fModel_then->GenerateInferSignature(false);
0062 fInputSignature_modelElse = fModel_else->GenerateInferSignature(false);
0063
0064
0065 for (size_t i = 0; i < fNYs.size(); i++) {
0066
0067
0068 auto soutput_name = fModel_then->GetOutputTensorNames()[i];
0069 auto shape = fModel_then->GetTensorShape(soutput_name);
0070 auto type = fModel_then->GetTensorType(soutput_name);
0071 if (i == 0)
0072 fType = type;
0073 else {
0074 if (type != fType)
0075 throw std::runtime_error("TMVA SOFIE If Op supports only all outputs of the same type");
0076 }
0077 model.AddIntermediateTensor(fNYs[i], fType, shape );
0078 }
0079
0080 }
0081
0082
0083 std::string Generate(std::string opName) override {
0084 opName = "op_" + opName;
0085 if (fType == ETensorType::UNDEFINED) {
0086 throw std::runtime_error("TMVA If operator called to Generate without being initialized first");
0087 }
0088 std::stringstream out;
0089
0090 std::string typeName = ConvertTypeToString(fType);
0091 out << "\n//------ If operator\n";
0092 out << SP << "std::vector<std::vector<" << typeName << ">> outputs_" << opName << ";\n";
0093
0094 out << SP << "if (fTensor_" << fNX << "[0] ) { \n";
0095
0096 out << SP << SP << "outputs_" << opName << " = "
0097 << "fSession_" << fModel_then->GetName() << ".infer(" << fInputSignature_modelThen << ");\n";
0098
0099 out << SP << "} else {\n";
0100 out << SP << SP << "outputs_" << opName << " = "
0101 << "fSession_" + fModel_else->GetName() + ".infer(" << fInputSignature_modelElse << ");\n";
0102 out << SP << "}\n";
0103
0104 out << SP << "if (outputs_" << opName << ".size() != " << fNYs.size() << ")\n";
0105 out << SP << SP << "throw std::runtime_error(\" If operator: invalid output size!\");\n\n";
0106 for (size_t i = 0; i < fNYs.size(); i++) {
0107 out << SP << "std::copy(outputs_" << opName << "[" << i << "].begin(), outputs_" << opName << "[" << i << "].end(), fTensor_" << fNYs[i] << ".begin());\n";
0108 }
0109 return out.str();
0110 }
0111
0112
0113
0114 };
0115
0116 }
0117 }
0118 }
0119
0120
0121 #endif