File indexing completed on 2025-02-22 10:42:47
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008 #include "onnx/common/visitor.h"
0009
0010 namespace ONNX_NAMESPACE {
0011 namespace internal {
0012
0013 using AttributeMap = std::unordered_map<std::string, const AttributeProto*>;
0014
0015
0016
0017 class AttributeBinder : public MutableVisitor {
0018 public:
0019 AttributeBinder(const AttributeMap& attr_map) : attr_map_(attr_map) {}
0020
0021
0022
0023
0024
0025 void VisitNode(NodeProto* node) override {
0026 auto& attributes = *node->mutable_attribute();
0027 for (auto attr_iter = attributes.begin(); attr_iter != attributes.end();) {
0028 auto& attr = *attr_iter;
0029 if (!attr.ref_attr_name().empty()) {
0030
0031
0032 auto it = attr_map_.find(attr.ref_attr_name());
0033 if (it != attr_map_.end()) {
0034 const AttributeProto* replacement = it->second;
0035
0036 std::string name = attr.name();
0037 attr = *replacement;
0038 attr.set_name(name);
0039 ++attr_iter;
0040 } else {
0041 attr_iter = attributes.erase(attr_iter);
0042 }
0043 } else {
0044
0045 VisitAttribute(&attr);
0046 ++attr_iter;
0047 }
0048 }
0049 }
0050
0051
0052
0053 static void BindAttributes(const NodeProto& callnode, FunctionProto& callee) {
0054 AttributeMap map;
0055 for (auto& attr : callnode.attribute()) {
0056 map[attr.name()] = &attr;
0057 }
0058 AttributeBinder attr_binder(map);
0059 attr_binder.VisitFunction(&callee);
0060 }
0061
0062 private:
0063 const AttributeMap& attr_map_;
0064 };
0065
0066 }
0067 }