Back to home page

EIC code displayed by LXR

 
 

    


Warning, file /include/root/TMVA/ROperator_TopK.hxx was not indexed or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).

0001 #ifndef TMVA_SOFIE_ROPERATOR_TOPK
0002 #define TMVA_SOFIE_ROPERATOR_TOPK
0003 
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007 
0008 #include <sstream>
0009 
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013 
0014 template <typename T>
0015 class ROperator_TopK final : public ROperator {
0016 
0017 private:
0018    int fAttrAxis;
0019    int fAttrLargest;
0020    int fAttrSorted;
0021 
0022    size_t fK;
0023    std::string fNK;
0024    std::string fNX;
0025    std::string fNVal;
0026    std::string fNInd;
0027    std::vector<size_t> fShapeX;
0028    std::vector<size_t> fShapeY;
0029    std::string fType;
0030 
0031 public:
0032    ROperator_TopK() {}
0033    ROperator_TopK(int attr_axis, int attr_largest, int attr_sorted, std::string nameK, std::string nameX, std::string nameVal, std::string nameInd)
0034       : fAttrAxis(attr_axis),
0035         fAttrLargest(attr_largest),
0036         fAttrSorted(attr_sorted),
0037         fNK(UTILITY::Clean_name(nameK)),
0038         fNX(UTILITY::Clean_name(nameX)),
0039         fNVal(UTILITY::Clean_name(nameVal)),
0040         fNInd(UTILITY::Clean_name(nameInd)){
0041             fInputTensorNames = { fNX, fNK };
0042             fOutputTensorNames = { fNVal, fNInd };
0043         }
0044 
0045    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0046          ETensorType ret = input[0];
0047          return {ret, ret};
0048       }
0049 
0050    std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0051       if (input.size() != 2) {
0052          throw std::runtime_error("TMVA SOFIE TopK Op Shape Inference needs exactly 2 input tensors");
0053       }
0054 
0055       auto shape = input[0]; // Shape format: [ m x n x o x p ... ]
0056 
0057       // set the dimension at the specified axis to k  (fAttrAxis is checked before that is in the correct range
0058       shape[fAttrAxis] = fK; // Modified shape: [ m x n x k x p ... ]
0059       return {shape, shape};
0060    }
0061 
0062 
0063    void Initialize(RModel& model) override {
0064       if (model.CheckIfTensorAlreadyExist(fNX) == false) {
0065          // input must be a graph input, or already initialized intermediate tensor
0066          throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor is not found in model");
0067       }
0068       if (model.CheckIfTensorAlreadyExist(fNK) == false) {
0069          // input must be a graph input, or already initialized intermediate tensor
0070          throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor i.e. K is not found in model");
0071       }
0072 
0073       fShapeX = model.GetTensorShape(fNX);
0074       auto fShapeK = model.GetTensorShape(fNK);
0075       auto kptr = static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
0076       fK = *kptr;
0077       model.SetNotWritableInitializedTensor(fNK);
0078       fAttrAxis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
0079       if(static_cast<size_t>(fAttrAxis) >=  fShapeX.size()){
0080          throw
0081             std::runtime_error("TMVA::SOFIE ONNX TopK op axis = "+ std::to_string(fAttrAxis) +" value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
0082       }
0083       // fK cannot be larger that axis dimension
0084       fK = std::min(fK, fShapeX[fAttrAxis]);
0085 
0086       fShapeY = ShapeInference({fShapeX, fShapeK})[0];
0087       model.AddIntermediateTensor(fNVal, model.GetTensorType(fNX), fShapeY);
0088 
0089       // output indices should be an int64 tensor
0090       model.AddIntermediateTensor(fNInd, ETensorType::INT64, fShapeY);
0091       fType = ConvertTypeToString(model.GetTensorType(fNX));
0092    }
0093 
0094    std::string Generate(std::string OpName) override {
0095       OpName = "op_" + OpName;
0096       if (fShapeX.empty()) {
0097          throw std::runtime_error("TMVA SOFIE Operator TopK called to Generate without being initialized first");
0098       }
0099       std::stringstream out;
0100       size_t size = fShapeX.size();
0101       size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0102       out << "\n" << SP << "//------ TopK\n";
0103 
0104       size_t length=ConvertShapeToLength(fShapeX);
0105       auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
0106       auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
0107       // we perform loop on dimension before sorted axis and after sorted axis
0108       size_t n_before = (axis>0) ? length/strideX[axis-1] : 1;
0109       size_t n_after = strideX[axis];
0110       size_t n_elements = fShapeX[axis]; // number of elements to be sorted
0111 
0112       // }
0113       out << SP << "{\n"; // to define a separate scope for the operator code
0114       out << SP << "std::vector<std::pair<float,int64_t>> elements(" << n_elements << ");\n";
0115       // loop on elements before
0116       if (n_before > 1) {
0117          out << SP << "for (size_t i = 0; i < " << n_before << "; i++) {\n";
0118          out << SP << SP << "size_t xoffset = i*" << strideX[axis-1] << ";\n";
0119          out << SP << SP << "size_t yoffset = i*" << strideY[axis-1] << ";\n";
0120          out << SP;
0121       } else {
0122          out << SP << "size_t xoffset = 0;\n";
0123          out << SP << "size_t yoffset = 0;\n";
0124       }
0125       if (n_after > 1)
0126          out << SP << "for (size_t j = 0; j < " << n_after << "; j++) {\n";
0127       else
0128          out << SP << "const size_t j = 0;\n";
0129 
0130       // copy elements to be sorted in vector of pair
0131       out << SP << SP << "for (size_t l = 0; l < " << n_elements << "; l++) {\n";
0132       out << SP << SP << SP << "elements[l] = std::make_pair(tensor_" << fNX << "[xoffset + " << strideX[axis] << "*l + j], l);\n";
0133       out << SP << SP << "}\n";
0134 
0135       if (fAttrSorted) {
0136          if (fAttrLargest) {
0137             out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0138                "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first>b.first) : a.second < b.second;});\n";
0139 
0140          } else
0141             out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0142             "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first<b.first) : a.second < b.second;});\n";
0143       } else
0144          // in this case we don;t need to return sorted elements, so we keep same order as before
0145          out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end());\n";
0146 
0147       // copy the selected elements in the output
0148       out << SP << SP << "for (size_t l = 0; l < " << fK << "; l++) {\n";
0149       out << SP << SP << SP << "tensor_" << fNVal   << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].first;\n";
0150       out << SP << SP << SP << "tensor_" << fNInd << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].second;\n";
0151       out << SP << SP << "}\n";
0152       if (n_after > 1) out << SP << SP << "}\n";
0153       if (n_before> 1) out << SP << "}\n";
0154       out << SP << "}\n"; // end operator scope
0155       return out.str();
0156    }
0157 };
0158 
0159 } // namespace SOFIE
0160 } // namespace Experimental
0161 } // namespace TMVA
0162 
0163 #endif // TMVA_SOFIE_ROPERATOR_TOPK