File indexing completed on 2025-04-03 08:57:54
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <memory>
0012
0013 namespace ONNX_NAMESPACE {
0014 namespace version_conversion {
0015
0016 class Dropout_11_12 final : public Adapter {
0017 public:
0018 explicit Dropout_11_12() : Adapter("Dropout", OpSetID(11), OpSetID(12)) {}
0019
0020 void adapt_dropout_11_12(std::shared_ptr<Graph> graph, Node* node) const {
0021 float ratio;
0022 if (node->hasAttribute(kratio)) {
0023 ratio = node->f(kratio);
0024 node->removeAttribute(kratio);
0025 } else {
0026 ratio = 0.5;
0027 }
0028
0029 Tensor t_ratio;
0030 t_ratio.elem_type() = TensorProto_DataType_FLOAT;
0031 auto& data_ratio = t_ratio.floats();
0032 data_ratio.emplace_back(ratio);
0033 Node* constant = graph->create(kConstant);
0034 constant->insertBefore(node);
0035 constant->t_(kvalue, t_ratio);
0036 node->addInput(constant->output());
0037 }
0038
0039 Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0040 adapt_dropout_11_12(graph, node);
0041 return node;
0042 }
0043 };
0044
0045 }
0046 }