Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:08

0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <cassert>
0009 #include <sstream>
0010 
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014 
0015 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0016 
0017 template <typename T>
0018 class ROperator_Reshape final : public ROperator
0019 {
0020 
0021 private:
0022 
0023    ReshapeOpMode fOpMode = Reshape;   // type of Reshape operator
0024 
0025    int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
0026    int fAxis = 1;      // (for Flatten)
0027 
0028    std::string fNData;        // input data tensor name
0029    std::string fNShape;       // reshape tensor name
0030    std::string fNOutput;               // output tensor name
0031    std::vector<size_t> fShapeInput;     // input shape data
0032    std::vector<size_t> fShapeOutput;   // output shape data
0033    std::vector<int64_t> fAttrAxes;         // axes attributes (provided for all version of Squeeze/Unsqueeze)
0034 
0035 public:
0036 
0037    ROperator_Reshape(){}
0038    ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
0039       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
0040       fNOutput(UTILITY::Clean_name(nameOutput))
0041    {
0042       if (opMode == Reshape) fAllowZero = attr_value;
0043       if (opMode == Flatten) fAxis = attr_value;
0044    }
0045 
0046    // for squeeze/unsqueezed operators following old ONNX version (< 10)
0047    // In this cases axes are passed as attribute values
0048    ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0049       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0050         fAttrAxes(attrAxes)
0051    {
0052       assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0053    }
0054 
0055    // output type is same as input
0056    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0057       auto ret = std::vector<ETensorType>(1, input[0]);
0058       return ret;
0059    }
0060 
0061    // output shape
0062    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0063       std::vector<std::vector<size_t>> ret;
0064       auto & input_shape = input[0];
0065 
0066       if (fOpMode == Reshape) {
0067          if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
0068          auto output_shape = input[1]; // the provided shape
0069          size_t input_length = ConvertShapeToLength(input_shape);
0070          size_t output_length = ConvertShapeToLength(output_shape);
0071          // (input_length == output_length) is the easy case : (2,3,4) -> (2,12)
0072          if (input_length != output_length) {
0073             if (output_shape.size() > 1 && ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX)) {
0074                // in this case value 0 in shape are automatically corrected
0075                for (size_t i = 0; i < output_shape.size(); i++) {
0076                   if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
0077                      auto tmp = output_shape;
0078                      tmp.erase(tmp.begin() + i);
0079                      auto tmp_length = ConvertShapeToLength(tmp);
0080                      output_shape[i] = input_length / tmp_length;
0081                      break;
0082                   }
0083                }
0084             }
0085             if (ConvertShapeToLength(output_shape) != input_length) {
0086                throw std::runtime_error("TMVA Reshape Op : Invalid  shapes : " + ConvertShapeToString(input_shape) +
0087                                         ConvertShapeToString(output_shape));
0088             }
0089          }
0090          ret.push_back(output_shape);
0091 
0092       } else if (fOpMode == Flatten) {
0093          // flattenig case
0094          size_t inputSize = ConvertShapeToLength(input_shape);
0095          size_t b = input[0][0];
0096          std::vector<size_t> newShape = {b, inputSize / b};
0097          ret.push_back(newShape);
0098 
0099       } else if (fOpMode == Squeeze) {
0100          // squeeze
0101          // assume no axis is provided - remove all axes with value equal to 1
0102          auto output_shape = input[0];
0103          if (input.size() == 1) {
0104             size_t i = 0;
0105             while (i < output_shape.size()) {
0106                if (output_shape[i] == 1 ) {
0107                   output_shape.erase(output_shape.begin() + i);
0108                }
0109                else {
0110                   i++;
0111                }
0112             }
0113          } else if (input.size() == 2) {
0114             auto & axes = input[1];
0115             for (size_t i = 0; i < axes.size(); i++){
0116                if (output_shape[axes[i]] != 1)
0117                   throw std::runtime_error("TMVA Squeeze Op : Invalid  axes : " + ConvertShapeToString(axes) +
0118                                            ConvertShapeToString(output_shape));
0119                output_shape.erase(output_shape.begin() + axes[i]);
0120             }
0121          }
0122          ret.push_back(output_shape);
0123       }
0124 
0125       else if (fOpMode == Unsqueeze) {
0126          // unsqueeze
0127          assert(input.size() == 2);
0128          auto output_shape = input[0];
0129          auto &axes = input[1];
0130          if (axes[0] > 0) { // positive axis start from beginning
0131             for (auto & i : axes)
0132                output_shape.insert(output_shape.begin() + i, 1);
0133          } else {
0134             //negative axes
0135             for (auto &i : axes) {
0136                assert(i < 0);
0137                output_shape.insert(output_shape.begin() + (output_shape.size() + i - 1), 1);
0138             }
0139          }
0140          ret.push_back(output_shape);
0141       }
0142       return ret;
0143    }
0144 
0145    void Initialize(RModel &model)
0146    {
0147 
0148       if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0149           // input must be a graph input, or already initialized intermediate tensor
0150          throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + "  is not found in model");
0151       }
0152       fShapeInput = model.GetTensorShape(fNData);
0153       // check if optional shape tensor exist
0154       if (!fNShape.empty()) {
0155          if (model.CheckIfTensorAlreadyExist(fNShape)) {
0156             auto dptr = model.GetInitializedTensorData(fNShape);
0157             auto input_shape = static_cast<int64_t *>(dptr.get());
0158             auto vec = model.GetTensorShape(fNShape);
0159             assert(vec.size() == 1);
0160             size_t n = vec[0]; // size of shape input tensor
0161 
0162             std::vector<size_t> descShape(n);
0163             std::copy(input_shape, input_shape + n, descShape.begin());
0164             fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0165          } else {
0166             throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
0167          }
0168       } else if (!fAttrAxes.empty()) {
0169          // case fNShape is empty and axes are provided as attributes
0170          std::vector<size_t> descShape(fAttrAxes.size());
0171          std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
0172          fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0173       } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0174          fShapeOutput = ShapeInference({fShapeInput})[0];
0175       } else {
0176          throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0177       }
0178       model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0179    }
0180 
0181    std::string Generate(std::string OpName)
0182    {
0183       OpName = "op_" + OpName;
0184       if (fShapeInput.empty() || fShapeOutput.empty()) {
0185          throw std::runtime_error("TMVA SOFIE Reshape Op called to Generate without being initialized first");
0186       }
0187 
0188       // output of reshape is same as input
0189       size_t length = ConvertShapeToLength(fShapeOutput);
0190       if (length != ConvertShapeToLength(fShapeInput)) {
0191          throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
0192                                   ConvertShapeToString(fShapeOutput) + " and input is " +
0193                                   ConvertShapeToString(fShapeInput));
0194       }
0195       std::stringstream out;
0196       std::string opName = "Reshape";
0197       if (fOpMode == Flatten)
0198          opName = "Flatten";
0199       else if (fOpMode == Squeeze)
0200          opName = "Squeeze";
0201       else if (fOpMode == Unsqueeze)
0202          opName = "Unsquueze";
0203 
0204       out << SP << "///--------" << opName << " operator\n" << std::endl;
0205       out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
0206           << ");\n";
0207       return out.str();
0208    }
0209 };
0210 
0211 }//SOFIE
0212 }//Experimental
0213 }//TMVA
0214 
0215 
0216 #endif //TMVA_SOFIE_ROPERATOR_RESHAPE