File indexing completed on 2025-04-03 08:57:55
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <cinttypes>
0010 #include <string>
0011 #include <utility>
0012 #include <vector>
0013
0014
0015
0016
0017
0018 #define NODE_TRANSFORMER(node) [=](std::shared_ptr<Graph>, Node * node)
0019
0020 namespace ONNX_NAMESPACE {
0021 namespace version_conversion {
0022
0023 inline NodeTransformerFunction RemoveAttribute(Symbol attr) {
0024 return NODE_TRANSFORMER(node) {
0025 if (node->hasAttribute(attr)) {
0026 node->removeAttribute(attr);
0027 }
0028 return node;
0029 };
0030 }
0031
0032 inline NodeTransformerFunction RemoveAttribute(Symbol attr, int64_t value) {
0033 return NODE_TRANSFORMER(node) {
0034 if (node->hasAttribute(attr)) {
0035 ONNX_ASSERTM(node->i(attr) == value, "Attribute %s must have value %" PRId64, attr.toString(), value);
0036 node->removeAttribute(attr);
0037 }
0038 return node;
0039 };
0040 }
0041
0042 inline NodeTransformerFunction RemoveAttributeNotEq(Symbol attr, int64_t value) {
0043 return NODE_TRANSFORMER(node) {
0044 if (node->hasAttribute(attr)) {
0045 ONNX_ASSERTM(node->i(attr) != value, "Attribute %s must not have value %" PRId64, attr.toString(), value);
0046 node->removeAttribute(attr);
0047 }
0048 return node;
0049 };
0050 }
0051
0052 inline NodeTransformerFunction SetAttribute(Symbol attr, int64_t value) {
0053 return NODE_TRANSFORMER(node) {
0054 node->i_(attr, value);
0055 return node;
0056 };
0057 }
0058
0059 inline NodeTransformerFunction SetAttribute(Symbol attr, const std::string& value) {
0060 return NODE_TRANSFORMER(node) {
0061 node->s_(attr, value);
0062 return node;
0063 };
0064 }
0065
0066 inline NodeTransformerFunction SetAttribute(Symbol attr, std::vector<int64_t> value) {
0067 return NODE_TRANSFORMER(node) {
0068 std::vector<int64_t> local(value);
0069 node->is_(attr, std::move(local));
0070 return node;
0071 };
0072 }
0073
0074 inline NodeTransformerFunction SetAttributeIfAbsent(Symbol attr, int64_t value) {
0075 return NODE_TRANSFORMER(node) {
0076 if (!node->hasAttribute(attr)) {
0077 node->i_(attr, value);
0078 }
0079 return node;
0080 };
0081 }
0082
0083 }
0084 }