Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Scan in default domain from version 9 to 8
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 
0020 struct Scan_9_8 final : public Adapter {
0021   explicit Scan_9_8() : Adapter("Scan", OpSetID(9), OpSetID(8)) {}
0022 
0023   void adapt_scan_9_8(std::shared_ptr<Graph>, Node* node) const {
0024     const std::vector<Value*> inputs(node->inputs().vec());
0025     const std::vector<Value*> outputs(node->outputs().vec());
0026 
0027     // Handling Attribute Changes
0028 
0029     Symbol input_dirs = Symbol("scan_input_directions");
0030     if (node->hasAttribute(input_dirs)) {
0031       const std::vector<int64_t> scan_input_directions(node->is(input_dirs));
0032       node->removeAttribute(input_dirs);
0033       node->is_(Symbol("directions"), std::move(scan_input_directions));
0034     }
0035 
0036     Symbol output_dirs = Symbol("scan_output_directions");
0037     if (node->hasAttribute(output_dirs)) {
0038       const std::vector<int64_t> scan_output_directions(node->is(output_dirs));
0039       for (int64_t x : scan_output_directions) {
0040         ONNX_ASSERTM(x == 0, "Unsupported output direction for Version 8");
0041       }
0042       node->removeAttribute(output_dirs);
0043     }
0044 
0045     Symbol input_axes = Symbol("scan_input_axes");
0046     if (node->hasAttribute(input_axes)) {
0047       const std::vector<int64_t> scan_input_axes(node->is(input_axes));
0048       for (int64_t x : scan_input_axes) {
0049         ONNX_ASSERTM(x == 0, "Unsupported input axes for Version 8");
0050       }
0051       node->removeAttribute(input_axes);
0052     }
0053 
0054     Symbol output_axes = Symbol("scan_output_axes");
0055     if (node->hasAttribute(output_axes)) {
0056       const std::vector<int64_t> scan_output_axes(node->is(output_axes));
0057       for (int64_t x : scan_output_axes) {
0058         ONNX_ASSERTM(x == 0, "Unsupported output axes for Version 8");
0059       }
0060       node->removeAttribute(output_axes);
0061     }
0062 
0063     // Handling Input and Output Changes
0064 
0065     node->removeAllInputs();
0066 
0067     Value* v = new Value(node, 0);
0068     v->setUniqueName("");
0069     v->setElemType(TensorProto_DataType::TensorProto_DataType_INT32);
0070     node->addInput(v);
0071 
0072     for (Value* input : inputs) {
0073       std::vector<Dimension> new_sizes{Dimension(1)};
0074       new_sizes.insert(new_sizes.end(), input->sizes().begin(), input->sizes().end());
0075       input->setSizes(new_sizes);
0076       node->addInput(input);
0077     }
0078 
0079     for (Value* output : outputs) {
0080       std::vector<Dimension> new_sizes{Dimension(1)};
0081       new_sizes.insert(new_sizes.end(), output->sizes().begin(), output->sizes().end());
0082       output->setSizes(new_sizes);
0083     }
0084   }
0085 
0086   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0087     adapt_scan_9_8(graph, node);
0088     return node;
0089   }
0090 };
0091 
0092 } // namespace version_conversion
0093 } // namespace ONNX_NAMESPACE