Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:38

0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <cassert>
0009 #include <sstream>
0010 
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014 
0015 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0016 
0017 
0018 class ROperator_Reshape final : public ROperator
0019 {
0020 
0021 private:
0022 
0023    bool fVerbose = false;
0024    ReshapeOpMode fOpMode = Reshape;   // type of Reshape operator
0025 
0026    int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
0027    int fAxis = 1;      // (for Flatten)
0028 
0029    std::string fNData;        // input data tensor name
0030    std::string fNShape;       // reshape tensor name
0031    std::string fNOutput;               // output tensor name
0032    std::vector<size_t> fShapeInput;     // input shape data
0033    std::vector<size_t> fShapeOutput;   // output shape data
0034    std::vector<int64_t> fAttrAxes;         // axes attributes (provided for all version of Squeeze/Unsqueeze)
0035 
0036 public:
0037 
0038    std::string Name() const {
0039       if (fOpMode == Reshape) return "Reshape";
0040       if (fOpMode == Flatten) return "Flatten";
0041       if (fOpMode == Squeeze) return "Squeeze";
0042       if (fOpMode == Unsqueeze) return "Unsqueeze";
0043       return "";
0044    }
0045 
0046    ROperator_Reshape(){}
0047    ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
0048       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
0049       fNOutput(UTILITY::Clean_name(nameOutput))
0050    {
0051       if (opMode == Reshape) fAllowZero = attr_value;
0052       if (opMode == Flatten) fAxis = attr_value;
0053 
0054       fInputTensorNames = { fNData };
0055       if(!fNShape.empty()){
0056          fInputTensorNames.emplace_back(fNShape);
0057       }
0058       fOutputTensorNames = { fNOutput };
0059    }
0060 
0061    // for squeeze/unsqueezed operators following old ONNX version (< 10)
0062    // In this cases axes are passed as attribute values
0063    ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0064       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0065         fAttrAxes(attrAxes)
0066    {
0067       assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0068    }
0069 
0070    // output type is same as input
0071    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0072       auto ret = std::vector<ETensorType>(1, input[0]);
0073       return ret;
0074    }
0075 
0076    // output shape
0077    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0078       std::vector<std::vector<size_t>> ret;
0079       auto & input_shape = input[0];
0080 
0081       if (fOpMode == Reshape) {
0082          if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
0083          auto output_shape = input[1]; // the provided shape
0084          size_t input_length = ConvertShapeToLength(input_shape);
0085          size_t output_length = ConvertShapeToLength(output_shape);
0086          // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
0087          if (input_length != output_length) {
0088             if ((output_length == 0 && fAllowZero == 0) || static_cast<long>(output_length)  < 0) {
0089                // in this case value 0 or -1 in shape are automatically corrected
0090                bool replacementDone = false;
0091                for (size_t i = 0; i < output_shape.size(); i++) {
0092                   if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
0093                      if (replacementDone) {
0094                         throw std::runtime_error("TMVA Reshape Op : output shape has multiple negative or zero values");
0095                      }
0096                      auto tmp = output_shape;
0097                      tmp.erase(tmp.begin() + i);
0098                      auto tmp_length = ConvertShapeToLength(tmp);
0099                      output_shape[i] = input_length / tmp_length;
0100                      replacementDone = true;
0101                   }
0102                }
0103                if (fVerbose)
0104                   std::cout << "Reshape: correct output shape from " << ConvertShapeToString(input[1])
0105                         << " to " << ConvertShapeToString(output_shape) << std::endl;
0106             }
0107             if (ConvertShapeToLength(output_shape) != input_length) {
0108                throw std::runtime_error("TMVA Reshape Op : Invalid  shapes : " + ConvertShapeToString(input_shape) +
0109                                         ConvertShapeToString(output_shape));
0110             }
0111          }
0112          ret.push_back(output_shape);
0113 
0114       } else if (fOpMode == Flatten) {
0115          // flattenig case
0116          size_t inputSize = ConvertShapeToLength(input_shape);
0117          size_t b = input[0][0];
0118          std::vector<size_t> newShape = {b, inputSize / b};
0119          ret.push_back(newShape);
0120 
0121       } else if (fOpMode == Squeeze) {
0122          // squeeze
0123          // assume no axis is provided - remove all axes with value equal to 1
0124          auto output_shape = input[0];
0125          if (input.size() == 1) {
0126             size_t i = 0;
0127             while (i < output_shape.size()) {
0128                if (output_shape[i] == 1 ) {
0129                   output_shape.erase(output_shape.begin() + i);
0130                }
0131                else {
0132                   i++;
0133                }
0134             }
0135          } else if (input.size() == 2) {
0136             auto & axes = input[1];
0137             for (size_t i = 0; i < axes.size(); i++){
0138                if (output_shape[axes[i]] != 1)
0139                   throw std::runtime_error("TMVA Squeeze Op : Invalid  axes : " + ConvertShapeToString(axes) +
0140                                            ConvertShapeToString(output_shape));
0141                output_shape.erase(output_shape.begin() + axes[i]);
0142             }
0143          }
0144          ret.push_back(output_shape);
0145       }
0146 
0147       else if (fOpMode == Unsqueeze) {
0148          // unsqueeze
0149          assert(input.size() == 2);
0150          auto output_shape = input[0];
0151          auto &axes = input[1];
0152          // output rank
0153          int64_t r = input[0].size() + axes.size();
0154          for (auto & a : axes) {
0155             int64_t i = static_cast<int64_t>(a);
0156             if ( i < -r  || i > r - 1 )
0157                throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
0158             if (i >= 0)
0159                output_shape.insert(output_shape.begin() + i, 1);
0160             else
0161                //negative axes
0162                output_shape.insert(output_shape.end() + i + 1, 1);
0163          }
0164          ret.push_back(output_shape);
0165       }
0166       return ret;
0167    }
0168 
0169    void Initialize(RModel& model) override {
0170 
0171       fVerbose = model.Verbose();
0172       if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0173           // input must be a graph input, or already initialized intermediate tensor
0174          throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + "  is not found in model");
0175       }
0176       fShapeInput = model.GetTensorShape(fNData);
0177       // check if optional shape tensor exist
0178       if (!fNShape.empty()) {
0179          if (model.CheckIfTensorAlreadyExist(fNShape)) {
0180             auto dptr = model.GetInitializedTensorData(fNShape);
0181             auto input_shape = static_cast<int64_t *>(dptr.get());
0182             auto vec = model.GetTensorShape(fNShape);
0183             assert(vec.size() == 1);
0184             size_t n = vec[0]; // size of shape input tensor
0185 
0186             std::vector<size_t> descShape(n);
0187             std::copy(input_shape, input_shape + n, descShape.begin());
0188             fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0189             // set flag to not write tensor in weight file. Its data will be hard-coded in way model is constructed
0190             model.SetNotWritableInitializedTensor(fNShape);
0191          } else {
0192             throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
0193          }
0194       } else if (!fAttrAxes.empty()) {
0195          // case fNShape is empty and axes are provided as attributes
0196          std::vector<size_t> descShape(fAttrAxes.size());
0197          std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
0198          fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0199       } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0200          fShapeOutput = ShapeInference({fShapeInput})[0];
0201       } else {
0202          throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0203       }
0204       // check if output is constant or not
0205       if (model.IsInitializedTensor(fNData) && model.GetTensorType(fNData) == ETensorType::INT64) {
0206          fIsOutputConstant = true;
0207          auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
0208          if (ConvertShapeToLength(fShapeInput) != ConvertShapeToLength(fShapeOutput))
0209             throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
0210          model.AddConstantTensor<int64_t>(fNOutput, fShapeOutput, inputData);
0211          if (model.Verbose()) {
0212             std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " -->  " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput)  << " : " <<
0213             ConvertValuesToString(ConvertShapeToLength(fShapeOutput), inputData) << std::endl;
0214          }
0215       } else {
0216          // non-constant case
0217          model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0218          if (model.Verbose())
0219             std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " -->  "<< fNOutput << "  " << ConvertShapeToString(fShapeOutput)  << std::endl;
0220       }
0221    }
0222 
0223    std::string Generate(std::string OpName) override {
0224       if (fIsOutputConstant) return "";  //no op for constant tensors
0225 
0226       OpName = "op_" + OpName;
0227 
0228       // output of reshape is same as input
0229       size_t length = ConvertShapeToLength(fShapeOutput);
0230       if (length != ConvertShapeToLength(fShapeInput)) {
0231          throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
0232                                   ConvertShapeToString(fShapeOutput) + " and input is " +
0233                                   ConvertShapeToString(fShapeInput));
0234       }
0235       std::stringstream out;
0236       std::string opName = "Reshape";
0237       if (fOpMode == Flatten)
0238          opName = "Flatten";
0239       else if (fOpMode == Squeeze)
0240          opName = "Squeeze";
0241       else if (fOpMode == Unsqueeze)
0242          opName = "Unsquueze";
0243 
0244       out << SP << "///--------" << opName << " operator\n" << std::endl;
0245       out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
0246           << ");\n";
0247       return out.str();
0248    }
0249 };
0250 
0251 }//SOFIE
0252 }//Experimental
0253 }//TMVA
0254 
0255 
0256 #endif //TMVA_SOFIE_ROPERATOR_RESHAPE