File indexing completed on 2025-09-18 09:32:38
0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <cassert>
0009 #include <sstream>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0016
0017
0018 class ROperator_Reshape final : public ROperator
0019 {
0020
0021 private:
0022
0023 bool fVerbose = false;
0024 ReshapeOpMode fOpMode = Reshape;
0025
0026 int fAllowZero = 0;
0027 int fAxis = 1;
0028
0029 std::string fNData;
0030 std::string fNShape;
0031 std::string fNOutput;
0032 std::vector<size_t> fShapeInput;
0033 std::vector<size_t> fShapeOutput;
0034 std::vector<int64_t> fAttrAxes;
0035
0036 public:
0037
0038 std::string Name() const {
0039 if (fOpMode == Reshape) return "Reshape";
0040 if (fOpMode == Flatten) return "Flatten";
0041 if (fOpMode == Squeeze) return "Squeeze";
0042 if (fOpMode == Unsqueeze) return "Unsqueeze";
0043 return "";
0044 }
0045
0046 ROperator_Reshape(){}
0047 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
0048 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
0049 fNOutput(UTILITY::Clean_name(nameOutput))
0050 {
0051 if (opMode == Reshape) fAllowZero = attr_value;
0052 if (opMode == Flatten) fAxis = attr_value;
0053
0054 fInputTensorNames = { fNData };
0055 if(!fNShape.empty()){
0056 fInputTensorNames.emplace_back(fNShape);
0057 }
0058 fOutputTensorNames = { fNOutput };
0059 }
0060
0061
0062
0063 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0064 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0065 fAttrAxes(attrAxes)
0066 {
0067 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0068 }
0069
0070
0071 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0072 auto ret = std::vector<ETensorType>(1, input[0]);
0073 return ret;
0074 }
0075
0076
0077 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0078 std::vector<std::vector<size_t>> ret;
0079 auto & input_shape = input[0];
0080
0081 if (fOpMode == Reshape) {
0082 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
0083 auto output_shape = input[1];
0084 size_t input_length = ConvertShapeToLength(input_shape);
0085 size_t output_length = ConvertShapeToLength(output_shape);
0086
0087 if (input_length != output_length) {
0088 if ((output_length == 0 && fAllowZero == 0) || static_cast<long>(output_length) < 0) {
0089
0090 bool replacementDone = false;
0091 for (size_t i = 0; i < output_shape.size(); i++) {
0092 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
0093 if (replacementDone) {
0094 throw std::runtime_error("TMVA Reshape Op : output shape has multiple negative or zero values");
0095 }
0096 auto tmp = output_shape;
0097 tmp.erase(tmp.begin() + i);
0098 auto tmp_length = ConvertShapeToLength(tmp);
0099 output_shape[i] = input_length / tmp_length;
0100 replacementDone = true;
0101 }
0102 }
0103 if (fVerbose)
0104 std::cout << "Reshape: correct output shape from " << ConvertShapeToString(input[1])
0105 << " to " << ConvertShapeToString(output_shape) << std::endl;
0106 }
0107 if (ConvertShapeToLength(output_shape) != input_length) {
0108 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
0109 ConvertShapeToString(output_shape));
0110 }
0111 }
0112 ret.push_back(output_shape);
0113
0114 } else if (fOpMode == Flatten) {
0115
0116 size_t inputSize = ConvertShapeToLength(input_shape);
0117 size_t b = input[0][0];
0118 std::vector<size_t> newShape = {b, inputSize / b};
0119 ret.push_back(newShape);
0120
0121 } else if (fOpMode == Squeeze) {
0122
0123
0124 auto output_shape = input[0];
0125 if (input.size() == 1) {
0126 size_t i = 0;
0127 while (i < output_shape.size()) {
0128 if (output_shape[i] == 1 ) {
0129 output_shape.erase(output_shape.begin() + i);
0130 }
0131 else {
0132 i++;
0133 }
0134 }
0135 } else if (input.size() == 2) {
0136 auto & axes = input[1];
0137 for (size_t i = 0; i < axes.size(); i++){
0138 if (output_shape[axes[i]] != 1)
0139 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
0140 ConvertShapeToString(output_shape));
0141 output_shape.erase(output_shape.begin() + axes[i]);
0142 }
0143 }
0144 ret.push_back(output_shape);
0145 }
0146
0147 else if (fOpMode == Unsqueeze) {
0148
0149 assert(input.size() == 2);
0150 auto output_shape = input[0];
0151 auto &axes = input[1];
0152
0153 int64_t r = input[0].size() + axes.size();
0154 for (auto & a : axes) {
0155 int64_t i = static_cast<int64_t>(a);
0156 if ( i < -r || i > r - 1 )
0157 throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
0158 if (i >= 0)
0159 output_shape.insert(output_shape.begin() + i, 1);
0160 else
0161
0162 output_shape.insert(output_shape.end() + i + 1, 1);
0163 }
0164 ret.push_back(output_shape);
0165 }
0166 return ret;
0167 }
0168
0169 void Initialize(RModel& model) override {
0170
0171 fVerbose = model.Verbose();
0172 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0173
0174 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
0175 }
0176 fShapeInput = model.GetTensorShape(fNData);
0177
0178 if (!fNShape.empty()) {
0179 if (model.CheckIfTensorAlreadyExist(fNShape)) {
0180 auto dptr = model.GetInitializedTensorData(fNShape);
0181 auto input_shape = static_cast<int64_t *>(dptr.get());
0182 auto vec = model.GetTensorShape(fNShape);
0183 assert(vec.size() == 1);
0184 size_t n = vec[0];
0185
0186 std::vector<size_t> descShape(n);
0187 std::copy(input_shape, input_shape + n, descShape.begin());
0188 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0189
0190 model.SetNotWritableInitializedTensor(fNShape);
0191 } else {
0192 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
0193 }
0194 } else if (!fAttrAxes.empty()) {
0195
0196 std::vector<size_t> descShape(fAttrAxes.size());
0197 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
0198 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0199 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0200 fShapeOutput = ShapeInference({fShapeInput})[0];
0201 } else {
0202 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0203 }
0204
0205 if (model.IsInitializedTensor(fNData) && model.GetTensorType(fNData) == ETensorType::INT64) {
0206 fIsOutputConstant = true;
0207 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
0208 if (ConvertShapeToLength(fShapeInput) != ConvertShapeToLength(fShapeOutput))
0209 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
0210 model.AddConstantTensor<int64_t>(fNOutput, fShapeOutput, inputData);
0211 if (model.Verbose()) {
0212 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput) << " : " <<
0213 ConvertValuesToString(ConvertShapeToLength(fShapeOutput), inputData) << std::endl;
0214 }
0215 } else {
0216
0217 model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0218 if (model.Verbose())
0219 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> "<< fNOutput << " " << ConvertShapeToString(fShapeOutput) << std::endl;
0220 }
0221 }
0222
0223 std::string Generate(std::string OpName) override {
0224 if (fIsOutputConstant) return "";
0225
0226 OpName = "op_" + OpName;
0227
0228
0229 size_t length = ConvertShapeToLength(fShapeOutput);
0230 if (length != ConvertShapeToLength(fShapeInput)) {
0231 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
0232 ConvertShapeToString(fShapeOutput) + " and input is " +
0233 ConvertShapeToString(fShapeInput));
0234 }
0235 std::stringstream out;
0236 std::string opName = "Reshape";
0237 if (fOpMode == Flatten)
0238 opName = "Flatten";
0239 else if (fOpMode == Squeeze)
0240 opName = "Squeeze";
0241 else if (fOpMode == Unsqueeze)
0242 opName = "Unsquueze";
0243
0244 out << SP << "///--------" << opName << " operator\n" << std::endl;
0245 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
0246 << ");\n";
0247 return out.str();
0248 }
0249 };
0250
0251 }
0252 }
0253 }
0254
0255
0256 #endif