Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Scan in default domain from version 8 to 9
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <utility>
0013 #include <vector>
0014 
0015 #include "onnx/version_converter/adapters/adapter.h"
0016 
0017 namespace ONNX_NAMESPACE {
0018 namespace version_conversion {
0019 struct Scan_8_9 final : public Adapter {
0020   explicit Scan_8_9() : Adapter("Scan", OpSetID(8), OpSetID(9)) {}
0021 
0022   void adapt_scan_8_9(std::shared_ptr<Graph>, Node* node) const {
0023     const std::vector<Value*> inputs(node->inputs().vec());
0024     const std::vector<Value*> outputs(node->outputs().vec());
0025 
0026     // Handling Attribute Changes
0027 
0028     Symbol dirs = Symbol("directions");
0029     if (node->hasAttribute(dirs)) {
0030       const std::vector<int64_t> directions(node->is(dirs));
0031       node->removeAttribute(dirs);
0032       node->is_(Symbol("scan_input_directions"), std::move(directions));
0033     }
0034 
0035     // Handling Input and Output Changes
0036 
0037     node->removeAllInputs();
0038 
0039     ONNX_ASSERTM(inputs[0]->uniqueName() == "", "Unsupported conversion to opset 9");
0040 
0041     for (Value* input : inputs) {
0042       if (!input->sizes().empty()) {
0043         std::vector<Dimension> new_sizes(input->sizes().begin() + 1, input->sizes().end());
0044         input->setSizes(new_sizes);
0045         node->addInput(input);
0046       }
0047     }
0048 
0049     for (Value* output : outputs) {
0050       if (!output->sizes().empty()) {
0051         std::vector<Dimension> new_sizes(output->sizes().begin() + 1, output->sizes().end());
0052         output->setSizes(new_sizes);
0053       }
0054     }
0055   }
0056 
0057   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0058     adapt_scan_8_9(graph, node);
0059     return node;
0060   }
0061 };
0062 
0063 } // namespace version_conversion
0064 } // namespace ONNX_NAMESPACE