Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:55

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Pad in default domain from version 10 to 11
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 #include <vector>
0013 
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016 
0017 class Pad_10_11 final : public Adapter {
0018  public:
0019   explicit Pad_10_11() : Adapter("Pad", OpSetID(10), OpSetID(11)) {}
0020 
0021   void adapt_pad_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0022     // Turn pads attribute into input
0023     Tensor t_pads;
0024     t_pads.elem_type() = TensorProto_DataType_INT64;
0025     auto& data_pads = t_pads.int64s();
0026     for (int64_t shape : node->is(kpads)) {
0027       data_pads.emplace_back(shape);
0028     }
0029     t_pads.sizes() = std::vector<int64_t>{(int64_t)data_pads.size()};
0030     Value* v_pads = graph->addInitializerAndCreateValue(t_pads);
0031     node->addInput(v_pads);
0032     node->removeAttribute(kpads);
0033     // Turn value attribute into input
0034     if (!node->hasAttribute(kmode) || node->s(kmode) == "constant") {
0035       if (!node->hasAttribute(kvalue))
0036         node->f_(kvalue, 0.);
0037       Tensor t_value;
0038       t_value.elem_type() = TensorProto_DataType_FLOAT;
0039       auto& data_value = t_value.floats();
0040       data_value.emplace_back(node->f(kvalue));
0041       Node* constant = graph->create(kConstant);
0042       constant->insertBefore(node);
0043       constant->t_(kvalue, t_value);
0044       node->addInput(constant->output());
0045       node->removeAttribute(kvalue);
0046     }
0047   }
0048 
0049   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0050     adapt_pad_10_11(graph, node);
0051     return node;
0052   }
0053 };
0054 
0055 } // namespace version_conversion
0056 } // namespace ONNX_NAMESPACE