Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_ScatterElements.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROperator_ScatterElements
0002 #define TMVA_SOFIE_ROperator_ScatterElements
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 
0015 class ROperator_ScatterElements final : public ROperator{
0016 private:
0017 
0018    int64_t fAxis;
0019 
0020    std::string fNX;
0021    std::string fNI;
0022    std::string fNU;
0023    std::string fNY;
0024    std::string fReduction;
0025 
0026    std::vector<size_t> fShapeX;
0027    std::vector<size_t> fShapeI;
0028    std::vector<size_t> fShapeY;
0029 
0030    // define reduction function. Possibilities are:
0031    // none (default), add, mul, max, min
0032    std::string ReductionFunction(const std::string & t1, const std::string & t2 ) {
0033       std::string name = fReduction;
0034       if (name.empty() || name == "none")
0035          return t2;
0036       else if (name == "add")
0037          return t1 + " + " + t2;
0038       else if (name == "mul")
0039          return t1 + " * " + t2;
0040       else if (name == "max")
0041          return "std::max(" + t1 + "," + t2 + ")";
0042       else if (name == "min")
0043          return "std::min(" + t1 + "," + t2 + ")";
0044       else
0045          throw std::runtime_error("TMVA SOFIE ScatterElements : invalid reduction attribute");
0046 
0047       return std::string();
0048    }
0049 
0050 public:
0051    ROperator_ScatterElements(){}
0052    ROperator_ScatterElements(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY,
0053                            int axis, std::string reduction):
0054       fAxis(axis),
0055       fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)),
0056       fNY(UTILITY::Clean_name(nameY)),
0057       fReduction(reduction)
0058       {
0059          fInputTensorNames = { fNX, fNI, fNU };
0060          fOutputTensorNames = { fNY };
0061       }
0062 
0063    // type of output given input
0064    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0065       return input;
0066    }
0067 
0068    // shape of output tensors given input tensors
0069    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0070       auto ret = std::vector<std::vector<size_t>>(1, input[0]); // return vector size 1 with first input
0071       return ret;
0072    }
0073 
0074    void Initialize(RModel& model) override {
0075       // input must be a graph input, or already initialized intermediate tensor
0076       if (!model.CheckIfTensorAlreadyExist(fNX)){
0077          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNX + "is not found in model");
0078       }
0079       if (!model.CheckIfTensorAlreadyExist(fNI)) {
0080          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNI + "is not found in model");
0081       }
0082       if (!model.CheckIfTensorAlreadyExist(fNU)) {
0083          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNU + "is not found in model");
0084       }
0085       //tbd check for constant tensors
0086 
0087       fShapeX = model.GetTensorShape(fNX);
0088       fShapeI = model.GetTensorShape(fNI);
0089       if (model.GetTensorShape(fNU) != fShapeI)
0090          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - update tensor has invalid shape ")) ;
0091       if (fShapeX.size() == 0)
0092          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - input tensor has zero rank  ")) ;
0093       if (fShapeX.size() != fShapeI.size())
0094          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - index tensor has invalid rank  ")) ;
0095 
0096       if (fAxis < 0) fAxis += fShapeX.size();
0097 
0098       // assume output shape is identical to input shape
0099       fShapeY = fShapeX;
0100       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0101    }
0102 
0103    std::string GenerateInitCode() override {
0104       std::stringstream out;
0105       return out.str();
0106    }
0107 
0108    std::string Generate(std::string opName) override {
0109 
0110       if (fIsOutputConstant) return "";
0111 
0112       if (fShapeY.empty()) {
0113          throw std::runtime_error("TMVA SOFIE ScatterElements Op called to Generate without being initialized first");
0114       }
0115       std::stringstream out;
0116       out << SP << "\n//-------- ScatterElements  --- " << opName << "\n";
0117 
0118       auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
0119       auto strideI = UTILITY::ComputeStrideFromShape(fShapeI);
0120 
0121       size_t length = ConvertShapeToLength(fShapeY);
0122 
0123       // function to write compute expression for global index from axes indices
0124       auto tensorIndex = [](const std::vector<size_t> & stride, const std::vector<std::string> & idx) {
0125          std::stringstream strst;
0126          int dims = idx.size();
0127          assert (dims == (int) stride.size());
0128          for (int i = 0; i < dims; i++) {
0129             if (stride[i] != 1)
0130                strst << stride[i] << "*" << idx[i];
0131             else
0132                strst << idx[i];
0133             if (i < dims-1)
0134                strst << " + ";
0135          }
0136          return strst.str();
0137       };
0138 
0139 
0140       // copy first input in output (maybe can be avoided??)
0141       out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
0142 
0143       // loop on tensor rank
0144       int dims = fShapeY.size();
0145       std::vector<std::string> idx(dims);
0146       for (int i = 0; i < dims; i++) {
0147          idx[i] = std::string("i") + std::to_string(i);
0148          for (int j = 0; j <= i; j++) out << SP;
0149          out << "for (int " << idx[i] << " = 0; " << idx[i] << " < " << fShapeI[i] << "; " << idx[i] << "++) {\n";
0150       }
0151       // correct index for specific axis
0152       for (int j = 0; j <= dims; j++) out << SP;
0153       out << "int updateIndex = " << tensorIndex(strideI,idx) << ";\n";
0154       for (int j = 0; j <= dims; j++) out << SP;
0155       out << "int iAxis = tensor_" << fNI << "[updateIndex];\n";
0156       for (int j = 0; j <= dims; j++) out << SP;
0157       out << "if (iAxis < 0) iAxis += " << fShapeY[fAxis] << ";\n";
0158       idx[fAxis] = "iAxis";
0159       for (int j = 0; j <= dims; j++) out << SP;
0160       out << "int  outIndex = " << tensorIndex(strideY, idx) << ";\n";
0161       for (int j = 0; j <= dims; j++) out << SP;
0162       out << "tensor_" << fNY << "[outIndex] = "
0163          << ReductionFunction(std::string("tensor_") + fNY + "[outIndex]", std::string("tensor_") + fNU + "[updateIndex]") << ";\n";
0164 
0165       for (int i = dims; i > 0; i--) {
0166          for (int j = 0; j < i; j++) out << SP;
0167          out << "}\n";
0168       }
0169       return out.str();
0170    }
0171 
0172 };
0173 
0174 }//SOFIE
0175 }//Experimental
0176 }//TMVA
0177 
0178 
0179 #endif //TMVA_SOFIE_ROperator_ScatterElements