File indexing completed on 2025-09-12 09:09:38
0001 #ifndef TMVA_SOFIE_ROPERATOR_GATHER
0002 #define TMVA_SOFIE_ROPERATOR_GATHER
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009 #include <stdexcept>
0010 #include <string>
0011
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015
0016 class ROperator_Gather final : public ROperator
0017 {
0018 private:
0019
0020 int64_t fAttrAxis = 0;
0021
0022 std::string fNX;
0023 std::string fNIndices;
0024 std::string fNY;
0025
0026 std::vector<size_t> fShapeX;
0027 std::vector<size_t> fShapeIndices;
0028 std::vector<size_t> fShapeY;
0029
0030 std::vector<int64_t> fIndices;
0031
0032 std::string fType;
0033
0034 public:
0035 ROperator_Gather(){}
0036 ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
0037 fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
0038 fInputTensorNames = { fNX, fNIndices };
0039 fOutputTensorNames = { fNY };
0040 }
0041
0042 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0043 return input;
0044 }
0045
0046 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0047 auto ret = input;
0048 return ret;
0049 }
0050
0051 void Initialize(RModel& model) override {
0052 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0053 throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
0054 }
0055 fShapeX = model.GetTensorShape(fNX);
0056 fShapeIndices = model.GetTensorShape(fNIndices);
0057 size_t q = fShapeIndices.size();
0058
0059 size_t r = fShapeX.size();
0060
0061 if (fAttrAxis < 0) {
0062 fAttrAxis = fAttrAxis + int64_t(r);
0063 }
0064
0065 size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0066
0067
0068 if (model.IsInitializedTensor(fNIndices)) {
0069 int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
0070
0071 model.SetNotWritableInitializedTensor(fNIndices);
0072
0073 for (size_t i = 0; i < indicesLength; i++) {
0074 if (indicesData[i] < 0) {
0075 indicesData[i] += fShapeX[fAttrAxis];
0076 }
0077 }
0078
0079 fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
0080 }
0081
0082 if (model.Verbose())
0083 std::cout << "Gather: q and r " << q << " " << r << " shape indices " << ConvertShapeToString(fShapeIndices) << std::endl;
0084
0085 if (fShapeY.empty()) {
0086 fShapeY.resize(q + r - 1);
0087 if (fAttrAxis > 0) {
0088
0089 std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
0090 }
0091
0092 for (size_t i = 0; i < q; i++) {
0093 fShapeY[fAttrAxis + i] = fShapeIndices[i];
0094 }
0095
0096 std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
0097 }
0098
0099 if (model.IsInitializedTensor(fNX) && q <= 1 && r == 1 && fIndices.size() > 0) {
0100 if (model.GetTensorType(fNX) == ETensorType::INT64) {
0101 auto inputData = static_cast<int64_t*>(model.GetInitializedTensorData(fNX).get());
0102
0103 std::vector<int64_t> outputData(ConvertShapeToLength(fShapeY));
0104 outputData[0] = inputData[fIndices[0]];
0105 model.AddConstantTensor(fNY, fShapeY, outputData.data());
0106 if (model.Verbose())
0107 std::cout << "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
0108 << " and values " << ConvertValuesToString(outputData) << " (constant) " << std::endl;
0109 fIsOutputConstant = true;
0110 }
0111 }
0112 if (!fIsOutputConstant) {
0113
0114 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0115 fType = ConvertTypeToString(model.GetTensorType(fNX));
0116 if (model.Verbose())
0117 std::cout << "Gather: " << fNX << " " << ConvertShapeToString(fShapeX) << " -> " << fNY << " with shape " << ConvertShapeToString(fShapeY)
0118 << std::endl;
0119 }
0120 }
0121
0122 std::string Generate(std::string OpName) override {
0123 if (fIsOutputConstant) {
0124
0125 return "//---------------------------------------\n";
0126 }
0127 OpName = "op_" + OpName;
0128 std::stringstream out;
0129 out << "//--------- Gather operator \n";
0130
0131 size_t r = fShapeX.size();
0132
0133 size_t q = fShapeIndices.size();
0134
0135 std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
0136 std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
0137 std::vector<size_t> stridesIndices = UTILITY::ComputeStrideFromShape(fShapeIndices);
0138
0139
0140 if (fIndices.empty()) {
0141 size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0142 out << SP << "// correct in case of negative gather indices\n";
0143 out << SP << "for (size_t i = 0; i < " << indicesLength << "; i++){\n";
0144 out << SP << SP << "if (tensor_" << fNIndices << "[i] < 0)\n";
0145 out << SP << SP << SP << "tensor_" << fNIndices << "[i] += " << fShapeX[fAttrAxis] << ";\n";
0146 out << SP << "}\n";
0147 }
0148
0149
0150
0151
0152
0153
0154 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0155 std::string index = "j_" + std::to_string(j);
0156 out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
0157 }
0158
0159 if (q == 0)
0160 out << SP << SP << "{\n";
0161 for (size_t i = 0; i < q; i++) {
0162 std::string index = "i_" + std::to_string(i);
0163 out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
0164 }
0165
0166 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0167 std::string index = "j_" + std::to_string(j);
0168 out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
0169 }
0170
0171 out << SP << SP << SP << "size_t y_index = 0;\n";
0172 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0173 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
0174 }
0175 for (size_t i = 0; i < q; i++) {
0176 out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
0177 }
0178 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0179 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
0180 }
0181
0182 out << SP << SP << SP << "size_t i_index = 0;\n";
0183 for (size_t i = 0; i < q; i++) {
0184 out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
0185 }
0186
0187 out << SP << SP << SP << "size_t k = static_cast<size_t>(" << "tensor_" << fNIndices << "[i_index]" << ");\n";
0188
0189 out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
0190 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0191 out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
0192 }
0193 for (size_t j = fAttrAxis + 1; j < r; j++) {
0194 out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
0195 }
0196 out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
0197
0198
0199 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0200 out << SP << SP << SP << "}\n";
0201 }
0202
0203 if (q == 0)
0204 out << SP << SP << "}\n";
0205 for (size_t i = 0; i < q; i++) {
0206 out << SP << SP << "}\n";
0207 }
0208
0209 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0210 out << SP << "}\n";
0211 }
0212
0213 return out.str();
0214 }
0215
0216 };
0217
0218 }
0219 }
0220 }
0221
0222
0223 #endif