Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_TopK.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_TOPK
0002 #define TMVA_SOFIE_ROPERATOR_TOPK
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 template <typename T>
0015 class ROperator_TopK final : public ROperator {
0016 
0017 private:
0018    int fAttrAxis;
0019    int fAttrLargest;
0020    int fAttrSorted;
0021 
0022    size_t fK;
0023    std::string fNK;
0024    std::string fNX;
0025    std::string fNVal;
0026    std::string fNInd;
0027    std::vector<size_t> fShapeX;
0028    std::vector<size_t> fShapeY;
0029    std::string fType;
0030 
0031 public:
0032    ROperator_TopK() {}
0033    ROperator_TopK(int attr_axis, int attr_largest, int attr_sorted, std::string nameK, std::string nameX, std::string nameVal, std::string nameInd)
0034       : fAttrAxis(attr_axis),
0035         fAttrLargest(attr_largest),
0036         fAttrSorted(attr_sorted),
0037         fNK(UTILITY::Clean_name(nameK)),
0038         fNX(UTILITY::Clean_name(nameX)),
0039         fNVal(UTILITY::Clean_name(nameVal)),
0040         fNInd(UTILITY::Clean_name(nameInd)){
0041             fInputTensorNames = { fNX, fNK };
0042             fOutputTensorNames = { fNVal, fNInd };
0043         }
0044 
0045    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0046          ETensorType ret = input[0];
0047          return {ret, ret};
0048       }
0049 
0050    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0051       if (input.size() != 2) {
0052          throw std::runtime_error("TMVA SOFIE TopK Op Shape Inference needs exactly 2 input tensors");
0053       }
0054 
0055       auto shape = input[0]; // Shape format: [ m x n x o x p ... ]
0056 
0057       // set the dimension at the specified axis to k  (fAttrAxis is checked before that is in the correct range
0058       shape[fAttrAxis] = fK; // Modified shape: [ m x n x k x p ... ]
0059       return {shape, shape};
0060    }
0061 
0062 
0063    void Initialize(RModel& model) override {
0064       if (model.CheckIfTensorAlreadyExist(fNX) == false) {
0065          // input must be a graph input, or already initialized intermediate tensor
0066          throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor is not found in model");
0067       }
0068       if (model.CheckIfTensorAlreadyExist(fNK) == false) {
0069          // input must be a graph input, or already initialized intermediate tensor
0070          throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor i.e. K is not found in model");
0071       }
0072 
0073       fShapeX = model.GetTensorShape(fNX);
0074       auto fShapeK = model.GetTensorShape(fNK);
0075       auto kptr = static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
0076       fK = *kptr;
0077       model.SetNotWritableInitializedTensor(fNK);
0078       fAttrAxis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
0079       if(static_cast<size_t>(fAttrAxis) >=  fShapeX.size()){
0080          throw
0081             std::runtime_error("TMVA::SOFIE ONNX TopK op axis = "+ std::to_string(fAttrAxis) +" value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
0082       }
0083       // fK cannot be larger that axis dimension
0084       fK = std::min(fK, fShapeX[fAttrAxis]);
0085       // if(fK>fShapeX[fAttrAxis]){
0086       //    throw
0087       //       std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(fK) +" value exeeds value of tensor " +fNX+" of size "+fShapeX.size()+" at axis= "+std::to_string(fAttrAxis)+".");
0088       // }
0089       // fShapeX = model.GetTensorShape(fNX); //  [ m x n x o x p ... ]
0090       // if(k[0]>=fShapeX.size()){
0091       //    throw
0092       //       std::runtime_error("TMVA::SOFIE ONNX TopK op k = "+ std::to_string(k[0]) +"value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
0093       // }
0094       // fShapeY.push_back(2);
0095       // for (auto i : fShapeX)
0096       //    fShapeY.push_back(i); //  [ 2 x m x n x o x p ... ]
0097       // size_t axis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
0098       // fShapeY[axis] = k[0]; //  [ 2 x m x n x K x p ... ]
0099       fShapeY=ShapeInference({fShapeX,fShapeK})[0];
0100 
0101       // for(int i=0;i<fShapeX.size();i++)
0102       // std::cout<<fShapeX[i]<<" ";
0103       // std::cout<<"\ny size -> "<<fShapeY.size()<<std::endl;
0104 
0105 
0106       model.AddIntermediateTensor(fNVal, model.GetTensorType(fNX), fShapeY);
0107       // output indices should be an int64 tensor
0108       model.AddIntermediateTensor(fNInd, ETensorType::INT64, fShapeY);
0109       fType = ConvertTypeToString(model.GetTensorType(fNX));
0110    }
0111 
0112    std::string Generate(std::string OpName) override {
0113       OpName = "op_" + OpName;
0114       if (fShapeX.empty()) {
0115          throw std::runtime_error("TMVA SOFIE Operator TopK called to Generate without being initialized first");
0116       }
0117       std::stringstream out;
0118       size_t size = fShapeX.size();
0119       size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0120       out << "\n" << SP << "//------ TopK\n";
0121 
0122       size_t length=ConvertShapeToLength(fShapeX);
0123       auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
0124       auto strideY = UTILITY::ComputeStrideFromShape(fShapeX);
0125       // we perform loop on dimension before sorted axis and after sorted axis
0126       size_t n_before = (axis>0) ? length/strideX[axis-1] : 1;
0127       size_t n_after = strideX[axis];
0128       size_t n_elements = fShapeX[axis]; // number of elements to be sorted
0129 
0130       // }
0131       out << SP << "{\n"; // to define a separate scope for the operator code
0132       out << SP << "std::vector<std::pair<float,int64_t>> elements(" << n_elements << ");\n";
0133       // loop on elements before
0134       if (n_before > 1) {
0135          out << SP << "for (size_t i = 0; i < " << n_before << "; i++) {\n";
0136          out << SP << SP << "size_t xoffset = i*" << strideX[axis-1] << ";\n";
0137          out << SP << SP << "size_t yoffset = i*" << strideY[axis-1] << ";\n";
0138          out << SP;
0139       } else {
0140          out << SP << "size_t xoffset = 0;\n";
0141          out << SP << "size_t yoffset = 0;\n";
0142       }
0143       if (n_after > 1)
0144          out << SP << "for (size_t j = 0; j < " << n_after << "; j++) {\n";
0145       else
0146          out << SP << "const size_t j = 0;\n";
0147 
0148       // copy elements to be sorted in vector of pair
0149       out << SP << SP << "for (size_t l = 0; l < " << n_elements << "; l++) {\n";
0150       out << SP << SP << SP << "elements[l] = std::make_pair(tensor_" << fNX << "[xoffset + " << strideX[axis] << "*l + j], l);\n";
0151       out << SP << SP << "}\n";
0152 
0153       if (fAttrSorted) {
0154          if (fAttrLargest) {
0155             out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0156                "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first>b.first) : a.second < b.second;});\n";
0157 
0158          } else
0159             out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0160             "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first<b.first) : a.second < b.second;});\n";
0161       } else
0162          // in this case we don;t need to return sorted elements, so we keep same order as before
0163          out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end());\n";
0164 
0165       // copy the selected elements in the output
0166       out << SP << SP << "for (size_t l = 0; l < " << fK << "; l++) {\n";
0167       out << SP << SP << SP << "tensor_" << fNVal   << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].first;\n";
0168       out << SP << SP << SP << "tensor_" << fNInd << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].second;\n";
0169       out << SP << SP << "}\n";
0170       if (n_after > 1) out << SP << SP << "}\n";
0171       if (n_before> 1) out << SP << "}\n";
0172       out << SP << "}\n"; // end operator scope
0173       return out.str();
0174    }
0175 };
0176 
0177 } // nameSPace SOFIE
0178 } // nameSPace Experimental
0179 } // nameSPace TMVA
0180 
0181 #endif // TMVA_SOFIE_ROPERATOR_TOPK