File indexing completed on 2025-02-22 10:42:45
0001
0002
0003
0004
0005 #pragma once
0006
0007 #include <mutex>
0008 #include <string>
0009 #include <unordered_map>
0010 #include <utility>
0011 #include <vector>
0012
0013 #include "attr_proto_util.h"
0014 #include "onnx/common/constants.h"
0015 #include "onnx/common/status.h"
0016 #include "onnx/defs/parser.h"
0017 #include "onnx/defs/schema.h"
0018 #include "tensor_proto_util.h"
0019
0020 namespace ONNX_NAMESPACE {
0021
0022 void FunctionExpandHelper(
0023 const NodeProto& node,
0024 const FunctionProto& func,
0025 GraphProto& g,
0026 const std::string& node_prefix = "");
0027
0028 class FunctionBodyHelper {
0029 public:
0030 struct AttributeProtoWrapper {
0031 AttributeProto proto;
0032
0033 AttributeProtoWrapper() {}
0034
0035 AttributeProtoWrapper(const AttributeProto& attr_prot) {
0036 proto = attr_prot;
0037 }
0038
0039 template <typename T>
0040 AttributeProtoWrapper(const std::string& attr_name, const T& value) {
0041 proto = MakeAttribute(attr_name, value);
0042 }
0043 };
0044
0045 struct NodeDef {
0046 NodeDef(
0047 std::vector<std::string> outputs,
0048 std::string op_type,
0049 std::vector<std::string> inputs,
0050 std::vector<AttributeProtoWrapper> attributes = {},
0051 std::string domain = "")
0052 : outputs(std::move(outputs)),
0053 op_type(std::move(op_type)),
0054 inputs(std::move(inputs)),
0055 attributes(std::move(attributes)),
0056 domain(std::move(domain)) {}
0057
0058 std::vector<std::string> outputs;
0059 std::string op_type;
0060 std::vector<std::string> inputs;
0061 std::vector<AttributeProtoWrapper> attributes;
0062 std::string domain;
0063 };
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088
0089
0090 static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs);
0091
0092 static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs);
0093
0094 static bool BuildFunctionProto(
0095 FunctionProto& functionProto,
0096 const OpSchema& schema,
0097 const std::vector<NodeDef>& node_defs,
0098 const std::vector<OperatorSetIdProto>& relied_opsets);
0099
0100 template <typename T>
0101 static NodeDef Const(const std::string& name, const T& value) {
0102 return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(value)}}};
0103 }
0104
0105 template <typename T>
0106 static NodeDef Const(const std::string& name, const std::vector<T>& values) {
0107 return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(values)}}};
0108 }
0109 };
0110
0111 class FunctionBuilder {
0112 public:
0113 FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
0114
0115 FunctionBuilder& Add(const char* nodes_txt) {
0116 OnnxParser parser(nodes_txt);
0117 auto& nodes = *funProto.mutable_node();
0118
0119 while (!parser.EndOfInput()) {
0120 auto status = parser.Parse(*nodes.Add());
0121 if (!status.IsOK())
0122 ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
0123 }
0124
0125 return *this;
0126 }
0127
0128 FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
0129 OnnxParser parser(node_txt);
0130 auto& node = *funProto.add_node();
0131 auto status = parser.Parse(node);
0132 if (!status.IsOK()) {
0133 ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
0134 }
0135
0136 if (!parser.EndOfInput()) {
0137 ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
0138 }
0139
0140 *node.add_attribute() = attr;
0141
0142 return *this;
0143 }
0144
0145 template <typename T>
0146 FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) {
0147 return Add(node_txt, MakeAttribute(attr_name, attr_value));
0148 }
0149
0150 FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) {
0151 std::string constant_op(name);
0152 constant_op += " = Constant()";
0153 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0154 }
0155
0156
0157 template <typename T>
0158 FunctionBuilder& Const(const std::string& name, T const_value) {
0159 std::string constant_op(name);
0160 constant_op += " = Constant()";
0161 return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
0162 }
0163
0164
0165 template <typename T>
0166 FunctionBuilder& Const1D(const std::string& name, T const_value) {
0167 std::string constant_op(name);
0168 constant_op += " = Constant()";
0169 auto tensor = ToTensor(const_value);
0170 tensor.add_dims(1);
0171 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0172 }
0173
0174
0175 template <typename T>
0176 FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
0177 std::string constant_op(name);
0178 constant_op += " = Constant()";
0179 auto tensor = ToTensor(values);
0180 tensor.add_dims(values.size());
0181
0182 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0183 }
0184
0185 FunctionBuilder& AddOpset(const char* domain, int version) {
0186 auto* opset = funProto.add_opset_import();
0187 opset->set_domain(domain);
0188 opset->set_version(version);
0189 return *this;
0190 }
0191
0192 private:
0193 FunctionProto& funProto;
0194 };
0195
0196 }