File indexing completed on 2025-01-18 10:11:08
0001 #ifndef TMVA_SOFIE_ROPERATOR_RESHAPE
0002 #define TMVA_SOFIE_ROPERATOR_RESHAPE
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <cassert>
0009 #include <sstream>
0010
0011 namespace TMVA{
0012 namespace Experimental{
0013 namespace SOFIE{
0014
0015 enum ReshapeOpMode { Reshape, Flatten, Squeeze, Unsqueeze };
0016
0017 template <typename T>
0018 class ROperator_Reshape final : public ROperator
0019 {
0020
0021 private:
0022
0023 ReshapeOpMode fOpMode = Reshape;
0024
0025 int fAllowZero = 0;
0026 int fAxis = 1;
0027
0028 std::string fNData;
0029 std::string fNShape;
0030 std::string fNOutput;
0031 std::vector<size_t> fShapeInput;
0032 std::vector<size_t> fShapeOutput;
0033 std::vector<int64_t> fAttrAxes;
0034
0035 public:
0036
0037 ROperator_Reshape(){}
0038 ROperator_Reshape(ReshapeOpMode opMode, int attr_value, std::string nameData, std::string nameShape, std::string nameOutput)
0039 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNShape(UTILITY::Clean_name(nameShape)),
0040 fNOutput(UTILITY::Clean_name(nameOutput))
0041 {
0042 if (opMode == Reshape) fAllowZero = attr_value;
0043 if (opMode == Flatten) fAxis = attr_value;
0044 }
0045
0046
0047
0048 ROperator_Reshape(ReshapeOpMode opMode, std::vector<int64_t> attrAxes, std::string nameData, std::string nameOutput)
0049 : fOpMode(opMode), fNData(UTILITY::Clean_name(nameData)), fNOutput(UTILITY::Clean_name(nameOutput)),
0050 fAttrAxes(attrAxes)
0051 {
0052 assert(fOpMode == Squeeze || fOpMode == Unsqueeze);
0053 }
0054
0055
0056 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0057 auto ret = std::vector<ETensorType>(1, input[0]);
0058 return ret;
0059 }
0060
0061
0062 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0063 std::vector<std::vector<size_t>> ret;
0064 auto & input_shape = input[0];
0065
0066 if (fOpMode == Reshape) {
0067 if (input.size() != 2) throw std::runtime_error("TMVA SOFIE Reshape Op needs 2 input tensors");
0068 auto output_shape = input[1];
0069 size_t input_length = ConvertShapeToLength(input_shape);
0070 size_t output_length = ConvertShapeToLength(output_shape);
0071
0072 if (input_length != output_length) {
0073 if (output_shape.size() > 1 && ((output_length == 0 && fAllowZero == 0) || output_length > INT64_MAX)) {
0074
0075 for (size_t i = 0; i < output_shape.size(); i++) {
0076 if (output_shape[i] == 0 || output_shape[i] == static_cast<size_t>(-1)) {
0077 auto tmp = output_shape;
0078 tmp.erase(tmp.begin() + i);
0079 auto tmp_length = ConvertShapeToLength(tmp);
0080 output_shape[i] = input_length / tmp_length;
0081 break;
0082 }
0083 }
0084 }
0085 if (ConvertShapeToLength(output_shape) != input_length) {
0086 throw std::runtime_error("TMVA Reshape Op : Invalid shapes : " + ConvertShapeToString(input_shape) +
0087 ConvertShapeToString(output_shape));
0088 }
0089 }
0090 ret.push_back(output_shape);
0091
0092 } else if (fOpMode == Flatten) {
0093
0094 size_t inputSize = ConvertShapeToLength(input_shape);
0095 size_t b = input[0][0];
0096 std::vector<size_t> newShape = {b, inputSize / b};
0097 ret.push_back(newShape);
0098
0099 } else if (fOpMode == Squeeze) {
0100
0101
0102 auto output_shape = input[0];
0103 if (input.size() == 1) {
0104 size_t i = 0;
0105 while (i < output_shape.size()) {
0106 if (output_shape[i] == 1 ) {
0107 output_shape.erase(output_shape.begin() + i);
0108 }
0109 else {
0110 i++;
0111 }
0112 }
0113 } else if (input.size() == 2) {
0114 auto & axes = input[1];
0115 for (size_t i = 0; i < axes.size(); i++){
0116 if (output_shape[axes[i]] != 1)
0117 throw std::runtime_error("TMVA Squeeze Op : Invalid axes : " + ConvertShapeToString(axes) +
0118 ConvertShapeToString(output_shape));
0119 output_shape.erase(output_shape.begin() + axes[i]);
0120 }
0121 }
0122 ret.push_back(output_shape);
0123 }
0124
0125 else if (fOpMode == Unsqueeze) {
0126
0127 assert(input.size() == 2);
0128 auto output_shape = input[0];
0129 auto &axes = input[1];
0130 if (axes[0] > 0) {
0131 for (auto & i : axes)
0132 output_shape.insert(output_shape.begin() + i, 1);
0133 } else {
0134
0135 for (auto &i : axes) {
0136 assert(i < 0);
0137 output_shape.insert(output_shape.begin() + (output_shape.size() + i - 1), 1);
0138 }
0139 }
0140 ret.push_back(output_shape);
0141 }
0142 return ret;
0143 }
0144
0145 void Initialize(RModel &model)
0146 {
0147
0148 if (model.CheckIfTensorAlreadyExist(fNData) == false) {
0149
0150 throw std::runtime_error("TMVA Reshape Op Input Tensor " + fNData + " is not found in model");
0151 }
0152 fShapeInput = model.GetTensorShape(fNData);
0153
0154 if (!fNShape.empty()) {
0155 if (model.CheckIfTensorAlreadyExist(fNShape)) {
0156 auto dptr = model.GetInitializedTensorData(fNShape);
0157 auto input_shape = static_cast<int64_t *>(dptr.get());
0158 auto vec = model.GetTensorShape(fNShape);
0159 assert(vec.size() == 1);
0160 size_t n = vec[0];
0161
0162 std::vector<size_t> descShape(n);
0163 std::copy(input_shape, input_shape + n, descShape.begin());
0164 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0165 } else {
0166 throw std::runtime_error("TMVA Reshape Op Shape Tensor " + fNShape + " is not found in model");
0167 }
0168 } else if (!fAttrAxes.empty()) {
0169
0170 std::vector<size_t> descShape(fAttrAxes.size());
0171 std::copy(fAttrAxes.begin(), fAttrAxes.end(), descShape.begin());
0172 fShapeOutput = ShapeInference({fShapeInput, descShape})[0];
0173 } else if (fOpMode == Flatten || fOpMode == Squeeze) {
0174 fShapeOutput = ShapeInference({fShapeInput})[0];
0175 } else {
0176 throw std::runtime_error("TMVA Reshape Op : Invalid Input/Attribute data");
0177 }
0178 model.AddIntermediateTensor(fNOutput, model.GetTensorType(fNData), fShapeOutput);
0179 }
0180
0181 std::string Generate(std::string OpName)
0182 {
0183 OpName = "op_" + OpName;
0184 if (fShapeInput.empty() || fShapeOutput.empty()) {
0185 throw std::runtime_error("TMVA SOFIE Reshape Op called to Generate without being initialized first");
0186 }
0187
0188
0189 size_t length = ConvertShapeToLength(fShapeOutput);
0190 if (length != ConvertShapeToLength(fShapeInput)) {
0191 throw std::runtime_error("TMVA SOFIE Reshape Op : wrong output shape - is " +
0192 ConvertShapeToString(fShapeOutput) + " and input is " +
0193 ConvertShapeToString(fShapeInput));
0194 }
0195 std::stringstream out;
0196 std::string opName = "Reshape";
0197 if (fOpMode == Flatten)
0198 opName = "Flatten";
0199 else if (fOpMode == Squeeze)
0200 opName = "Squeeze";
0201 else if (fOpMode == Unsqueeze)
0202 opName = "Unsquueze";
0203
0204 out << SP << "///--------" << opName << " operator\n" << std::endl;
0205 out << SP << "std::copy( tensor_" << fNData << ", tensor_" << fNData << " + " << length << ", " << "tensor_" << fNOutput
0206 << ");\n";
0207 return out.str();
0208 }
0209 };
0210
0211 }
0212 }
0213 }
0214
0215
0216 #endif