File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012 #include <vector>
0013
0014 namespace ONNX_NAMESPACE {
0015 namespace version_conversion {
0016
0017 class Pad_10_11 final : public Adapter {
0018 public:
0019 explicit Pad_10_11() : Adapter("Pad", OpSetID(10), OpSetID(11)) {}
0020
0021 void adapt_pad_10_11(std::shared_ptr<Graph> graph, Node* node) const {
0022
0023 Tensor t_pads;
0024 t_pads.elem_type() = TensorProto_DataType_INT64;
0025 auto& data_pads = t_pads.int64s();
0026 for (int64_t shape : node->is(kpads)) {
0027 data_pads.emplace_back(shape);
0028 }
0029 t_pads.sizes() = std::vector<int64_t>{(int64_t)data_pads.size()};
0030 Value* v_pads = graph->addInitializerAndCreateValue(t_pads);
0031 node->addInput(v_pads);
0032 node->removeAttribute(kpads);
0033
0034 if (!node->hasAttribute(kmode) || node->s(kmode) == "constant") {
0035 if (!node->hasAttribute(kvalue))
0036 node->f_(kvalue, 0.);
0037 Tensor t_value;
0038 t_value.elem_type() = TensorProto_DataType_FLOAT;
0039 auto& data_value = t_value.floats();
0040 data_value.emplace_back(node->f(kvalue));
0041 Node* constant = graph->create(kConstant);
0042 constant->insertBefore(node);
0043 constant->t_(kvalue, t_value);
0044 node->addInput(constant->output());
0045 node->removeAttribute(kvalue);
0046 }
0047 }
0048
0049 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0050 adapt_pad_10_11(graph, node);
0051 return node;
0052 }
0053 };
0054
0055 }
0056 }