Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-22 10:28:05

0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <cassert>
0009 #include <cctype>
0010 #include <sstream>
0011 #include <algorithm>
0012 
0013 namespace TMVA{
0014 namespace Experimental{
0015 namespace SOFIE{
0016 
0017 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0018 
0019 
0020 class ROperator_Reshape final : public ROperator
0021 {
0022 
0023 private:
0024 
0025    bool fVerbose = false;
0026    bool fDimInput = false;
0027    bool fDynamicShape = false;
0028    ReshapeOpMode fOpMode = Reshape;   // type of Reshape operator
0029 
0030    int fAllowZero = 0; // (for Reshape) zero in tensor shape makes output shape equal to input tensor shape
0031    int fAxis = 1;      // (for Flatten)
0032 
0033    std::string fNData;        // input data tensor name
0034    std::string fNInput2;       // reshape or axes tensor name depending on operator
0035    std::string fNOutput;               // output tensor name
0036    std::vector<Dim> fShapeInput;     // input shape data
0037    std::vector<Dim> fShapeOutput;   // output shape data
0038    std::vector<int64_t> fAttrAxes;         // axes attributes (provided for all version of Squeeze/Unsqueeze)
0039    std::vector<int64_t> fShape;     // shape tensor values provided for Reshape
0040 
0041 public:
0042 
0043    std::string Name() const {
0044       if (fOpMode == Reshape) return "Reshape";
0045       if (fOpMode == Flatten) return "Flatten";
0046       if (fOpMode == Squeeze) return "Squeeze";
0047       if (fOpMode == Unsqueeze) return "Unsqueeze";
0048       return "";
0049    }
0050 
0051    ROperator_Reshape(){}
0052    ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameInput2, std::string nameOutput)
0053       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNInput2(UTILITY::Clean_name(nameInput2)),
0054          fNOutput(UTILITY::Clean_name(nameOutput))
0055    {
0056       if (opMode == Reshape) fAllowZero = attr_value;
0057       if (opMode == Flatten) fAxis = attr_value;
0058 
0059       fInputTensorNames = { fNData };
0060       if(!fNInput2.empty()){
0061          fInputTensorNames.emplace_back(fNInput2);
0062       }
0063       fOutputTensorNames = { fNOutput };
0064    }
0065 
0066    // for squeeze/unsqueezed operators following old ONNX version (< 10)
0067    // In this cases axes are passed as attribute values
0068    ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0069       : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0070         fAttrAxes(attrAxes)
0071    {
0072       assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0073    }
0074 
0075    // output type is same as input
0076    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0077       auto ret = std::vector<ETensorType>(1, input[0]);
0078       return ret;
0079    }
0080    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0081       return input;
0082    }
0083 
0084    // output shape
0085    std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & input)  {
0086       std::vector<std::vector<Dim>> ret;
0087       auto & input_shape = input[0];
0088       if (fOpMode == Reshape) {
0089          // correct the provided shape (here we have the value) for 0 or -1
0090          std::vector<Dim> output_shape(fShape.size());
0091          assert(!fShape.empty() && !fDynamicShape);
0092          for (size_t i = 0; i < output_shape.size(); i++) {
0093             if (fShape[i] > 0 || (fAllowZero && fShape[i] >= 0))
0094                output_shape[i] = Dim{ static_cast<size_t>(fShape[i]) };
0095             else if (!fAllowZero && fShape[i] == 0)
0096                output_shape[i] = input_shape[i];
0097          }
0098          // now case of -1 in shape
0099          for (size_t i = 0; i < output_shape.size(); i++) {
0100             if (fShape[i] == -1) {
0101                auto tmp = output_shape;
0102                tmp.erase(tmp.begin() + i);
0103                auto tmp_length = ConvertDimShapeToLength(tmp);
0104                auto input_length = ConvertDimShapeToLength(input_shape);
0105                if (fVerbose)
0106                   std::cout << "reshape- try simplifying " << ConvertDimShapeToString(input_shape) << " with length "
0107                             << input_length << " to " << tmp_length << std::endl;
0108 
0109                if (IsInteger(tmp_length) && IsInteger(input_length))
0110                   output_shape[i] = Dim{static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
0111                else {
0112                   //we can try simplifying expression if tmp_length is integer and part of input_length
0113                   // contains tmp_length
0114                   bool canSimplify = false;
0115                   std::vector <Dim> reduced_input;
0116                   if (IsInteger(tmp_length)) {
0117 
0118                      // try to tokenize with * the input length
0119 
0120                      std::stringstream ss(input_length);
0121 
0122                      std::string token;
0123 
0124                      // Tokenizing w.r.t. space '*'
0125                      while(getline(ss, token, '*'))
0126                      {
0127                         // remove any whitespace
0128                         token.erase(std::remove_if(token.begin(), token.end(),
0129                                                    [](unsigned char x) { return std::isspace(x); }), token.end());
0130                         if (token != tmp_length) {
0131                            if (IsInteger(token)) {
0132                                  size_t il = static_cast<size_t>(std::stoi(input_length));
0133                                  size_t tl = static_cast<size_t>(std::stoi(tmp_length));
0134                                  if ((il % tl) == 0) {
0135                                  canSimplify = true;
0136                                  reduced_input.push_back(Dim{il / tl});
0137                                  }
0138                            } else {
0139                               reduced_input.push_back(Dim{token});
0140                            }
0141                         } else {
0142                            // token is equal to tmp_length, can be not considered and is simplified
0143                            canSimplify = true;
0144                         }
0145                      }
0146                   }
0147                   if (canSimplify) {
0148                      // if length contains * we need to add some brackets
0149                      std::string res_shape = ConvertDimShapeToLength(reduced_input);
0150                      if (res_shape.find('*') != std::string::npos)
0151                         output_shape[i] = Dim{std::string("(") + res_shape + ")", static_cast<size_t>(-1)};
0152                      else
0153                         output_shape[i] = Dim{res_shape};
0154                   }
0155                   if (!canSimplify)
0156                      output_shape[i] = Dim{std::string("(") + input_length + " / (" + tmp_length + "))", static_cast<size_t>(-1)};
0157                }
0158 
0159                break; // cannot have more than -1
0160             }
0161             //  throw std::runtime_error(
0162             //                   "TMVA Reshape Op : output shape has multiple negative or zero values");
0163          }
0164 
0165          if (fVerbose)
0166             std::cout << "Reshape: correct output shape  to " << ConvertShapeToString(output_shape) << std::endl;
0167 
0168          if (!fDimInput && ConvertDimShapeToLength(output_shape) != ConvertDimShapeToLength(input_shape)) {
0169             throw std::runtime_error("TMVA Reshape Op : Invalid  shapes : " + ConvertShapeToString(input_shape) +
0170                                      ConvertShapeToString(output_shape));
0171          }
0172          ret.push_back(output_shape);
0173 
0174       } else if (fOpMode == Flatten) {
0175          // flatten case
0176          if (fAxis < 0)
0177             fAxis += input_shape.size();
0178          auto s1 = std::vector<Dim>(input_shape.begin(), input_shape.begin() + fAxis);
0179          auto s2 = std::vector<Dim>(input_shape.begin() + fAxis, input_shape.end());
0180          auto l1 = ConvertDimShapeToLength(s1);
0181          auto l2 = ConvertDimShapeToLength(s2);
0182          std::vector<Dim> newShape = {Dim{l1}, Dim{l2}};
0183          ret.push_back(newShape);
0184       } else if (fOpMode == Squeeze) {
0185          // squeeze
0186          // assume no axis is provided - remove all axes with value equal to 1
0187          auto output_shape = input_shape;
0188          if (fAttrAxes.empty()) {
0189             size_t i = 0;
0190             while (i < output_shape.size()) {
0191                if (output_shape[i] == Dim{1}) {
0192                   output_shape.erase(output_shape.begin() + i);
0193                } else {
0194                   i++;
0195                }
0196             }
0197          } else {
0198             auto &axes = fAttrAxes;
0199             for (size_t i = 0; i < axes.size(); i++) {
0200                if (axes[i] < 0)
0201                   axes[i] += input_shape.size();
0202                if (!(output_shape[axes[i]] == Dim{1}))
0203                   throw std::runtime_error("TMVA Squeeze Op : Invalid  axis value " + std::to_string(axes[i]) +
0204                                            " for " + ConvertShapeToString(output_shape));
0205                output_shape.erase(output_shape.begin() + axes[i]);
0206             }
0207          }
0208          ret.push_back(output_shape);
0209       }
0210       else if (fOpMode == Unsqueeze) {
0211          // unsqueeze
0212          std::cout << "doing unsqueeze....\n";
0213          assert(!fAttrAxes.empty());
0214          auto output_shape = input_shape;
0215          auto &axes = fAttrAxes;
0216          // output rank
0217          int64_t r = input[0].size() + axes.size();
0218          for (auto &a : axes) {
0219             int64_t i = static_cast<int64_t>(a);
0220             if (i < -r || i > r - 1)
0221                throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
0222             if (i >= 0)
0223                output_shape.insert(output_shape.begin() + i, Dim{1});
0224             else
0225                // negative axes
0226                output_shape.insert(output_shape.end() + i + 1, Dim{1});
0227          }
0228          ret.push_back(output_shape);
0229       }
0230       return ret;
0231    }
0232 
0233    void Initialize(RModel& model) override {
0234 
0235       std::cout << "initialize reshape op type " << fOpMode << " - " << fNInput2 << " " << fNData << std::endl;
0236       fVerbose = model.Verbose();
0237       if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0238           // input must be a graph input, or already initialized intermediate tensor
0239          throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + "  is not found in model");
0240       }
0241       fShapeInput = model.GetDimTensorShape(fNData);
0242       fDimInput = model.IsDynamicTensor(fNData);
0243       // check if optional tensor exists defining shape or axes
0244       if (!fNInput2.empty()) {
0245          if (model.CheckIfTensorAlreadyExist(fNInput2)) {
0246             if (model.IsConstantTensor(fNInput2) || model.IsInitializedTensor(fNInput2)) {
0247                // assume input shape is an initialized tensor
0248                auto dptr = model.GetInitializedTensorData(fNInput2);
0249                auto values = static_cast<int64_t *>(dptr.get());
0250                auto vec = model.GetTensorShape(fNInput2);
0251                size_t n = 1;
0252                if (vec.size() > 0)
0253                   n = vec[0]; // size of shape input tensor
0254                // copy values in fShape vector or fAttrAxes
0255                if (fOpMode == Reshape)
0256                   fShape = std::vector<int64_t>(values, values + n);
0257                else
0258                   fAttrAxes = std::vector<int64_t>(values, values + n);
0259 
0260                fShapeOutput = ShapeInference({fShapeInput})[0];
0261                // set flag to not write tensor in weight file. Its data will be hard-coded in way model is constructed
0262                model.SetNotWritableInitializedTensor(fNInput2);
0263             } else {
0264                // we cannot get shape at initialization time but at run-time
0265                fDynamicShape = true;
0266                // size of shape output us given by size of shape input tensor
0267                auto shapeInput2 = model.GetTensorShape(fNInput2);
0268                fShapeOutput.resize(shapeInput2[0]);
0269                for (size_t i = 0; i < fShapeOutput.size(); i++) {
0270                   fShapeOutput[i] = Dim{ std::string("s_") + fNOutput + "_" + std::to_string(i)};
0271                }
0272             }
0273          } else {
0274             throw std::runtime_error("TMVA Reshape Op 2nd input Tensor " + fNInput2 + " is not found in model");
0275          }
0276       } else if (!fAttrAxes.empty()) {
0277          // case fNShape is empty and axes are provided as attributes (e.g. for Unsqueeze)
0278          std::cout << "attribute axes exists\n";
0279          fShapeOutput = ShapeInference({fShapeInput})[0];
0280       } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0281          fShapeOutput = ShapeInference({fShapeInput})[0];
0282       } else {
0283          throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0284       }
0285       // check if output is constant or not
0286       if (model.IsInitializedTensor(fNData) && model.GetTensorType(fNData) == ETensorType::INT64) {
0287          fIsOutputConstant = true;
0288          auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
0289          auto o_shape = ConvertShapeToInt(fShapeOutput);
0290          if (ConvertShapeToLength(ConvertShapeToInt(fShapeInput)) != ConvertShapeToLength(o_shape) )
0291             throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
0292          model.AddConstantTensor<int64_t>(fNOutput, o_shape, inputData);
0293          if (model.Verbose()) {
0294             std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " -->  " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput)  << " : " <<
0295             ConvertValuesToString(ConvertShapeToLength(o_shape), inputData) << std::endl;
0296          }
0297       }
0298       // for shape tensors we can have it if output shape is size==1 or a scalar
0299       else if (model.IsShapeTensor(fNData) && fShapeOutput.size() <=1) {
0300          fIsOutputConstant = true;
0301          auto inputData = model.GetShapeTensorValues(fNData);
0302          model.AddShapeTensor(fNOutput, inputData);
0303          if (model.Verbose()) {
0304             std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " -->  " << fNOutput << " (shape) " << ConvertShapeToString(fShapeOutput)  << " : " <<
0305             ConvertShapeToString(inputData) << std::endl;
0306          }
0307       }
0308       else {
0309          // non-constant case
0310          model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0311          if (model.Verbose())
0312             std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " -->  "<< fNOutput << "  " << ConvertShapeToString(fShapeOutput)  << std::endl;
0313       }
0314    }
0315 
0316    std::string Generate(std::string opName) override {
0317       if (fIsOutputConstant) return "";  //no op for constant tensors
0318 
0319       std::stringstream out;
0320       std::string opType = "Reshape";
0321       if (fOpMode == Flatten)
0322          opType = "Flatten";
0323       else if (fOpMode == Squeeze)
0324          opType = "Squeeze";
0325       else if (fOpMode == Unsqueeze)
0326          opType = "Unsquueze";
0327 
0328       out << SP << "///--------" << opType << " operator " << opName << " --> " << ConvertShapeToString(fShapeOutput) << "\n";
0329 
0330       // in case of dynamic output shape we need to set the shape value from input shape tensor
0331       // and take case of the zero values
0332       if (fDynamicShape) {
0333          for (size_t i = 0; i < fShapeOutput.size(); i++) {
0334             // since fNInput2 values are int64_t, should we check if they are negative?
0335             out << SP << "size_t " << fShapeOutput[i].param << " = " << "tensor_" << fNInput2 << "[" << i << "];\n";
0336             if (!fAllowZero)
0337                out << SP << "if (tensor_" << fNInput2 << "[" << i << "] <= 0 ) "
0338                          <<  fShapeOutput[i].param << " = " <<  fShapeInput[i] << ";\n";
0339          }
0340       }
0341 
0342       // output of reshape is same as input
0343       auto lengthOut = ConvertDimShapeToLength(fShapeOutput);
0344       auto lengthIn = ConvertDimShapeToLength(fShapeInput);
0345       if (lengthOut != lengthIn) {
0346          // check needs to be done at run-time
0347          out << SP << "if (" << lengthOut << "!=" << lengthIn << ")\n";
0348          out << "throw std::runtime_error(\"TMVA SOFIE Reshape Op : output lengths is different than input one\");\n";
0349       }
0350 
0351 
0352       out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << lengthIn << ", " << "tensor_" << fNOutput
0353           << ");\n";
0354       return out.str();
0355    }
0356 };
0357 
0358 }//SOFIE
0359 }//Experimental
0360 }//TMVA
0361 
0362 
0363 #endif //TMVA_SOFIE_ROPERATOR_RESHAPE