Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_ScatterElements.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROperator_ScatterElements
0002 #define TMVA_SOFIE_ROperator_ScatterElements
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA{
0011 namespace Experimental{
0012 namespace SOFIE{
0013 
0014 
0015 class ROperator_ScatterElements final : public ROperator{
0016 private:
0017 
0018    int64_t fAxis;
0019 
0020    std::string fNX;
0021    std::string fNI;
0022    std::string fNU;
0023    std::string fNY;
0024    std::string fReduction;
0025 
0026    std::vector<Dim> fShapeX;
0027    std::vector<Dim> fShapeI;
0028    std::vector<Dim> fShapeY;
0029 
0030    // define reduction function. Possibilities are:
0031    // none (default), add, mul, max, min
0032    std::string ReductionFunction(const std::string & t1, const std::string & t2 ) {
0033       std::string name = fReduction;
0034       if (name.empty() || name == "none")
0035          return t2;
0036       else if (name == "add")
0037          return t1 + " + " + t2;
0038       else if (name == "mul")
0039          return t1 + " * " + t2;
0040       else if (name == "max")
0041          return "std::max(" + t1 + "," + t2 + ")";
0042       else if (name == "min")
0043          return "std::min(" + t1 + "," + t2 + ")";
0044       else
0045          throw std::runtime_error("TMVA SOFIE ScatterElements : invalid reduction attribute");
0046 
0047       return std::string();
0048    }
0049 
0050 public:
0051    ROperator_ScatterElements(){}
0052    ROperator_ScatterElements(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY,
0053                            int axis, std::string reduction):
0054       fAxis(axis),
0055       fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)),
0056       fNY(UTILITY::Clean_name(nameY)),
0057       fReduction(reduction)
0058       {
0059          fInputTensorNames = { fNX, fNI, fNU };
0060          fOutputTensorNames = { fNY };
0061       }
0062 
0063 
0064    void Initialize(RModel& model) override {
0065       // input must be a graph input, or already initialized intermediate tensor
0066       if (!model.CheckIfTensorAlreadyExist(fNX)){
0067          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNX + "is not found in model");
0068       }
0069       if (!model.CheckIfTensorAlreadyExist(fNI)) {
0070          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNI + "is not found in model");
0071       }
0072       if (!model.CheckIfTensorAlreadyExist(fNU)) {
0073          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements Op Input Tensor ") + fNU + "is not found in model");
0074       }
0075       //tbd check for constant tensors
0076 
0077       fShapeX = model.GetDimTensorShape(fNX);
0078       fShapeI = model.GetDimTensorShape(fNI);
0079       auto shapeU = model.GetDimTensorShape(fNU);
0080       if (model.Verbose()) {
0081          std::cout << "ScatterElements: input: " << ConvertShapeToString(fShapeX)
0082                                                 << " indices " << ConvertShapeToString(fShapeI)
0083                                                 << " update " <<  ConvertShapeToString(shapeU) << std::endl;
0084       }
0085       if (!model.IsDynamicTensor(fNI) && !model.IsDynamicTensor(fNU)) {
0086          if (shapeU != fShapeI)
0087            throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - update tensor has invalid shape ")) ;
0088       }
0089       if (fShapeX.size() == 0)
0090          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - input tensor has zero rank  ")) ;
0091       if (fShapeX.size() != fShapeI.size())
0092          throw std::runtime_error(std::string("TMVA SOFIE ScatterElements - index tensor has invalid rank  ")) ;
0093 
0094       if (fAxis < 0) fAxis += fShapeX.size();
0095 
0096       // assume output shape is identical to input shape
0097       fShapeY = fShapeX;
0098       model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0099       if (model.Verbose())
0100          std::cout << "\t----> " << ConvertDimShapeToString(fShapeY) << std::endl;
0101    }
0102 
0103    std::string GenerateInitCode() override {
0104       std::stringstream out;
0105       return out.str();
0106    }
0107 
0108    std::string Generate(std::string opName) override {
0109 
0110       if (fIsOutputConstant) return "";
0111 
0112       if (fShapeY.empty()) {
0113          throw std::runtime_error("TMVA SOFIE ScatterElements Op called to Generate without being initialized first");
0114       }
0115       std::stringstream out;
0116       out << SP << "\n//-------- ScatterElements  --- " << opName << "\n";
0117 
0118       auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
0119       auto strideI = UTILITY::ComputeStrideFromShape(fShapeI);
0120 
0121       auto length = ConvertDimShapeToLength(fShapeY);
0122 
0123       // function to write compute expression for global index from axes indices
0124       auto tensorIndex = [](const std::vector<Dim> & stride, const std::vector<std::string> & idx) {
0125          std::stringstream strst;
0126          int dims = idx.size();
0127          assert (dims == (int) stride.size());
0128          for (int i = 0; i < dims; i++) {
0129             if (stride[i].GetVal() != "1")
0130                strst << stride[i] << "*" << idx[i];
0131             else
0132                strst << idx[i];
0133             if (i < dims-1)
0134                strst << " + ";
0135          }
0136          return strst.str();
0137       };
0138 
0139 
0140       // copy first input in output (maybe can be avoided??)
0141       out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << length << ", tensor_" << fNY << ");\n";
0142 
0143       // loop on tensor rank
0144       int dims = fShapeY.size();
0145       std::vector<std::string> idx(dims);
0146       for (int i = 0; i < dims; i++) {
0147          idx[i] = std::string("i") + std::to_string(i);
0148          for (int j = 0; j <= i; j++) out << SP;
0149          out << "for (int " << idx[i] << " = 0; " << idx[i] << " < " << fShapeI[i] << "; " << idx[i] << "++) {\n";
0150       }
0151       // correct index for specific axis
0152       for (int j = 0; j <= dims; j++) out << SP;
0153       out << "int updateIndex = " << tensorIndex(strideI,idx) << ";\n";
0154       for (int j = 0; j <= dims; j++) out << SP;
0155       out << "int iAxis = tensor_" << fNI << "[updateIndex];\n";
0156       for (int j = 0; j <= dims; j++) out << SP;
0157       out << "if (iAxis < 0) iAxis += " << fShapeY[fAxis] << ";\n";
0158       idx[fAxis] = "iAxis";
0159       for (int j = 0; j <= dims; j++) out << SP;
0160       out << "int  outIndex = " << tensorIndex(strideY, idx) << ";\n";
0161       for (int j = 0; j <= dims; j++) out << SP;
0162       out << "tensor_" << fNY << "[outIndex] = "
0163          << ReductionFunction(std::string("tensor_") + fNY + "[outIndex]", std::string("tensor_") + fNU + "[updateIndex]") << ";\n";
0164 
0165       for (int i = dims; i > 0; i--) {
0166          for (int j = 0; j < i; j++) out << SP;
0167          out << "}\n";
0168       }
0169       return out.str();
0170    }
0171 
0172 };
0173 
0174 }//SOFIE
0175 }//Experimental
0176 }//TMVA
0177 
0178 
0179 #endif //TMVA_SOFIE_ROperator_ScatterElements