Warning, file /include/root/TMVA/ROperator_ScatterElements.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROperator_ScatterElements
0002 #define TMVA_SOFIE_ROperator_ScatterElements
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013
0014
0015 class ROperator_ScatterElements final : public ROperator{
0016 private:
0017
0018 int64_t fAxis;
0019
0020 std::string fNX;
0021 std::string fNI;
0022 std::string fNU;
0023 std::string fNY;
0024 std::string fReduction;
0025
0026 std::vector<size_t> fShapeX;
0027 std::vector<size_t> fShapeI;
0028 std::vector<size_t> fShapeY;
0029
0030
0031
0032 std::string ReductionFunction(const std::string & t1, const std::string & t2 ) {
0033 std::string name = fReduction;
0034 if (name.empty() || name == "none")
0035 return t2;
0036 else if (name == "add")
0037 return t1 + " + " + t2;
0038 else if (name == "mul")
0039 return t1 + " * " + t2;
0040 else if (name == "max")
0041 return "std::max(" + t1 + "," + t2 + ")";
0042 else if (name == "min")
0043 return "std::min(" + t1 + "," + t2 + ")";
0044 else
0045 throw std::runtime_error("TMVA SOFIE ScatterElements : invalid reduction attribute");
0046
0047 return std::string();
0048 }
0049
0050 public:
0051 ROperator_ScatterElements(){}
0052 ROperator_ScatterElements(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY,
0053 int axis, std::string reduction):
0054 fAxis(axis),
0055 fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)),
0056 fNY(UTILITY::Clean_name(nameY)),
0057 fReduction(reduction)
0058 {
0059 fInputTensorNames = { fNX, fNI, fNU };
0060 fOutputTensorNames = { fNY };
0061 }
0062
0063
0064 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0065 return input;
0066 }
0067
0068
0069 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0070 auto ret = std::vector<std::vector<size_t>>(1, input[0]);
0071 return ret;
0072 }
0073
0074 void Initialize(RModel& model) override {
0075
0076 if (!model.CheckIfTensorAlreadyExist(fNX)){
0077 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNX + "is not found in model");
0078 }
0079 if (!model.CheckIfTensorAlreadyExist(fNI)) {
0080 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNI + "is not found in model");
0081 }
0082 if (!model.CheckIfTensorAlreadyExist(fNU)) {
0083 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNU + "is not found in model");
0084 }
0085
0086
0087 fShapeX = model.GetTensorShape(fNX);
0088 fShapeI = model.GetTensorShape(fNI);
0089 if (model.GetTensorShape(fNU) != fShapeI)
0090 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - update tensor has invalid shape ")) ;
0091 if (fShapeX.size() == 0)
0092 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - input tensor has zero rank ")) ;
0093 if (fShapeX.size() != fShapeI.size())
0094 throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - index tensor has invalid rank ")) ;
0095
0096 if (fAxis < 0) fAxis += fShapeX.size();
0097
0098
0099 fShapeY = fShapeX;
0100 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0101 }
0102
0103 std::string GenerateInitCode() override {
0104 std::stringstream out;
0105 return out.str();
0106 }
0107
0108 std::string Generate(std::string opName) override {
0109
0110 if (fIsOutputConstant) return "";
0111
0112 if (fShapeY.empty()) {
0113 throw std::runtime_error("TMVA SOFIE ScatterElements Op called to Generate without being initialized first");
0114 }
0115 std::stringstream out;
0116 out << SP << "\n//-------- ScatterElements --- " << opName << "\n";
0117
0118 auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
0119 auto strideI = UTILITY::ComputeStrideFromShape(fShapeI);
0120
0121 size_t length = ConvertShapeToLength(fShapeY);
0122
0123
0124 auto tensorIndex = [](const std::vector<size_t> & stride, const std::vector<std::string> & idx) {
0125 std::stringstream strst;
0126 int dims = idx.size();
0127 assert (dims == (int) stride.size());
0128 for (int i = 0; i < dims; i++) {
0129 if (stride[i] != 1)
0130 strst << stride[i] << "*" << idx[i];
0131 else
0132 strst << idx[i];
0133 if (i < dims-1)
0134 strst << " + ";
0135 }
0136 return strst.str();
0137 };
0138
0139
0140
0141 out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
0142
0143
0144 int dims = fShapeY.size();
0145 std::vector<std::string> idx(dims);
0146 for (int i = 0; i < dims; i++) {
0147 idx[i] = std::string("i") + std::to_string(i);
0148 for (int j = 0; j <= i; j++) out << SP;
0149 out << "for (int " << idx[i] << " = 0; " << idx[i] << " < " << fShapeI[i] << "; " << idx[i] << "++) {\n";
0150 }
0151
0152 for (int j = 0; j <= dims; j++) out << SP;
0153 out << "int updateIndex = " << tensorIndex(strideI,idx) << ";\n";
0154 for (int j = 0; j <= dims; j++) out << SP;
0155 out << "int iAxis = tensor_" << fNI << "[updateIndex];\n";
0156 for (int j = 0; j <= dims; j++) out << SP;
0157 out << "if (iAxis < 0) iAxis += " << fShapeY[fAxis] << ";\n";
0158 idx[fAxis] = "iAxis";
0159 for (int j = 0; j <= dims; j++) out << SP;
0160 out << "int outIndex = " << tensorIndex(strideY, idx) << ";\n";
0161 for (int j = 0; j <= dims; j++) out << SP;
0162 out << "tensor_" << fNY << "[outIndex] = "
0163 << ReductionFunction(std::string("tensor_") + fNY + "[outIndex]", std::string("tensor_") + fNU + "[updateIndex]") << ";\n";
0164
0165 for (int i = dims; i > 0; i--) {
0166 for (int j = 0; j < i; j++) out << SP;
0167 out << "}\n";
0168 }
0169 return out.str();
0170 }
0171
0172 };
0173
0174 }
0175 }
0176 }
0177
0178
0179 #endif