File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014
0015 #include "onnx/version_converter/adapters/adapter.h"
0016
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019
0020 struct Scan_9_8 final : public Adapter {
0021 explicit Scan_9_8() : Adapter("Scan", OpSetID(9), OpSetID(8)) {}
0022
0023 void adapt_scan_9_8(std::shared_ptr<Graph>, Node* node) const {
0024 const std::vector<Value*> inputs(node->inputs().vec());
0025 const std::vector<Value*> outputs(node->outputs().vec());
0026
0027
0028
0029 Symbol input_dirs = Symbol("scan_input_directions");
0030 if (node->hasAttribute(input_dirs)) {
0031 const std::vector<int64_t> scan_input_directions(node->is(input_dirs));
0032 node->removeAttribute(input_dirs);
0033 node->is_(Symbol("directions"), std::move(scan_input_directions));
0034 }
0035
0036 Symbol output_dirs = Symbol("scan_output_directions");
0037 if (node->hasAttribute(output_dirs)) {
0038 const std::vector<int64_t> scan_output_directions(node->is(output_dirs));
0039 for (int64_t x : scan_output_directions) {
0040 ONNX_ASSERTM(x == 0, "Unsupported output direction for Version 8");
0041 }
0042 node->removeAttribute(output_dirs);
0043 }
0044
0045 Symbol input_axes = Symbol("scan_input_axes");
0046 if (node->hasAttribute(input_axes)) {
0047 const std::vector<int64_t> scan_input_axes(node->is(input_axes));
0048 for (int64_t x : scan_input_axes) {
0049 ONNX_ASSERTM(x == 0, "Unsupported input axes for Version 8");
0050 }
0051 node->removeAttribute(input_axes);
0052 }
0053
0054 Symbol output_axes = Symbol("scan_output_axes");
0055 if (node->hasAttribute(output_axes)) {
0056 const std::vector<int64_t> scan_output_axes(node->is(output_axes));
0057 for (int64_t x : scan_output_axes) {
0058 ONNX_ASSERTM(x == 0, "Unsupported output axes for Version 8");
0059 }
0060 node->removeAttribute(output_axes);
0061 }
0062
0063
0064
0065 node->removeAllInputs();
0066
0067 Value* v = new Value(node, 0);
0068 v->setUniqueName("");
0069 v->setElemType(TensorProto_DataType::TensorProto_DataType_INT32);
0070 node->addInput(v);
0071
0072 for (Value* input : inputs) {
0073 std::vector<Dimension> new_sizes{Dimension(1)};
0074 new_sizes.insert(new_sizes.end(), input->sizes().begin(), input->sizes().end());
0075 input->setSizes(new_sizes);
0076 node->addInput(input);
0077 }
0078
0079 for (Value* output : outputs) {
0080 std::vector<Dimension> new_sizes{Dimension(1)};
0081 new_sizes.insert(new_sizes.end(), output->sizes().begin(), output->sizes().end());
0082 output->setSizes(new_sizes);
0083 }
0084 }
0085
0086 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0087 adapt_scan_9_8(graph, node);
0088 return node;
0089 }
0090 };
0091
0092 }
0093 }