Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-12 09:09:38

0001 #ifndef TMVA_SOFIE_ROPERATOR_GATHER
0002 #define TMVA_SOFIE_ROPERATOR_GATHER
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 #include <stdexcept>
0010 #include <string>
0011 
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015 
0016 class ROperator_Gather final : public ROperator
0017 {
0018 private:
0019 
0020    int64_t fAttrAxis = 0;
0021 
0022    std::string fNX;
0023    std::string fNIndices;
0024    std::string fNY;
0025 
0026    std::vector<size_t> fShapeX;
0027    std::vector<size_t> fShapeIndices;
0028    std::vector<size_t> fShapeY;
0029 
0030    std::vector<int64_t> fIndices;  // indices vector in case they are known at initialization
0031 
0032    std::string fType;
0033 
0034 public:
0035    ROperator_Gather(){}
0036    ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
0037       fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
0038          fInputTensorNames = { fNX, fNIndices };
0039          fOutputTensorNames = { fNY };
0040    }
0041 
0042    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0043       return input;
0044    }
0045 
0046    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0047       auto ret = input;
0048       return ret;
0049    }
0050 
0051    void Initialize(RModel& model) override {
0052       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0053          throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
0054       }
0055       fShapeX = model.GetTensorShape(fNX);
0056       fShapeIndices = model.GetTensorShape(fNIndices);
0057       size_t q = fShapeIndices.size();
0058       // Axis in range [0, r) where r=rank(X)
0059       size_t r = fShapeX.size();
0060        // Set the axis
0061       if (fAttrAxis < 0) {
0062          fAttrAxis = fAttrAxis + int64_t(r);
0063       }
0064       // empty fShapeIndices is a scalar value for the indices
0065       size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0066 
0067       // case indices tensor is initialized
0068       if (model.IsInitializedTensor(fNIndices)) {
0069          int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
0070          //flag index tensor as not writable (not sure this is needed since index tensor might be used in generated code)
0071          model.SetNotWritableInitializedTensor(fNIndices);
0072          // update indices data in case of negative dim values
0073          for (size_t i = 0; i < indicesLength; i++) {
0074             if (indicesData[i] < 0) {
0075                indicesData[i] += fShapeX[fAttrAxis];
0076             }
0077          }
0078          // Save in a vector gather Indices of size q
0079          fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
0080       }
0081       // Output shape
0082       if (model.Verbose())
0083          std::cout << "Gather: q and r " << q << " " << r << " shape indices " << ConvertShapeToString(fShapeIndices) << std::endl;
0084 
0085       if (fShapeY.empty()) {
0086          fShapeY.resize(q + r - 1);
0087          if (fAttrAxis > 0) {
0088             // Copy shape of X[0, ..., axis) to Shape of Y[0, ..., axis)
0089             std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
0090          }
0091          // Set shape of Y[axis, ..., axis + q)
0092          for (size_t i = 0; i < q; i++) {
0093             fShapeY[fAttrAxis + i] = fShapeIndices[i];
0094          }
0095          // Copy shape of X[axis + 1, ..., axis + r) to shape of Y[axis + q, ... q + r - 1)
0096          std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
0097       }
0098       // case input is known (type is an integer) and input indices is a scalar (or vector of size 1)
0099       if (model.IsInitializedTensor(fNX) && q <= 1 && r == 1 && fIndices.size() > 0) {
0100          if (model.GetTensorType(fNX) == ETensorType::INT64) {
0101             auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNX).get());
0102             // if q <=1 and r = 1 output length = 1 (it is a scalar)
0103             std::vector<int64_t> outputData(ConvertShapeToLength(fShapeY));
0104             outputData[0] = inputData[fIndices[0]];
0105             model.AddConstantTensor(fNY, fShapeY, outputData.data());
0106             if (model.Verbose())
0107                std::cout << "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
0108                    << " and values " << ConvertValuesToString(outputData) << " (constant) " << std::endl;
0109             fIsOutputConstant = true;
0110          }
0111       }
0112       if (!fIsOutputConstant) {
0113          // Add output tensor
0114          model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0115          fType = ConvertTypeToString(model.GetTensorType(fNX));
0116          if (model.Verbose())
0117                std::cout <<  "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
0118                   << std::endl;
0119       }
0120    }
0121 
0122    std::string Generate(std::string OpName) override {
0123       if (fIsOutputConstant) {
0124          // no code to generate here for constant output. Tensor output is defined in Session constructor
0125          return "//---------------------------------------\n";
0126       }
0127       OpName = "op_" + OpName;
0128       std::stringstream out;
0129       out << "//--------- Gather operator \n";
0130       // The shape of the output is q + r - 1
0131       size_t r = fShapeX.size();
0132       // Indices of shape q
0133       size_t q = fShapeIndices.size();
0134       // Strides
0135       std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
0136       std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
0137       std::vector<size_t> stridesIndices = UTILITY::ComputeStrideFromShape(fShapeIndices);
0138 
0139       // case fIndices is not known we need to correct for negative axis indices at run-time
0140       if (fIndices.empty()) {
0141          size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0142          out << SP << "// correct in case of negative gather indices\n";
0143          out << SP << "for (size_t i = 0; i < " << indicesLength << "; i++){\n";
0144          out << SP << SP << "if (tensor_" << fNIndices << "[i] < 0)\n";
0145          out << SP << SP << SP <<  "tensor_" << fNIndices << "[i] += " << fShapeX[fAttrAxis] << ";\n";
0146          out << SP << "}\n";
0147       }
0148 
0149 
0150       // Fill the output Y[j_0, j_1, ..., j_{axis - 1}, i_0, i_1, ..., i_{q - 1}, j_{axis + 1}, ..., j_{r - 1}]
0151       // [0 ... axis) [axis ... axis + q) [axis + q ... q + r - 1)
0152       // iterate in [0 ... axis) [0 ... q) [axis ... r - 1)
0153       // for j_0, j_1, ..., j_{axis-1}
0154       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0155          std::string index = "j_" + std::to_string(j);
0156          out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
0157       }
0158       // for i_0, i_1, ..., i_{q - 1}
0159       if (q == 0)
0160          out << SP << SP << "{\n";  // add a scope for local variables
0161       for (size_t i = 0; i < q; i++) {
0162          std::string index = "i_" + std::to_string(i);
0163          out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
0164       }
0165       // for j_axis, j_{axis + 1}, ..., j_{r - 1}
0166       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0167          std::string index = "j_" + std::to_string(j);
0168          out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
0169       }
0170 
0171       out << SP << SP << SP << "size_t y_index = 0;\n";
0172       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0173          out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
0174       }
0175       for (size_t i = 0; i < q; i++) {
0176          out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
0177       }
0178       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0179          out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
0180       }
0181       // Indices
0182       out << SP << SP << SP << "size_t i_index = 0;\n";
0183       for (size_t i = 0; i < q; i++) {
0184          out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
0185       }
0186       // K
0187       out << SP << SP << SP << "size_t k = static_cast<size_t>(" << "tensor_" << fNIndices << "[i_index]" << ");\n";
0188       // Input
0189       out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
0190       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0191          out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
0192       }
0193       for (size_t j = fAttrAxis + 1; j < r; j++) {
0194          out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
0195       }
0196       out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
0197 
0198       // end loops j_k, j_{k + 1}, ..., j_{r - 2}
0199       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0200          out << SP << SP << SP << "}\n";
0201       }
0202       // end loops i_0, i_1, ..., i_{q - 1}
0203       if (q == 0)
0204          out << SP << SP << "}\n";  // end of scope for q = 0
0205       for (size_t i = 0; i < q; i++) {
0206          out << SP << SP << "}\n";
0207       }
0208       // end loops j_0, j_1, ..., j_{axis - 1}
0209       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0210          out << SP << "}\n";
0211       }
0212 
0213       return out.str();
0214    }
0215 
0216 };
0217 
0218 }//SOFIE
0219 }//Experimental
0220 }//TMVA
0221 
0222 
0223 #endif //TMVA_SOFIE_ROPERATOR_RELU