File indexing completed on 2025-12-22 10:28:05
0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <cassert>
0009 #include <cctype>
0010 #include <sstream>
0011 #include <algorithm>
0012
0013 namespace TMVA{
0014 namespace Experimental{
0015 namespace SOFIE{
0016
0017 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0018
0019
0020 class ROperator_Reshape final : public ROperator
0021 {
0022
0023 private:
0024
0025 bool fVerbose = false;
0026 bool fDimInput = false;
0027 bool fDynamicShape = false;
0028 ReshapeOpMode fOpMode = Reshape;
0029
0030 int fAllowZero = 0;
0031 int fAxis = 1;
0032
0033 std::string fNData;
0034 std::string fNInput2;
0035 std::string fNOutput;
0036 std::vector<Dim> fShapeInput;
0037 std::vector<Dim> fShapeOutput;
0038 std::vector<int64_t> fAttrAxes;
0039 std::vector<int64_t> fShape;
0040
0041 public:
0042
0043 std::string Name() const {
0044 if (fOpMode == Reshape) return "Reshape";
0045 if (fOpMode == Flatten) return "Flatten";
0046 if (fOpMode == Squeeze) return "Squeeze";
0047 if (fOpMode == Unsqueeze) return "Unsqueeze";
0048 return "";
0049 }
0050
0051 ROperator_Reshape(){}
0052 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameInput2, std::string nameOutput)
0053 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNInput2(UTILITY::Clean_name(nameInput2)),
0054 fNOutput(UTILITY::Clean_name(nameOutput))
0055 {
0056 if (opMode == Reshape) fAllowZero = attr_value;
0057 if (opMode == Flatten) fAxis = attr_value;
0058
0059 fInputTensorNames = { fNData };
0060 if(!fNInput2.empty()){
0061 fInputTensorNames.emplace_back(fNInput2);
0062 }
0063 fOutputTensorNames = { fNOutput };
0064 }
0065
0066
0067
0068 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0069 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0070 fAttrAxes(attrAxes)
0071 {
0072 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0073 }
0074
0075
0076 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0077 auto ret = std::vector<ETensorType>(1, input[0]);
0078 return ret;
0079 }
0080 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0081 return input;
0082 }
0083
0084
0085 std::vector<std::vector<Dim>> ShapeInference(const std::vector<std::vector<Dim>> & input) {
0086 std::vector<std::vector<Dim>> ret;
0087 auto & input_shape = input[0];
0088 if (fOpMode == Reshape) {
0089
0090 std::vector<Dim> output_shape(fShape.size());
0091 assert(!fShape.empty() && !fDynamicShape);
0092 for (size_t i = 0; i < output_shape.size(); i++) {
0093 if (fShape[i] > 0 || (fAllowZero && fShape[i] >= 0))
0094 output_shape[i] = Dim{ static_cast<size_t>(fShape[i]) };
0095 else if (!fAllowZero && fShape[i] == 0)
0096 output_shape[i] = input_shape[i];
0097 }
0098
0099 for (size_t i = 0; i < output_shape.size(); i++) {
0100 if (fShape[i] == -1) {
0101 auto tmp = output_shape;
0102 tmp.erase(tmp.begin() + i);
0103 auto tmp_length = ConvertDimShapeToLength(tmp);
0104 auto input_length = ConvertDimShapeToLength(input_shape);
0105 if (fVerbose)
0106 std::cout << "reshape- try simplifying " << ConvertDimShapeToString(input_shape) << " with length "
0107 << input_length << " to " << tmp_length << std::endl;
0108
0109 if (IsInteger(tmp_length) && IsInteger(input_length))
0110 output_shape[i] = Dim{static_cast<size_t>(std::stoi(input_length) / std::stoi(tmp_length))};
0111 else {
0112
0113
0114 bool canSimplify = false;
0115 std::vector <Dim> reduced_input;
0116 if (IsInteger(tmp_length)) {
0117
0118
0119
0120 std::stringstream ss(input_length);
0121
0122 std::string token;
0123
0124
0125 while(getline(ss, token, '*'))
0126 {
0127
0128 token.erase(std::remove_if(token.begin(), token.end(),
0129 [](unsigned char x) { return std::isspace(x); }), token.end());
0130 if (token != tmp_length) {
0131 if (IsInteger(token)) {
0132 size_t il = static_cast<size_t>(std::stoi(input_length));
0133 size_t tl = static_cast<size_t>(std::stoi(tmp_length));
0134 if ((il % tl) == 0) {
0135 canSimplify = true;
0136 reduced_input.push_back(Dim{il / tl});
0137 }
0138 } else {
0139 reduced_input.push_back(Dim{token});
0140 }
0141 } else {
0142
0143 canSimplify = true;
0144 }
0145 }
0146 }
0147 if (canSimplify) {
0148
0149 std::string res_shape = ConvertDimShapeToLength(reduced_input);
0150 if (res_shape.find('*') != std::string::npos)
0151 output_shape[i] = Dim{std::string("(") + res_shape + ")", static_cast<size_t>(-1)};
0152 else
0153 output_shape[i] = Dim{res_shape};
0154 }
0155 if (!canSimplify)
0156 output_shape[i] = Dim{std::string("(") + input_length + " / (" + tmp_length + "))", static_cast<size_t>(-1)};
0157 }
0158
0159 break;
0160 }
0161
0162
0163 }
0164
0165 if (fVerbose)
0166 std::cout << "Reshape: correct output shape to " << ConvertShapeToString(output_shape) << std::endl;
0167
0168 if (!fDimInput && ConvertDimShapeToLength(output_shape) != ConvertDimShapeToLength(input_shape)) {
0169 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
0170 ConvertShapeToString(output_shape));
0171 }
0172 ret.push_back(output_shape);
0173
0174 } else if (fOpMode == Flatten) {
0175
0176 if (fAxis < 0)
0177 fAxis += input_shape.size();
0178 auto s1 = std::vector<Dim>(input_shape.begin(), input_shape.begin() + fAxis);
0179 auto s2 = std::vector<Dim>(input_shape.begin() + fAxis, input_shape.end());
0180 auto l1 = ConvertDimShapeToLength(s1);
0181 auto l2 = ConvertDimShapeToLength(s2);
0182 std::vector<Dim> newShape = {Dim{l1}, Dim{l2}};
0183 ret.push_back(newShape);
0184 } else if (fOpMode == Squeeze) {
0185
0186
0187 auto output_shape = input_shape;
0188 if (fAttrAxes.empty()) {
0189 size_t i = 0;
0190 while (i < output_shape.size()) {
0191 if (output_shape[i] == Dim{1}) {
0192 output_shape.erase(output_shape.begin() + i);
0193 } else {
0194 i++;
0195 }
0196 }
0197 } else {
0198 auto &axes = fAttrAxes;
0199 for (size_t i = 0; i < axes.size(); i++) {
0200 if (axes[i] < 0)
0201 axes[i] += input_shape.size();
0202 if (!(output_shape[axes[i]] == Dim{1}))
0203 throw std::runtime_error("TMVA Squeeze Op : Invalid axis value " + std::to_string(axes[i]) +
0204 " for " + ConvertShapeToString(output_shape));
0205 output_shape.erase(output_shape.begin() + axes[i]);
0206 }
0207 }
0208 ret.push_back(output_shape);
0209 }
0210 else if (fOpMode == Unsqueeze) {
0211
0212 std::cout << "doing unsqueeze....\n";
0213 assert(!fAttrAxes.empty());
0214 auto output_shape = input_shape;
0215 auto &axes = fAttrAxes;
0216
0217 int64_t r = input[0].size() + axes.size();
0218 for (auto &a : axes) {
0219 int64_t i = static_cast<int64_t>(a);
0220 if (i < -r || i > r - 1)
0221 throw std::runtime_error("TMVA Unsqueeze Op - axes input is not in correct range");
0222 if (i >= 0)
0223 output_shape.insert(output_shape.begin() + i, Dim{1});
0224 else
0225
0226 output_shape.insert(output_shape.end() + i + 1, Dim{1});
0227 }
0228 ret.push_back(output_shape);
0229 }
0230 return ret;
0231 }
0232
0233 void Initialize(RModel& model) override {
0234
0235 std::cout << "initialize reshape op type " << fOpMode << " - " << fNInput2 << " " << fNData << std::endl;
0236 fVerbose = model.Verbose();
0237 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0238
0239 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
0240 }
0241 fShapeInput = model.GetDimTensorShape(fNData);
0242 fDimInput = model.IsDynamicTensor(fNData);
0243
0244 if (!fNInput2.empty()) {
0245 if (model.CheckIfTensorAlreadyExist(fNInput2)) {
0246 if (model.IsConstantTensor(fNInput2) || model.IsInitializedTensor(fNInput2)) {
0247
0248 auto dptr = model.GetInitializedTensorData(fNInput2);
0249 auto values = static_cast<int64_t *>(dptr.get());
0250 auto vec = model.GetTensorShape(fNInput2);
0251 size_t n = 1;
0252 if (vec.size() > 0)
0253 n = vec[0];
0254
0255 if (fOpMode == Reshape)
0256 fShape = std::vector<int64_t>(values, values + n);
0257 else
0258 fAttrAxes = std::vector<int64_t>(values, values + n);
0259
0260 fShapeOutput = ShapeInference({fShapeInput})[0];
0261
0262 model.SetNotWritableInitializedTensor(fNInput2);
0263 } else {
0264
0265 fDynamicShape = true;
0266
0267 auto shapeInput2 = model.GetTensorShape(fNInput2);
0268 fShapeOutput.resize(shapeInput2[0]);
0269 for (size_t i = 0; i < fShapeOutput.size(); i++) {
0270 fShapeOutput[i] = Dim{ std::string("s_") + fNOutput + "_" + std::to_string(i)};
0271 }
0272 }
0273 } else {
0274 throw std::runtime_error("TMVA Reshape Op 2nd input Tensor " + fNInput2 + " is not found in model");
0275 }
0276 } else if (!fAttrAxes.empty()) {
0277
0278 std::cout << "attribute axes exists\n";
0279 fShapeOutput = ShapeInference({fShapeInput})[0];
0280 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0281 fShapeOutput = ShapeInference({fShapeInput})[0];
0282 } else {
0283 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0284 }
0285
0286 if (model.IsInitializedTensor(fNData) && model.GetTensorType(fNData) == ETensorType::INT64) {
0287 fIsOutputConstant = true;
0288 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNData).get());
0289 auto o_shape = ConvertShapeToInt(fShapeOutput);
0290 if (ConvertShapeToLength(ConvertShapeToInt(fShapeInput)) != ConvertShapeToLength(o_shape) )
0291 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Output lengths");
0292 model.AddConstantTensor<int64_t>(fNOutput, o_shape, inputData);
0293 if (model.Verbose()) {
0294 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> " << fNOutput << " (constant) " << ConvertShapeToString(fShapeOutput) << " : " <<
0295 ConvertValuesToString(ConvertShapeToLength(o_shape), inputData) << std::endl;
0296 }
0297 }
0298
0299 else if (model.IsShapeTensor(fNData) && fShapeOutput.size() <=1) {
0300 fIsOutputConstant = true;
0301 auto inputData = model.GetShapeTensorValues(fNData);
0302 model.AddShapeTensor(fNOutput, inputData);
0303 if (model.Verbose()) {
0304 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> " << fNOutput << " (shape) " << ConvertShapeToString(fShapeOutput) << " : " <<
0305 ConvertShapeToString(inputData) << std::endl;
0306 }
0307 }
0308 else {
0309
0310 model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0311 if (model.Verbose())
0312 std::cout << Name() << " : " << fNData << " " << ConvertShapeToString(fShapeInput) << " --> "<< fNOutput << " " << ConvertShapeToString(fShapeOutput) << std::endl;
0313 }
0314 }
0315
0316 std::string Generate(std::string opName) override {
0317 if (fIsOutputConstant) return "";
0318
0319 std::stringstream out;
0320 std::string opType = "Reshape";
0321 if (fOpMode == Flatten)
0322 opType = "Flatten";
0323 else if (fOpMode == Squeeze)
0324 opType = "Squeeze";
0325 else if (fOpMode == Unsqueeze)
0326 opType = "Unsquueze";
0327
0328 out << SP << "///--------" << opType << " operator " << opName << " --> " << ConvertShapeToString(fShapeOutput) << "\n";
0329
0330
0331
0332 if (fDynamicShape) {
0333 for (size_t i = 0; i < fShapeOutput.size(); i++) {
0334
0335 out << SP << "size_t " << fShapeOutput[i].param << " = " << "tensor_" << fNInput2 << "[" << i << "];\n";
0336 if (!fAllowZero)
0337 out << SP << "if (tensor_" << fNInput2 << "[" << i << "] <= 0 ) "
0338 << fShapeOutput[i].param << " = " << fShapeInput[i] << ";\n";
0339 }
0340 }
0341
0342
0343 auto lengthOut = ConvertDimShapeToLength(fShapeOutput);
0344 auto lengthIn = ConvertDimShapeToLength(fShapeInput);
0345 if (lengthOut != lengthIn) {
0346
0347 out << SP << "if (" << lengthOut << "!=" << lengthIn << ")\n";
0348 out << "throw std::runtime_error(\"TMVA SOFIE Reshape Op : output lengths is different than input one\");\n";
0349 }
0350
0351
0352 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << lengthIn << ", " << "tensor_" << fNOutput
0353 << ");\n";
0354 return out.str();
0355 }
0356 };
0357
0358 }
0359 }
0360 }
0361
0362
0363 #endif