Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:59

0001 #ifndef TMVA_SOFIE_ROPERATOR_GATHER
0002 #define TMVA_SOFIE_ROPERATOR_GATHER
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 #include <stdexcept>
0010 #include <string>
0011 
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015 
0016 template <typename T>
0017 class ROperator_Gather final : public ROperator
0018 {
0019 private:
0020 
0021    int64_t fAttrAxis = 0;
0022 
0023    std::string fNX;
0024    std::string fNIndices;
0025    std::string fNY;
0026 
0027    std::vector<size_t> fShapeX;
0028    std::vector<size_t> fShapeIndices;
0029    std::vector<size_t> fShapeY;
0030 
0031    std::vector<int64_t> fIndices;
0032 
0033    std::string fType;
0034 
0035 public:
0036    ROperator_Gather(){}
0037    ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
0038       fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
0039    }
0040 
0041    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0042       return input;
0043    }
0044 
0045    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0046       auto ret = input;
0047       return ret;
0048    }
0049 
0050    void Initialize(RModel& model) {
0051       if (!model.CheckIfTensorAlreadyExist(fNX)) {
0052          throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
0053       }
0054       fShapeX = model.GetTensorShape(fNX);
0055       if (!model.IsInitializedTensor(fNIndices)) {
0056          throw
0057             std::runtime_error("TMVA::SOFIE - Tensor " + fNIndices + " is not initialized.");
0058       }
0059       int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
0060       fShapeIndices = model.GetTensorShape(fNIndices);
0061       size_t q = fShapeIndices.size();
0062       // Axis in range [0, r) where r=rank(X)
0063       size_t r = fShapeX.size();
0064       // Set the axis
0065       if (fAttrAxis < 0) {
0066          fAttrAxis = fAttrAxis + int64_t(r);
0067       }
0068       // Indices of size q
0069       // empty fShapeIndices is a scalar value for the indices
0070       size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0071       fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
0072       for (size_t i = 0; i < indicesLength; i++) {
0073          if (fIndices[i] < 0) {
0074             fIndices[i] += fShapeX[fAttrAxis];
0075          }
0076       }
0077       // Output shape
0078       if (fShapeY.empty()) {
0079          fShapeY.resize(q + r - 1);
0080          if (fAttrAxis > 0) {
0081             // Copy shape of X[0, ..., axis) to Shape of Y[0, ..., axis)
0082             std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
0083          }
0084          // Set shape of Y[axis, ..., axis + q)
0085          for (size_t i = 0; i < q; i++) {
0086             fShapeY[fAttrAxis + i] = fShapeIndices[i];
0087          }
0088          // Copy shape of X[axis + 1, ..., axis + r) to shape of Y[axis + q, ... q + r - 1)
0089          std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
0090       }
0091       // Add output tensor
0092       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0093       fType = ConvertTypeToString(model.GetTensorType(fNX));
0094    }
0095 
0096    std::string Generate(std::string OpName) {
0097       OpName = "op_" + OpName;
0098       std::stringstream out;
0099       // The shape of the output is q + r - 1
0100       size_t r = fShapeX.size();
0101       // Indices of shape q
0102       size_t q = fShapeIndices.size();
0103       // Strides
0104       std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
0105       std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
0106       std::vector<size_t> stridesIndices = UTILITY::ComputeStrideFromShape(fShapeIndices);
0107       // Indices vector
0108       out << SP << "std::vector<int64_t> " << OpName << "_indices = {";
0109       size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0110       for (size_t i = 0; i < indicesLength; i++) {
0111          out << fIndices[i] << (i + 1 < indicesLength? ", " : "};\n");
0112       }
0113       // Fill the output Y[j_0, j_1, ..., j_{axis - 1}, i_0, i_1, ..., i_{q - 1}, j_{axis + 1}, ..., j_{r - 1}]
0114       // [0 ... axis) [axis ... axis + q) [axis + q ... q + r - 1)
0115       // iterate in [0 ... axis) [0 ... q) [axis ... r - 1)
0116       // for j_0, j_1, ..., j_{axis-1}
0117       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0118          std::string index = "j_" + std::to_string(j);
0119          out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
0120       }
0121       // for i_0, i_1, ..., i_{q - 1}
0122       for (size_t i = 0; i < q; i++) {
0123          std::string index = "i_" + std::to_string(i);
0124          out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
0125       }
0126       // for j_axis, j_{axis + 1}, ..., j_{r - 1}
0127       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0128          std::string index = "j_" + std::to_string(j);
0129          out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
0130       }
0131 
0132       out << SP << SP << SP << "size_t y_index = 0;\n";
0133       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0134          out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
0135       }
0136       for (size_t i = 0; i < q; i++) {
0137          out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
0138       }
0139       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0140          out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
0141       }
0142       // Indices
0143       out << SP << SP << SP << "size_t i_index = 0;\n";
0144       for (size_t i = 0; i < q; i++) {
0145          out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
0146       }
0147       // K
0148       out << SP << SP << SP << "size_t k = static_cast<size_t>(" << OpName << "_indices[i_index]" << ");\n";
0149       // Input
0150       out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
0151       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0152          out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
0153       }
0154       for (size_t j = fAttrAxis + 1; j < r; j++) {
0155          out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
0156       }
0157       out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
0158 
0159       // end loops j_k, j_{k + 1}, ..., j_{r - 2}
0160       for (size_t j = fAttrAxis; j + 1 < r; j++) {
0161          out << SP << SP << SP << "}\n";
0162       }
0163       // end loops i_0, i_1, ..., i_{q - 1}
0164       for (size_t i = 0; i < q; i++) {
0165          out << SP << SP << "}\n";
0166       }
0167       // end loops j_0, j_1, ..., j_{axis - 1}
0168       for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0169          out << SP << "}\n";
0170       }
0171 
0172       return out.str();
0173    }
0174 
0175 };
0176 
0177 }//SOFIE
0178 }//Experimental
0179 }//TMVA
0180 
0181 
0182 #endif //TMVA_SOFIE_ROPERATOR_RELU