Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:47

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 #pragma once
0008 #include "onnx/common/visitor.h"
0009 
0010 namespace ONNX_NAMESPACE {
0011 namespace internal { // internal/private API
0012 
0013 using AttributeMap = std::unordered_map<std::string, const AttributeProto*>;
0014 
0015 // Class for binding formal attribute-parameters (in a node or graph) to their values.
0016 
0017 class AttributeBinder : public MutableVisitor {
0018  public:
0019   AttributeBinder(const AttributeMap& attr_map) : attr_map_(attr_map) {}
0020 
0021   // Binding a formal attribute-parameter to a value may, as a special case, also
0022   // remove the attribute from the list of attributes of a node (when the attribute
0023   // has no specified value). Hence, we need to do the processing at a Node level
0024   // rather than an attribute level.
0025   void VisitNode(NodeProto* node) override {
0026     auto& attributes = *node->mutable_attribute();
0027     for (auto attr_iter = attributes.begin(); attr_iter != attributes.end();) {
0028       auto& attr = *attr_iter;
0029       if (!attr.ref_attr_name().empty()) {
0030         // Attribute-references must be replaced by the corresponding attribute-value in the call-node
0031         // if the call-node contains the attribute. Otherwise, this attribute must be removed.
0032         auto it = attr_map_.find(attr.ref_attr_name());
0033         if (it != attr_map_.end()) {
0034           const AttributeProto* replacement = it->second;
0035           // Copy value of attribute, but retain original name:
0036           std::string name = attr.name();
0037           attr = *replacement;
0038           attr.set_name(name);
0039           ++attr_iter;
0040         } else {
0041           attr_iter = attributes.erase(attr_iter);
0042         }
0043       } else {
0044         // For regular attributes, we process subgraphs, if present, recursively.
0045         VisitAttribute(&attr);
0046         ++attr_iter;
0047       }
0048     }
0049   }
0050 
0051   // Updates a FunctionProto by replacing all attribute-references with the corresponding
0052   // attribute-values in the call-node, if present. Otherwise, the attribute is removed.
0053   static void BindAttributes(const NodeProto& callnode, FunctionProto& callee) {
0054     AttributeMap map;
0055     for (auto& attr : callnode.attribute()) {
0056       map[attr.name()] = &attr;
0057     }
0058     AttributeBinder attr_binder(map);
0059     attr_binder.VisitFunction(&callee);
0060   }
0061 
0062  private:
0063   const AttributeMap& attr_map_;
0064 };
0065 
0066 } // namespace internal
0067 } // namespace ONNX_NAMESPACE