Warning, file /include/root/TMVA/ROperator_TopK.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_TOPK
0002 #define TMVA_SOFIE_ROPERATOR_TOPK
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_TopK final : public ROperator {
0016
0017 private:
0018 int fAttrAxis;
0019 int fAttrLargest;
0020 int fAttrSorted;
0021
0022 size_t fK;
0023 std::string fNK;
0024 std::string fNX;
0025 std::string fNVal;
0026 std::string fNInd;
0027 std::vector<size_t> fShapeX;
0028 std::vector<size_t> fShapeY;
0029 std::string fType;
0030
0031 public:
0032 ROperator_TopK() {}
0033 ROperator_TopK(int attr_axis, int attr_largest, int attr_sorted, std::string nameK, std::string nameX, std::string nameVal, std::string nameInd)
0034 : fAttrAxis(attr_axis),
0035 fAttrLargest(attr_largest),
0036 fAttrSorted(attr_sorted),
0037 fNK(UTILITY::Clean_name(nameK)),
0038 fNX(UTILITY::Clean_name(nameX)),
0039 fNVal(UTILITY::Clean_name(nameVal)),
0040 fNInd(UTILITY::Clean_name(nameInd)){
0041 fInputTensorNames = { fNX, fNK };
0042 fOutputTensorNames = { fNVal, fNInd };
0043 }
0044
0045 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0046 ETensorType ret = input[0];
0047 return {ret, ret};
0048 }
0049
0050 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0051 if (input.size() != 2) {
0052 throw std::runtime_error("TMVA SOFIE TopK Op Shape Inference needs exactly 2 input tensors");
0053 }
0054
0055 auto shape = input[0];
0056
0057
0058 shape[fAttrAxis] = fK;
0059 return {shape, shape};
0060 }
0061
0062
0063 void Initialize(RModel& model) override {
0064 if (model.CheckIfTensorAlreadyExist(fNX) == false) {
0065
0066 throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor is not found in model");
0067 }
0068 if (model.CheckIfTensorAlreadyExist(fNK) == false) {
0069
0070 throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor i.e. K is not found in model");
0071 }
0072
0073 fShapeX = model.GetTensorShape(fNX);
0074 auto fShapeK = model.GetTensorShape(fNK);
0075 auto kptr = static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
0076 fK = *kptr;
0077 model.SetNotWritableInitializedTensor(fNK);
0078 fAttrAxis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
0079 if(static_cast<size_t>(fAttrAxis) >= fShapeX.size()){
0080 throw
0081 std::runtime_error("TMVA::SOFIE ONNX TopK op axis = "+ std::to_string(fAttrAxis) +" value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
0082 }
0083
0084 fK = std::min(fK, fShapeX[fAttrAxis]);
0085
0086 fShapeY = ShapeInference({fShapeX, fShapeK})[0];
0087 model.AddIntermediateTensor(fNVal, model.GetTensorType(fNX), fShapeY);
0088
0089
0090 model.AddIntermediateTensor(fNInd, ETensorType::INT64, fShapeY);
0091 fType = ConvertTypeToString(model.GetTensorType(fNX));
0092 }
0093
0094 std::string Generate(std::string OpName) override {
0095 OpName = "op_" + OpName;
0096 if (fShapeX.empty()) {
0097 throw std::runtime_error("TMVA SOFIE Operator TopK called to Generate without being initialized first");
0098 }
0099 std::stringstream out;
0100 size_t size = fShapeX.size();
0101 size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0102 out << "\n" << SP << "//------ TopK\n";
0103
0104 size_t length=ConvertShapeToLength(fShapeX);
0105 auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
0106 auto strideY = UTILITY::ComputeStrideFromShape(fShapeY);
0107
0108 size_t n_before = (axis>0) ? length/strideX[axis-1] : 1;
0109 size_t n_after = strideX[axis];
0110 size_t n_elements = fShapeX[axis];
0111
0112
0113 out << SP << "{\n";
0114 out << SP << "std::vector<std::pair<float,int64_t>> elements(" << n_elements << ");\n";
0115
0116 if (n_before > 1) {
0117 out << SP << "for (size_t i = 0; i < " << n_before << "; i++) {\n";
0118 out << SP << SP << "size_t xoffset = i*" << strideX[axis-1] << ";\n";
0119 out << SP << SP << "size_t yoffset = i*" << strideY[axis-1] << ";\n";
0120 out << SP;
0121 } else {
0122 out << SP << "size_t xoffset = 0;\n";
0123 out << SP << "size_t yoffset = 0;\n";
0124 }
0125 if (n_after > 1)
0126 out << SP << "for (size_t j = 0; j < " << n_after << "; j++) {\n";
0127 else
0128 out << SP << "const size_t j = 0;\n";
0129
0130
0131 out << SP << SP << "for (size_t l = 0; l < " << n_elements << "; l++) {\n";
0132 out << SP << SP << SP << "elements[l] = std::make_pair(tensor_" << fNX << "[xoffset + " << strideX[axis] << "*l + j], l);\n";
0133 out << SP << SP << "}\n";
0134
0135 if (fAttrSorted) {
0136 if (fAttrLargest) {
0137 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0138 "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first>b.first) : a.second < b.second;});\n";
0139
0140 } else
0141 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0142 "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first<b.first) : a.second < b.second;});\n";
0143 } else
0144
0145 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end());\n";
0146
0147
0148 out << SP << SP << "for (size_t l = 0; l < " << fK << "; l++) {\n";
0149 out << SP << SP << SP << "tensor_" << fNVal << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].first;\n";
0150 out << SP << SP << SP << "tensor_" << fNInd << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].second;\n";
0151 out << SP << SP << "}\n";
0152 if (n_after > 1) out << SP << SP << "}\n";
0153 if (n_before> 1) out << SP << "}\n";
0154 out << SP << "}\n";
0155 return out.str();
0156 }
0157 };
0158
0159 }
0160 }
0161 }
0162
0163 #endif