Warning, file /include/root/TMVA/ROperator_TopK.hxx was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001 #ifndef TMVA_SOFIE_ROPERATOR_TOPK
0002 #define TMVA_SOFIE_ROPERATOR_TOPK
0003
0004 #include "TMVA/SOFIE_common.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/RModel.hxx"
0007
0008 #include <sstream>
0009
0010 namespace TMVA {
0011 namespace Experimental {
0012 namespace SOFIE {
0013
0014 template <typename T>
0015 class ROperator_TopK final : public ROperator {
0016
0017 private:
0018 int fAttrAxis;
0019 int fAttrLargest;
0020 int fAttrSorted;
0021
0022 size_t fK;
0023 std::string fNK;
0024 std::string fNX;
0025 std::string fNVal;
0026 std::string fNInd;
0027 std::vector<size_t> fShapeX;
0028 std::vector<size_t> fShapeY;
0029 std::string fType;
0030
0031 public:
0032 ROperator_TopK() {}
0033 ROperator_TopK(int attr_axis, int attr_largest, int attr_sorted, std::string nameK, std::string nameX, std::string nameVal, std::string nameInd)
0034 : fAttrAxis(attr_axis),
0035 fAttrLargest(attr_largest),
0036 fAttrSorted(attr_sorted),
0037 fNK(UTILITY::Clean_name(nameK)),
0038 fNX(UTILITY::Clean_name(nameX)),
0039 fNVal(UTILITY::Clean_name(nameVal)),
0040 fNInd(UTILITY::Clean_name(nameInd)){
0041 fInputTensorNames = { fNX, fNK };
0042 fOutputTensorNames = { fNVal, fNInd };
0043 }
0044
0045 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input) override {
0046 ETensorType ret = input[0];
0047 return {ret, ret};
0048 }
0049
0050 std::vector<std::vector<size_t>> ShapeInference(std::vector<std::vector<size_t>> input) override {
0051 if (input.size() != 2) {
0052 throw std::runtime_error("TMVA SOFIE TopK Op Shape Inference needs exactly 2 input tensors");
0053 }
0054
0055 auto shape = input[0];
0056
0057
0058 shape[fAttrAxis] = fK;
0059 return {shape, shape};
0060 }
0061
0062
0063 void Initialize(RModel& model) override {
0064 if (model.CheckIfTensorAlreadyExist(fNX) == false) {
0065
0066 throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor is not found in model");
0067 }
0068 if (model.CheckIfTensorAlreadyExist(fNK) == false) {
0069
0070 throw std::runtime_error("TMVA SOFIE TopK Op Input Tensor i.e. K is not found in model");
0071 }
0072
0073 fShapeX = model.GetTensorShape(fNX);
0074 auto fShapeK = model.GetTensorShape(fNK);
0075 auto kptr = static_cast<int64_t *>(model.GetInitializedTensorData(fNK).get());
0076 fK = *kptr;
0077 model.SetNotWritableInitializedTensor(fNK);
0078 fAttrAxis = fAttrAxis < 0 ? fShapeX.size() + fAttrAxis : fAttrAxis;
0079 if(static_cast<size_t>(fAttrAxis) >= fShapeX.size()){
0080 throw
0081 std::runtime_error("TMVA::SOFIE ONNX TopK op axis = "+ std::to_string(fAttrAxis) +" value exeeds size of tensor " +fNX+" of size "+fShapeX.size()+" .");
0082 }
0083
0084 fK = std::min(fK, fShapeX[fAttrAxis]);
0085
0086
0087
0088
0089
0090
0091
0092
0093
0094
0095
0096
0097
0098
0099 fShapeY=ShapeInference({fShapeX,fShapeK})[0];
0100
0101
0102
0103
0104
0105
0106 model.AddIntermediateTensor(fNVal, model.GetTensorType(fNX), fShapeY);
0107
0108 model.AddIntermediateTensor(fNInd, ETensorType::INT64, fShapeY);
0109 fType = ConvertTypeToString(model.GetTensorType(fNX));
0110 }
0111
0112 std::string Generate(std::string OpName) override {
0113 OpName = "op_" + OpName;
0114 if (fShapeX.empty()) {
0115 throw std::runtime_error("TMVA SOFIE Operator TopK called to Generate without being initialized first");
0116 }
0117 std::stringstream out;
0118 size_t size = fShapeX.size();
0119 size_t axis = fAttrAxis < 0 ? size + fAttrAxis : fAttrAxis;
0120 out << "\n" << SP << "//------ TopK\n";
0121
0122 size_t length=ConvertShapeToLength(fShapeX);
0123 auto strideX = UTILITY::ComputeStrideFromShape(fShapeX);
0124 auto strideY = UTILITY::ComputeStrideFromShape(fShapeX);
0125
0126 size_t n_before = (axis>0) ? length/strideX[axis-1] : 1;
0127 size_t n_after = strideX[axis];
0128 size_t n_elements = fShapeX[axis];
0129
0130
0131 out << SP << "{\n";
0132 out << SP << "std::vector<std::pair<float,int64_t>> elements(" << n_elements << ");\n";
0133
0134 if (n_before > 1) {
0135 out << SP << "for (size_t i = 0; i < " << n_before << "; i++) {\n";
0136 out << SP << SP << "size_t xoffset = i*" << strideX[axis-1] << ";\n";
0137 out << SP << SP << "size_t yoffset = i*" << strideY[axis-1] << ";\n";
0138 out << SP;
0139 } else {
0140 out << SP << "size_t xoffset = 0;\n";
0141 out << SP << "size_t yoffset = 0;\n";
0142 }
0143 if (n_after > 1)
0144 out << SP << "for (size_t j = 0; j < " << n_after << "; j++) {\n";
0145 else
0146 out << SP << "const size_t j = 0;\n";
0147
0148
0149 out << SP << SP << "for (size_t l = 0; l < " << n_elements << "; l++) {\n";
0150 out << SP << SP << SP << "elements[l] = std::make_pair(tensor_" << fNX << "[xoffset + " << strideX[axis] << "*l + j], l);\n";
0151 out << SP << SP << "}\n";
0152
0153 if (fAttrSorted) {
0154 if (fAttrLargest) {
0155 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0156 "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first>b.first) : a.second < b.second;});\n";
0157
0158 } else
0159 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end()," <<
0160 "[](std::pair<float,int64_t>a,std::pair<float,int64_t>b){return (a.first!=b.first) ? (a.first<b.first) : a.second < b.second;});\n";
0161 } else
0162
0163 out<<SP<<SP << "std::partial_sort(elements.begin(),elements.begin()+" << fK << ",elements.end());\n";
0164
0165
0166 out << SP << SP << "for (size_t l = 0; l < " << fK << "; l++) {\n";
0167 out << SP << SP << SP << "tensor_" << fNVal << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].first;\n";
0168 out << SP << SP << SP << "tensor_" << fNInd << "[yoffset + " << strideY[axis] << "*l + j] = elements[l].second;\n";
0169 out << SP << SP << "}\n";
0170 if (n_after > 1) out << SP << SP << "}\n";
0171 if (n_before> 1) out << SP << "}\n";
0172 out << SP << "}\n";
0173 return out.str();
0174 }
0175 };
0176
0177 }
0178 }
0179 }
0180
0181 #endif