File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 struct Scan_8_9 final : public Adapter {
0020 explicit Scan_8_9() : Adapter("Scan", OpSetID(8), OpSetID(9)) {}
0021
0022 void adapt_scan_8_9(std::shared_ptr<Graph>, Node* node) const {
0023 const std::vector<Value*> inputs(node->inputs().vec());
0024 const std::vector<Value*> outputs(node->outputs().vec());
0025
0026
0027
0028 Symbol dirs = Symbol("directions");
0029 if (node->hasAttribute(dirs)) {
0030 const std::vector<int64_t> directions(node->is(dirs));
0031 node->removeAttribute(dirs);
0032 node->is_(Symbol("scan_input_directions"), std::move(directions));
0033 }
0034
0035
0036
0037 node->removeAllInputs();
0038
0039 ONNX_ASSERTM(inputs[0]->uniqueName() == "", "Unsupported conversion to opset 9");
0040
0041 for (Value* input : inputs) {
0042 if (!input->sizes().empty()) {
0043 std::vector<Dimension> new_sizes(input->sizes().begin() + 1, input->sizes().end());
0044 input->setSizes(new_sizes);
0045 node->addInput(input);
0046 }
0047 }
0048
0049 for (Value* output : outputs) {
0050 if (!output->sizes().empty()) {
0051 std::vector<Dimension> new_sizes(output->sizes().begin() + 1, output->sizes().end());
0052 output->setSizes(new_sizes);
0053 }
0054 }
0055 }
0056
0057 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0058 adapt_scan_8_9(graph, node);
0059 return node;
0060 }
0061 };
0062
0063 }
0064 }