Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:45

0001 /*
0002  * SPDX-License-Identifier: Apache-2.0
0003  */
0004 
0005 #pragma once
0006 
0007 #include <mutex>
0008 #include <string>
0009 #include <unordered_map>
0010 #include <utility>
0011 #include <vector>
0012 
0013 #include "attr_proto_util.h"
0014 #include "onnx/common/constants.h"
0015 #include "onnx/common/status.h"
0016 #include "onnx/defs/parser.h"
0017 #include "onnx/defs/schema.h"
0018 #include "tensor_proto_util.h"
0019 
0020 namespace ONNX_NAMESPACE {
0021 // Helper function to expand a function node given the function proto
0022 void FunctionExpandHelper(
0023     const NodeProto& node,
0024     const FunctionProto& func,
0025     GraphProto& g,
0026     const std::string& node_prefix = "");
0027 
0028 class FunctionBodyHelper {
0029  public:
0030   struct AttributeProtoWrapper {
0031     AttributeProto proto;
0032 
0033     AttributeProtoWrapper() {}
0034 
0035     AttributeProtoWrapper(const AttributeProto& attr_prot) {
0036       proto = attr_prot;
0037     }
0038 
0039     template <typename T>
0040     AttributeProtoWrapper(const std::string& attr_name, const T& value) {
0041       proto = MakeAttribute(attr_name, value);
0042     }
0043   };
0044 
0045   struct NodeDef {
0046     NodeDef(
0047         std::vector<std::string> outputs,
0048         std::string op_type,
0049         std::vector<std::string> inputs,
0050         std::vector<AttributeProtoWrapper> attributes = {},
0051         std::string domain = "")
0052         : outputs(std::move(outputs)),
0053           op_type(std::move(op_type)),
0054           inputs(std::move(inputs)),
0055           attributes(std::move(attributes)),
0056           domain(std::move(domain)) {}
0057 
0058     std::vector<std::string> outputs;
0059     std::string op_type;
0060     std::vector<std::string> inputs;
0061     std::vector<AttributeProtoWrapper> attributes;
0062     std::string domain;
0063   };
0064 
0065   /*
0066   BuildNodes() is an utility function for easily define a Function Body.
0067 
0068   To build a simple node:
0069     {{"Z"}, "Add", {"X", "Y"}} represents Z = Add(X,Y)
0070 
0071   To build a node with attribute:
0072     {{"Y"}, "Concat", {"X1", "X2", "X3"}, {{"axis", 1}}}
0073       represents Y = Concat(X1,X2,X3) with axis = 1
0074     The attribute type are infered from the attribute value's c++ type
0075     Supported value types are
0076       int64_t -> int, vector<int64_t> -> ints
0077       float -> float, vector<float> -> floats
0078       string -> string, vector<string> ->strings
0079     For refering an attribute from parent, use:
0080       {MakeRefAttribute("axes", AttributeProto::INTS)}}
0081 
0082   To build a node which belongs to a domain other than onnx standard domain:
0083     {{"Z"}, "Foo", {"X", "Y"}, "customdomain"} represents Z = customdomain.Foo(X,Y)
0084     or
0085     {{"Y"}, "Bar", {"X1", "X2", "X3"}, {{"axis", 1}}, "customdomain"}
0086       represents Y = customdomain.Bar(X1,X2,X3) with axis = 1
0087 
0088   For more examples, please find the references of this function
0089   */
0090   static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs);
0091 
0092   static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs);
0093 
0094   static bool BuildFunctionProto(
0095       FunctionProto& functionProto,
0096       const OpSchema& schema,
0097       const std::vector<NodeDef>& node_defs,
0098       const std::vector<OperatorSetIdProto>& relied_opsets);
0099 
0100   template <typename T>
0101   static NodeDef Const(const std::string& name, const T& value) {
0102     return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(value)}}};
0103   }
0104 
0105   template <typename T>
0106   static NodeDef Const(const std::string& name, const std::vector<T>& values) {
0107     return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(values)}}};
0108   }
0109 };
0110 
0111 class FunctionBuilder {
0112  public:
0113   FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
0114 
0115   FunctionBuilder& Add(const char* nodes_txt) {
0116     OnnxParser parser(nodes_txt);
0117     auto& nodes = *funProto.mutable_node();
0118 
0119     while (!parser.EndOfInput()) {
0120       auto status = parser.Parse(*nodes.Add());
0121       if (!status.IsOK())
0122         ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
0123     }
0124 
0125     return *this;
0126   }
0127 
0128   FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
0129     OnnxParser parser(node_txt);
0130     auto& node = *funProto.add_node();
0131     auto status = parser.Parse(node);
0132     if (!status.IsOK()) {
0133       ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
0134     }
0135 
0136     if (!parser.EndOfInput()) {
0137       ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
0138     }
0139 
0140     *node.add_attribute() = attr;
0141 
0142     return *this;
0143   }
0144 
0145   template <typename T>
0146   FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) {
0147     return Add(node_txt, MakeAttribute(attr_name, attr_value));
0148   }
0149 
0150   FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) {
0151     std::string constant_op(name);
0152     constant_op += " = Constant()";
0153     return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0154   }
0155 
0156   // Creates a scalar constant (a tensor of rank zero).
0157   template <typename T>
0158   FunctionBuilder& Const(const std::string& name, T const_value) {
0159     std::string constant_op(name);
0160     constant_op += " = Constant()";
0161     return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
0162   }
0163 
0164   // Creates a 1D tensor constant consisting of a single value.
0165   template <typename T>
0166   FunctionBuilder& Const1D(const std::string& name, T const_value) {
0167     std::string constant_op(name);
0168     constant_op += " = Constant()";
0169     auto tensor = ToTensor(const_value);
0170     tensor.add_dims(1);
0171     return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0172   }
0173 
0174   // Creates a 1D tensor constant consisting of zero or more values.
0175   template <typename T>
0176   FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
0177     std::string constant_op(name);
0178     constant_op += " = Constant()";
0179     auto tensor = ToTensor(values);
0180     tensor.add_dims(values.size()); // Treat as 1D tensor.
0181 
0182     return Add(constant_op.c_str(), MakeAttribute("value", tensor));
0183   }
0184 
0185   FunctionBuilder& AddOpset(const char* domain, int version) {
0186     auto* opset = funProto.add_opset_import();
0187     opset->set_domain(domain);
0188     opset->set_version(version);
0189     return *this;
0190   }
0191 
0192  private:
0193   FunctionProto& funProto;
0194 };
0195 
0196 } // namespace ONNX_NAMESPACE