Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-03 08:57:54

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // Adapter for Dropout in default domain from version 11 to 12
0008 
0009 #pragma once
0010 
0011 #include <memory>
0012 
0013 namespace ONNX_NAMESPACE {
0014 namespace version_conversion {
0015 
0016 class Dropout_11_12 final : public Adapter {
0017  public:
0018   explicit Dropout_11_12() : Adapter("Dropout", OpSetID(11), OpSetID(12)) {}
0019 
0020   void adapt_dropout_11_12(std::shared_ptr<Graph> graph, Node* node) const {
0021     float ratio;
0022     if (node->hasAttribute(kratio)) {
0023       ratio = node->f(kratio);
0024       node->removeAttribute(kratio);
0025     } else {
0026       ratio = 0.5;
0027     }
0028 
0029     Tensor t_ratio;
0030     t_ratio.elem_type() = TensorProto_DataType_FLOAT;
0031     auto& data_ratio = t_ratio.floats();
0032     data_ratio.emplace_back(ratio);
0033     Node* constant = graph->create(kConstant);
0034     constant->insertBefore(node);
0035     constant->t_(kvalue, t_ratio);
0036     node->addInput(constant->output());
0037   }
0038 
0039   Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
0040     adapt_dropout_11_12(graph, node);
0041     return node;
0042   }
0043 };
0044 
0045 } // namespace version_conversion
0046 } // namespace ONNX_NAMESPACE