Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_SubGraph.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_SubGraph
0002 #define TMVA_SOFIE_ROPERATOR_SubGraph
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014    // operator dealing with subgraphs (such as If , Loop, etc..)
0015 
0016 class ROperator_If final : public ROperator
0017 {
0018 
0019 private:
0020 
0021    std::string fNX;
0022    ETensorType fType = ETensorType::UNDEFINED;  // output type (support only one common type)
0023    std::vector<std::string> fNYs;
0024    std::shared_ptr<RModel> fModel_then;
0025    std::shared_ptr<RModel> fModel_else;
0026    std::string fInputSignature_modelThen;
0027    std::string fInputSignature_modelElse;
0028 
0029 public:
0030    ROperator_If(){}
0031    ROperator_If(const std::string & nameX, const std::vector<std::string> & nameYs, std::unique_ptr<RModel> model_then, std::unique_ptr<RModel> model_else):
0032       fNX(UTILITY::Clean_name(nameX)), fNYs(nameYs), fModel_then(std::move(model_then)), fModel_else(std::move(model_else))
0033       {
0034          for (auto & n : fNYs)
0035             n = UTILITY::Clean_name(n);
0036 
0037          fInputTensorNames = { fNX };
0038          std::transform(fNYs.begin(), fNYs.end(), fOutputTensorNames.begin(),
0039                    [](const std::string& s) -> std::string_view { return s; });
0040       }
0041 
0042    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0043       return input;
0044    }
0045 
0046    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0047       auto ret = input; //suggest copy to compiler
0048       return ret;
0049    }
0050 
0051    void Initialize(RModel& model) override {
0052        //input must be a graph input, or already initialized intermediate tensor
0053       if (model.CheckIfTensorAlreadyExist(fNX) == false){
0054         throw std::runtime_error("TMVA SOFIE If Op Input Tensor is not found in model");
0055       }
0056       //add the subgraph model to parent RModel and initialize them
0057       model.InitializeSubGraph(fModel_then);
0058       model.InitializeSubGraph(fModel_else);
0059 
0060       // generate input string signature for subgraphs
0061       fInputSignature_modelThen = fModel_then->GenerateInferSignature(false);
0062       fInputSignature_modelElse = fModel_else->GenerateInferSignature(false);
0063 
0064       // add the outputs
0065       for (size_t i = 0; i < fNYs.size(); i++) {
0066          // assume shape of then tensor is same of else tensor
0067          // if not need to make a parametric tensor output (tbd)
0068          auto soutput_name = fModel_then->GetOutputTensorNames()[i];
0069          auto shape = fModel_then->GetTensorShape(soutput_name);
0070          auto type = fModel_then->GetTensorType(soutput_name);
0071          if (i == 0)
0072             fType = type;
0073          else {
0074             if (type != fType)
0075                throw std::runtime_error("TMVA SOFIE If Op supports only all outputs of the same type");
0076          }
0077          model.AddIntermediateTensor(fNYs[i], fType, shape );
0078       }
0079 
0080    }
0081 
0082 
0083    std::string Generate(std::string opName) override {
0084       opName = "op_" + opName;
0085       if (fType == ETensorType::UNDEFINED) {
0086          throw std::runtime_error("TMVA If operator called to Generate without being initialized first");
0087       }
0088       std::stringstream out;
0089       //size_t length = ConvertShapeToLength(fShape);
0090       std::string typeName = ConvertTypeToString(fType);
0091       out << "\n//------ If operator\n";
0092       out << SP << "std::vector<std::vector<" << typeName << ">> outputs_" << opName << ";\n";
0093       // use the std::vector since is a boolean
0094       out << SP << "if (fTensor_" << fNX << "[0] ) { \n";
0095       // then branch
0096       out << SP << SP << "outputs_" << opName << " = "
0097          << "fSession_" <<  fModel_then->GetName() << ".infer(" << fInputSignature_modelThen << ");\n";
0098        // else branch
0099       out << SP << "} else {\n";
0100       out << SP << SP << "outputs_" << opName << " = "
0101          << "fSession_" + fModel_else->GetName() + ".infer(" << fInputSignature_modelElse << ");\n";
0102       out << SP << "}\n";
0103       // copy the outputs
0104       out << SP << "if (outputs_" << opName << ".size() != " << fNYs.size() << ")\n";
0105       out << SP << SP << "throw std::runtime_error(\" If operator: invalid output size!\");\n\n";
0106       for (size_t i = 0; i < fNYs.size(); i++) {
0107          out << SP << "std::copy(outputs_" << opName << "[" << i << "].begin(), outputs_" << opName << "[" << i << "].end(), fTensor_" << fNYs[i] << ".begin());\n";
0108       }
0109       return out.str();
0110    }
0111 
0112 
0113 
0114 };
0115 
0116 }//SOFIE
0117 }//Experimental
0118 }//TMVA
0119 
0120 
0121 #endif //TMVA_SOFIE_ROPERATOR_Tanh