Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 09:08:54

0001 #ifndef TMVA_SOFIE_ROPERATOR_Shape
0002 #define TMVA_SOFIE_ROPERATOR_Shape
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 #include<sstream>
0010 #include<vector>
0011 #include <iterator>
0012 #include<string>
0013 namespace TMVA{
0014 namespace Experimental{
0015 namespace SOFIE{
0016 
0017 class ROperator_Shape final : public ROperator
0018 {
0019 
0020 private:
0021 
0022    /* Attributes*/
0023    int fStart = 0;  // default is beginning
0024    int fEnd = 0; // default is input length (all input tensor shape included)
0025    std::string fNX;
0026    std::string fNY;
0027    std::vector<size_t> fShape;
0028    std::vector<size_t> fOutput_shape;
0029 
0030 public:
0031    ROperator_Shape(){}
0032    ROperator_Shape(int start, int end, std::string nameX, std::string nameY):
0033    fStart(start) ,fEnd(end), fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)){
0034          fInputTensorNames = { fNX };
0035          fOutputTensorNames = { fNY };
0036    }
0037 
0038    void Initialize(RModel& model) override {
0039       if (model.CheckIfTensorAlreadyExist(fNX) == false){   //input must be a graph input, or already initialized intermediate tensor
0040          throw std::runtime_error("TMVA SOFIE Shape Op Input Tensor " + fNX + " is not found in model");
0041       }
0042       fShape = model.GetTensorShape(fNX);
0043       size_t length = fShape.size();  // this the size of shape not length of tensor
0044       fStart = std::max(fStart,(int) -length);
0045       fStart = std::min(fStart,(int) length);
0046       if (fStart < 0) fStart += length;
0047       fEnd = std::max(fEnd,(int) -length);
0048       fEnd = std::min(fEnd, (int) length);
0049       if (fEnd < 0) fEnd += length;
0050       if (fEnd > fStart)
0051          fOutput_shape = { size_t(fEnd - fStart) };
0052       // in case the input tensor is not a dynamic tensor we should register the output as a Constant tensor since we know
0053       // its content
0054       if (!model.IsDynamicTensor(fNX) && !fOutput_shape.empty()) {
0055          std::shared_ptr<void> data(malloc(length * sizeof(int64_t)), free);
0056          auto shape_values = std::vector<int64_t>(fShape.begin()+fStart, fShape.begin() + fEnd );
0057          std::memcpy(data.get(), (void*) shape_values.data(), length * sizeof(int64_t));
0058          model.AddConstantTensor(fNY, ETensorType::INT64, fOutput_shape, data);
0059          fOutputTensorNames.pop_back();
0060          if (model.Verbose()) {
0061             std::cout << "Output of Shape is constant tensor with shape " << ConvertShapeToString(fOutput_shape) << " and values ";
0062             for (size_t i = 0; i < shape_values.size(); i++)
0063                std::cout << shape_values[i] << "  ";
0064             std::cout << std::endl;
0065          }
0066          fIsOutputConstant = true;
0067       }
0068       else
0069          model.AddIntermediateTensor(fNY, ETensorType::INT64, fOutput_shape);
0070 
0071 
0072    }
0073 
0074    std::string Generate(std::string OpName) override {
0075       // no need to generate code if the output is constant
0076       if (fIsOutputConstant) return "";
0077 
0078       OpName = "op_" + OpName;
0079       if (fShape.empty()) {
0080          throw std::runtime_error("TMVA SOFIE Shape op called to Generate without being initialized first");
0081       }
0082       std::stringstream out;
0083 
0084       out << "\n//------ Shape\n";
0085       // add a dummy statement to avoid warning for unused input
0086       out << SP << "(void) tensor_" << fNX << ";\n";
0087       size_t length = ConvertShapeToLength(fOutput_shape);
0088       for (size_t id = 0; id < length; id++) {
0089          out << SP << "tensor_" << fNY << "["<< id << "] = " << fShape[fStart+id] << ";\n";
0090       }
0091       return out.str();
0092    }
0093 
0094 };
0095 
0096 }//SOFIE
0097 }//Experimental
0098 }//TMVA
0099 
0100 
0101 #endif //TMVA_SOFIE_ROPERATOR_Shape