File indexing completed on 2025-01-30 10:22:59
0001 #ifndef TMVA_SOFIE_ROPERATOR_GATHER
0002 #define TMVA_SOFIE_ROPERATOR_GATHER
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009 #include <stdexcept>
0010 #include <string>
0011
0012 namespace TMVA{
0013 namespace Experimental{
0014 namespace SOFIE{
0015
0016 template <typename T>
0017 class ROperator_Gather final : public ROperator
0018 {
0019 private:
0020
0021 int64_t fAttrAxis = 0;
0022
0023 std::string fNX;
0024 std::string fNIndices;
0025 std::string fNY;
0026
0027 std::vector<size_t> fShapeX;
0028 std::vector<size_t> fShapeIndices;
0029 std::vector<size_t> fShapeY;
0030
0031 std::vector<int64_t> fIndices;
0032
0033 std::string fType;
0034
0035 public:
0036 ROperator_Gather(){}
0037 ROperator_Gather(int64_t attrAxis, std::string nameX, std::string nameIndices, std::string nameY):
0038 fAttrAxis(attrAxis), fNX(UTILITY::Clean_name(nameX)), fNIndices(UTILITY::Clean_name(nameIndices)), fNY(UTILITY::Clean_name(nameY)) {
0039 }
0040
0041 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input){
0042 return input;
0043 }
0044
0045 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input){
0046 auto ret = input;
0047 return ret;
0048 }
0049
0050 void Initialize(RModel& model) {
0051 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0052 throw std::runtime_error("TMVA SOFIE Gather Op Input Tensor " + fNX + " is not found in model");
0053 }
0054 fShapeX = model.GetTensorShape(fNX);
0055 if (!model.IsInitializedTensor(fNIndices)) {
0056 throw
0057 std::runtime_error("TMVA::SOFIE - Tensor " + fNIndices + " is not initialized.");
0058 }
0059 int64_t* indicesData = static_cast<int64_t*>(model.GetInitializedTensorData(fNIndices).get());
0060 fShapeIndices = model.GetTensorShape(fNIndices);
0061 size_t q = fShapeIndices.size();
0062
0063 size_t r = fShapeX.size();
0064
0065 if (fAttrAxis < 0) {
0066 fAttrAxis = fAttrAxis + int64_t(r);
0067 }
0068
0069
0070 size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0071 fIndices = std::vector<int64_t>(indicesData, indicesData + indicesLength);
0072 for (size_t i = 0; i < indicesLength; i++) {
0073 if (fIndices[i] < 0) {
0074 fIndices[i] += fShapeX[fAttrAxis];
0075 }
0076 }
0077
0078 if (fShapeY.empty()) {
0079 fShapeY.resize(q + r - 1);
0080 if (fAttrAxis > 0) {
0081
0082 std::copy(fShapeX.begin(), fShapeX.begin() + fAttrAxis, fShapeY.begin());
0083 }
0084
0085 for (size_t i = 0; i < q; i++) {
0086 fShapeY[fAttrAxis + i] = fShapeIndices[i];
0087 }
0088
0089 std::copy(fShapeX.begin() + fAttrAxis + 1, fShapeX.end(), fShapeY.begin() + fAttrAxis + q);
0090 }
0091
0092 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0093 fType = ConvertTypeToString(model.GetTensorType(fNX));
0094 }
0095
0096 std::string Generate(std::string OpName) {
0097 OpName = "op_" + OpName;
0098 std::stringstream out;
0099
0100 size_t r = fShapeX.size();
0101
0102 size_t q = fShapeIndices.size();
0103
0104 std::vector<size_t> stridesX = UTILITY::ComputeStrideFromShape(fShapeX);
0105 std::vector<size_t> stridesY = UTILITY::ComputeStrideFromShape(fShapeY);
0106 std::vector<size_t> stridesIndices = UTILITY::ComputeStrideFromShape(fShapeIndices);
0107
0108 out << SP << "std::vector<int64_t> " << OpName << "_indices = {";
0109 size_t indicesLength = ConvertShapeToLength(fShapeIndices);
0110 for (size_t i = 0; i < indicesLength; i++) {
0111 out << fIndices[i] << (i + 1 < indicesLength? ", " : "};\n");
0112 }
0113
0114
0115
0116
0117 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0118 std::string index = "j_" + std::to_string(j);
0119 out << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[j] << "; " << index << "++) {\n";
0120 }
0121
0122 for (size_t i = 0; i < q; i++) {
0123 std::string index = "i_" + std::to_string(i);
0124 out << SP << SP << "for (size_t " << index << " = " << 0 << "; " << index << " < " << fShapeIndices[i] << "; " << index << "++) {\n";
0125 }
0126
0127 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0128 std::string index = "j_" + std::to_string(j);
0129 out << SP << SP << SP << "for (size_t " << index << " = 0; " << index << " < " << fShapeY[q + j] << "; " << index << "++) {\n";
0130 }
0131
0132 out << SP << SP << SP << "size_t y_index = 0;\n";
0133 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0134 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[j] << ";\n";
0135 }
0136 for (size_t i = 0; i < q; i++) {
0137 out << SP << SP << SP << "y_index += i_" + std::to_string(i) + " * " << stridesY[fAttrAxis + i] << ";\n";
0138 }
0139 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0140 out << SP << SP << SP << "y_index += j_" + std::to_string(j) + " * " << stridesY[q + j] << ";\n";
0141 }
0142
0143 out << SP << SP << SP << "size_t i_index = 0;\n";
0144 for (size_t i = 0; i < q; i++) {
0145 out << SP << SP << SP << "i_index += i_" + std::to_string(i) + " * " << stridesIndices[i] << ";\n";
0146 }
0147
0148 out << SP << SP << SP << "size_t k = static_cast<size_t>(" << OpName << "_indices[i_index]" << ");\n";
0149
0150 out << SP << SP << SP << "size_t x_index = k * " << stridesX[fAttrAxis] << ";\n";
0151 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0152 out << SP << SP << SP << "x_index += j_" + std::to_string(j) + " * " << stridesX[j] << ";\n";
0153 }
0154 for (size_t j = fAttrAxis + 1; j < r; j++) {
0155 out << SP << SP << SP << "x_index += j_" + std::to_string(j - 1) + " * " << stridesX[j] << ";\n";
0156 }
0157 out << SP << SP << SP << "tensor_" << fNY << "[y_index] = tensor_" << fNX << "[x_index];\n";
0158
0159
0160 for (size_t j = fAttrAxis; j + 1 < r; j++) {
0161 out << SP << SP << SP << "}\n";
0162 }
0163
0164 for (size_t i = 0; i < q; i++) {
0165 out << SP << SP << "}\n";
0166 }
0167
0168 for (size_t j = 0; j < size_t(fAttrAxis); j++) {
0169 out << SP << "}\n";
0170 }
0171
0172 return out.str();
0173 }
0174
0175 };
0176
0177 }
0178 }
0179 }
0180
0181
0182 #endif