File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016
0017 class TopK_9_10 final : public Adapter {
0018 public:
0019 explicit TopK_9_10() : Adapter("TopK", OpSetID(9), OpSetID(10)) {}
0020
0021 void adapt_topk_9_10(std::shared_ptr<Graph> graph, Node* node) const {
0022 Tensor t;
0023 t.elem_type() = TensorProto_DataType_INT64;
0024 t.sizes() = std::vector<int64_t>{1};
0025 auto& data = t.int64s();
0026 data.emplace_back(node->i(kk));
0027
0028 Node* constant = graph->create(kConstant);
0029 constant->insertBefore(node);
0030 constant->t_(kvalue, t);
0031 node->addInput(constant->output());
0032
0033 node->removeAttribute(kk);
0034 }
0035
0036 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0037 adapt_topk_9_10(graph, node);
0038 return node;
0039 }
0040 };
0041
0042 }
0043 }