Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 09:00:13

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // ATTENTION: The code in this file is highly EXPERIMENTAL.
0008 // Adventurous users should note that the APIs will probably change.
0009 
0010 #pragma once
0011 
0012 #include <stdint.h>
0013 
0014 #include <algorithm>
0015 #include <atomic>
0016 #include <cstdint>
0017 #include <functional>
0018 #include <iostream>
0019 #include <limits>
0020 #include <memory>
0021 #include <set>
0022 #include <sstream>
0023 #include <string>
0024 #include <unordered_set>
0025 #include <utility>
0026 #include <vector>
0027 
0028 #include "onnx/common/array_ref.h"
0029 #include "onnx/common/assertions.h"
0030 #include "onnx/common/common.h"
0031 #include "onnx/common/graph_node_list.h"
0032 #include "onnx/common/interned_strings.h"
0033 #include "onnx/common/tensor.h"
0034 #include "onnx/string_utils.h"
0035 
0036 #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
0037   TypeName(const TypeName&) = delete;           \
0038   TypeName& operator=(const TypeName&) = delete
0039 
0040 namespace ONNX_NAMESPACE {
0041 
0042 namespace { // internal/private API
0043 
0044 std::string toVarName(size_t i) {
0045   std::ostringstream oss;
0046   oss << "_v_" << i;
0047   return oss.str();
0048 }
0049 
0050 } // namespace
0051 
0052 // Graph represents one "function" of computation.
0053 // It uses a simple ownership model where the graph owns all the nodes inside it.
0054 // All references inside the graph are raw pointers.
0055 // Destroying the Graph will invalidate any pointers to nodes in the graph.
0056 struct Graph;
0057 
0058 // Node is the base class of the IR graph. It represents one computation
0059 // and dependencies on a list of Values. The "prim-ops", so to speak.
0060 struct Node;
0061 
0062 // A Value represents an input or output to node that is either a
0063 // Tensor or an opaque Handle object, as determined by type().
0064 struct Value;
0065 
0066 class ResourceGuard final {
0067   std::function<void()> destructor_;
0068   bool released_;
0069 
0070  public:
0071   ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard);
0072   explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {}
0073   ResourceGuard(ResourceGuard&& other) = default;
0074   ResourceGuard& operator=(ResourceGuard&& other) = default;
0075 
0076   ~ResourceGuard() {
0077     if (!released_)
0078       destructor_();
0079   }
0080 
0081   void release() {
0082     released_ = true;
0083   }
0084 };
0085 
0086 struct Dimension final {
0087   Dimension() : is_unknown(true), is_int(false), dim(-1) {}
0088   Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {} // NOLINT
0089   Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {} // NOLINT
0090 
0091   bool is_unknown;
0092   bool is_int;
0093   int64_t dim;
0094   std::string param;
0095 };
0096 
0097 enum class AttributeKind : uint8_t {
0098   // float, float list, int, int list, string, string list,
0099   // tensor, tensor list, subgraph, subgraph list. type proto, type proto list
0100   f,
0101   fs,
0102   i,
0103   is,
0104   s,
0105   ss,
0106   t,
0107   ts,
0108   g,
0109   gs,
0110   tp,
0111   tps
0112 };
0113 
0114 static inline const char* toString(AttributeKind kind) {
0115   static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
0116   ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
0117   return names[int(kind)];
0118 }
0119 
0120 struct AttributeValue {
0121   explicit AttributeValue(Symbol name) : name(name) {}
0122   using Ptr = std::unique_ptr<AttributeValue>;
0123   Symbol name;
0124   virtual AttributeKind kind() const = 0;
0125   virtual Ptr clone() const = 0;
0126   virtual ~AttributeValue() = default;
0127 };
0128 
0129 template <typename T, AttributeKind Kind>
0130 struct ScalarAttributeValue final : public AttributeValue {
0131   using ConstructorType = const T&;
0132   using ValueType = T;
0133   ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {}
0134   ValueType& value() {
0135     return value_;
0136   }
0137   virtual Ptr clone() const override {
0138     return Ptr(new ScalarAttributeValue(name, value_));
0139   }
0140   virtual AttributeKind kind() const override {
0141     return Kind;
0142   }
0143 
0144  private:
0145   ValueType value_;
0146 };
0147 
0148 template <typename T, AttributeKind Kind>
0149 struct VectorAttributeValue final : public AttributeValue {
0150   using ConstructorType = const std::vector<T>&&;
0151   using ValueType = std::vector<T>;
0152   VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {}
0153   ValueType& value() {
0154     return value_;
0155   }
0156   virtual AttributeKind kind() const override {
0157     return Kind;
0158   }
0159   virtual std::unique_ptr<AttributeValue> clone() const override {
0160     auto copy = value_;
0161     return Ptr(new VectorAttributeValue(name, std::move(copy)));
0162   }
0163 
0164  private:
0165   ValueType value_;
0166 };
0167 
0168 using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
0169 using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
0170 using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
0171 using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
0172 using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
0173 using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
0174 using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>;
0175 using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>;
0176 using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
0177 using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
0178 using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>;
0179 using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>;
0180 
0181 // CRTP so that Node which inherits Attributes can be return for
0182 // method chaining e.g:
0183 // Node * n = g->create(kSelect)->set_i(kOffset,3)->set_f(kValue,3.5);
0184 // we return Derived* pointers because Nodes are normally held as pointers.
0185 template <typename Derived>
0186 struct Attributes {
0187   Attributes() {}
0188   void copyAttributes(const Attributes& rhs) {
0189     values_.clear();
0190     values_.reserve(rhs.values_.size());
0191     for (auto& i : rhs.values_) {
0192       values_.push_back(i->clone());
0193     }
0194   }
0195   bool hasAttribute(Symbol name) const {
0196     return find(name, false) != values_.end();
0197   }
0198   AttributeKind kindOf(Symbol name) const {
0199     return (*find(name, true))->kind();
0200   }
0201   Derived* removeAttribute(Symbol name) {
0202     values_.erase(find(name, true));
0203     return This();
0204   }
0205   bool hasAttributes() const {
0206     return !values_.empty();
0207   }
0208   // The names are returned in order, since name actually is the index.
0209   std::vector<Symbol> attributeNames() const {
0210     std::vector<Symbol> names;
0211     names.reserve(values_.size());
0212     for (auto& a : values_)
0213       names.push_back(a->name);
0214     return names;
0215   }
0216 
0217 #define CREATE_ACCESSOR(Kind, method)                                           \
0218   Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) {              \
0219     return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \
0220   }                                                                             \
0221   const Kind##Attr::ValueType& method(Symbol name) const {                      \
0222     return get<Kind##Attr>(name);                                               \
0223   }
0224   CREATE_ACCESSOR(Float, f)
0225   CREATE_ACCESSOR(Floats, fs)
0226   CREATE_ACCESSOR(String, s)
0227   CREATE_ACCESSOR(Strings, ss)
0228   CREATE_ACCESSOR(Int, i)
0229   CREATE_ACCESSOR(Ints, is)
0230   CREATE_ACCESSOR(Tensor, t)
0231   CREATE_ACCESSOR(Tensors, ts)
0232   CREATE_ACCESSOR(Graph, g)
0233   CREATE_ACCESSOR(Graphs, gs)
0234   CREATE_ACCESSOR(TypeProto, tp)
0235   CREATE_ACCESSOR(TypeProtos, tps)
0236 
0237 #undef CREATE_ACCESSOR
0238 
0239  private:
0240   Derived* This() {
0241     return static_cast<Derived*>(this);
0242   }
0243   template <typename T>
0244   Derived* set(Symbol name, typename T::ConstructorType v) {
0245     auto it = find(name, false);
0246     auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
0247     if (it == values_.end()) {
0248       values_.push_back(std::move(nv));
0249     } else {
0250       *it = std::move(nv);
0251     }
0252     return This();
0253   }
0254   template <typename T>
0255   typename T::ValueType& get(Symbol name) const {
0256     auto it = find(name, true);
0257     T* child = static_cast<T*>(it->get());
0258     return child->value();
0259   }
0260   using AVPtr = AttributeValue::Ptr;
0261   // NB: For determinism, we use a vector rather than a hash map.  This does
0262   // mean that lookups are O(n), so you shouldn't use Attributes to store
0263   // a big pile of messages.
0264   std::vector<AVPtr> values_;
0265   using iterator = std::vector<AVPtr>::iterator;
0266   iterator find(Symbol name, bool required) {
0267     auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0268     ONNX_ASSERT(!required || it != values_.end());
0269     return it;
0270   }
0271   using const_iterator = std::vector<AVPtr>::const_iterator;
0272   const_iterator find(Symbol name, bool required) const {
0273     auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0274     ONNX_ASSERTM(
0275         !required || it != values_.end(),
0276         "%s:%u: %s: required undefined attribute '%s'",
0277         __FILE__,
0278         __LINE__,
0279         __func__,
0280         name.toString());
0281     return it;
0282   }
0283 };
0284 
0285 // Each use is represented by this type, see Node::uses()
0286 // 'user' is the consumer of the value, offset is the index into
0287 // 'user's input this where the produces will be found.
0288 struct Use final {
0289   Use(Node* user, size_t offset) : user(user), offset(offset) {}
0290   Node* user;
0291   size_t offset;
0292 };
0293 
0294 static inline bool operator==(const Use& a, const Use& b) {
0295   return a.user == b.user && a.offset == b.offset;
0296 }
0297 
0298 // the list types are intentionally simple, but we type-def
0299 // them here so if we need to change them, refactoring will be easier
0300 using node_list = std::vector<Node*>;
0301 using value_list = std::vector<Value*>;
0302 using use_list = std::vector<Use>;
0303 using NodeKind = Symbol;
0304 
0305 struct Value final {
0306   ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
0307   Value(Node* node_, size_t offset_);
0308   Value(Value&&) = default;
0309   Value& operator=(Value&&) = default;
0310   ~Value() = default;
0311 
0312  private:
0313   friend struct Node;
0314   friend struct Graph;
0315   Node* node_;
0316   size_t offset_;
0317   size_t unique_ = 0; // unique id
0318   size_t stage_ = 0; // 0-forward, 1-backward, 2-double-backward,...
0319   use_list uses_in_current_graph_;
0320   bool has_unique_name_;
0321   std::string unique_name_;
0322   int32_t elem_type_;
0323   bool has_sizes_;
0324   std::vector<Dimension> sizes_;
0325 
0326  public:
0327   Value* setElemType(int32_t elem_type) {
0328     elem_type_ = elem_type;
0329     return this;
0330   }
0331   int32_t elemType() const {
0332     return elem_type_;
0333   }
0334   bool has_sizes() const {
0335     return has_sizes_;
0336   }
0337   Value* setSizes(std::vector<Dimension> sizes) {
0338     has_sizes_ = true;
0339     sizes_ = std::move(sizes);
0340     return this;
0341   }
0342   Value* wipeSizes() {
0343     has_sizes_ = false;
0344     sizes_ = std::vector<Dimension>();
0345     return this;
0346   }
0347   const std::vector<Dimension>& sizes() const {
0348     return sizes_;
0349   }
0350   size_t unique() const {
0351     return unique_;
0352   }
0353   bool has_unique_name() const {
0354     return has_unique_name_;
0355   }
0356   std::string uniqueName() const {
0357     if (has_unique_name())
0358       return unique_name_;
0359     return toVarName(unique());
0360   }
0361   Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true);
0362   Value* setStage(size_t s) {
0363     stage_ = s;
0364     return this;
0365   }
0366   size_t stage() const {
0367     return stage_;
0368   }
0369   Node* node() {
0370     return node_;
0371   }
0372   size_t offset() const {
0373     return offset_;
0374   }
0375   const Node* node() const {
0376     return node_;
0377   }
0378   Graph* owningGraph();
0379   const Graph* owningGraph() const;
0380   // TODO: make this more const correct
0381   const use_list uses() const;
0382 
0383   // Replaces all uses of this node with 'newValue'.
0384   //
0385   // Given:   %3 = f(%1, %2)
0386   //          %4 = g(%3)
0387   //          %5 = h(%3, %3)
0388   // Execute: %3.replaceAllUsesWith(%6)
0389   // Result:  %3 = f(%1, %2)
0390   //          %4 = g(%6)
0391   //          %5 = h(%6, %6)
0392   void replaceAllUsesWith(Value* newValue);
0393 
0394   Value* copyMetadata(Value* from) {
0395     setElemType(from->elemType());
0396     setSizes(from->sizes());
0397     if (from->has_unique_name()) {
0398       setUniqueName(from->uniqueName());
0399     }
0400     return this;
0401   }
0402 };
0403 
0404 struct Node : public Attributes<Node> {
0405   ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
0406   friend struct Graph;
0407   friend struct Value;
0408   friend graph_node_list;
0409   friend const_graph_node_list;
0410   friend graph_node_list_iterator;
0411   friend const_graph_node_list_iterator;
0412 
0413  private:
0414   // each node but Return/Param
0415   // is associated with exactly one place in the node list...
0416   // of the graph_
0417   // this circular is a doubly-linked list, the Return node is used as the sentinel for the beginning and end of the
0418   // list such that the list never has null pointers next_in_graph[0] is next pointer next_in_graph[1] is prev pointer
0419   // using an array to allow the same iterator class for forward and reverse node lists
0420   // This list represents a topological sort
0421 
0422   Node* next_in_graph[2] = {nullptr, nullptr};
0423   Node*& next() {
0424     return next_in_graph[kNextDirection];
0425   }
0426   Node*& prev() {
0427     return next_in_graph[kPrevDirection];
0428   }
0429   Node* const& next() const {
0430     return next_in_graph[kNextDirection];
0431   }
0432   Node* const& prev() const {
0433     return next_in_graph[kPrevDirection];
0434   }
0435 
0436   const NodeKind kind_;
0437   std::vector<Value*> inputs_;
0438   std::vector<Value*> outputs_;
0439   Graph* graph_;
0440   size_t stage_;
0441   bool has_name_;
0442   std::string name_;
0443   bool has_domain_;
0444   std::string domain_;
0445   bool has_doc_string_;
0446   std::string doc_string_;
0447   bool has_overload_;
0448   std::string overload_;
0449 
0450  protected:
0451   Node(Graph* graph_, NodeKind kind_); // defined after graph
0452 
0453  public:
0454   bool has_name() const {
0455     return has_name_;
0456   }
0457   const std::string& name() const {
0458     return name_;
0459   }
0460   void setName(std::string name) {
0461     has_name_ = true;
0462     name_ = std::move(name);
0463   }
0464   bool has_domain() const {
0465     return has_domain_;
0466   }
0467   const std::string& domain() const {
0468     return domain_;
0469   }
0470   void setDomain(std::string domain) {
0471     has_domain_ = true;
0472     domain_ = std::move(domain);
0473   }
0474   bool has_overload() const {
0475     return has_overload_;
0476   }
0477   const std::string& overload() const {
0478     return overload_;
0479   }
0480   void setOverload(std::string overload) {
0481     has_overload_ = true;
0482     overload_ = std::move(overload);
0483   }
0484   bool has_doc_string() const {
0485     return has_doc_string_;
0486   }
0487   const std::string& docString() const {
0488     return doc_string_;
0489   }
0490   void setDocString(std::string doc_string) {
0491     has_doc_string_ = true;
0492     doc_string_ = std::move(doc_string);
0493   }
0494   NodeKind kind() const {
0495     return kind_;
0496   }
0497   Graph* owningGraph() {
0498     return graph_;
0499   }
0500   const Graph* owningGraph() const {
0501     return graph_;
0502   }
0503   size_t stage() const {
0504     return stage_;
0505   }
0506   Node* setStage(size_t s) {
0507     stage_ = s;
0508     return this;
0509   }
0510   // NB: This returns an ArrayRef; that means that it will
0511   // get invalidated if you resize inputs (e.g., using addInput)
0512   // We can't return a std::vector<Node*>& because there's no
0513   // way to soundly cast to std::vector<const Node*> (an insane
0514   // implementation of std::vector could make this representationally
0515   // different.)
0516   ArrayRef<Value*> inputs() {
0517     return inputs_;
0518   }
0519   ArrayRef<const Value*> inputs() const {
0520     // Vectors are not convertible in const-ness of elements, but
0521     // raw pointers are.
0522     return {inputs_.data(), inputs_.size()};
0523   }
0524   // NB: This returns an ArrayRef; that means that it will
0525   // get invalidated if you resize inputs (e.g., using addInput)
0526   // We can't return a std::vector<Node*>& because there's no
0527   // way to soundly cast to std::vector<const Node*> (an insane
0528   // implementation of std::vector could make this representationally
0529   // different.)
0530   ArrayRef<Value*> outputs() {
0531     return outputs_;
0532   }
0533   ArrayRef<const Value*> outputs() const {
0534     // Vectors are not convertible in const-ness of elements, but
0535     // raw pointers are.
0536     return {outputs_.data(), outputs_.size()};
0537   }
0538   bool hasUses() const {
0539     for (auto o : outputs()) {
0540       if (!o->uses().empty())
0541         return true;
0542     }
0543     return false;
0544   }
0545   void replaceAllUsesWith(Node* n) {
0546     ONNX_ASSERT(outputs().size() == n->outputs().size());
0547     size_t nOutputs = outputs().size();
0548     for (size_t i = 0; i < nOutputs; i++) {
0549       outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
0550     }
0551   }
0552   // lots of things like chunk have a single input or single output, so we have a
0553   // helper to make accessing it easier
0554   Value* input() {
0555     ONNX_ASSERT(inputs_.size() == 1);
0556     return inputs_.at(0);
0557   }
0558   Value* output() {
0559     ONNX_ASSERT(outputs_.size() == 1);
0560     return outputs_.at(0);
0561   }
0562   const Value* input() const {
0563     ONNX_ASSERT(inputs_.size() == 1);
0564     return inputs_.at(0);
0565   }
0566   Value* output() const {
0567     ONNX_ASSERT(outputs_.size() == 1);
0568     return outputs_.at(0);
0569   }
0570   // Access a particular input.  This is a checked index.
0571   Value* input(size_t i) {
0572     return inputs_.at(i);
0573   }
0574   const Value* input(size_t i) const {
0575     return inputs_.at(i);
0576   }
0577 
0578   // Graphs
0579 
0580   // Note [Topological invariant]
0581   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0582   // We always maintain an up-to-date topological ordering of all nodes via
0583   // the next()/prev() links.  All transformations to graphs must preserve
0584   // this topological ordering: for example, it is only valid to 'addInput'
0585   // with an input which is topologically before the current node.
0586   //
0587   // Usually, it is obvious whether or not topological order is maintained;
0588   // for example, if you are adding nodes to the end of the topsort, it's
0589   // impossible for them to refer to inputs that are not in the topsort.
0590   // If it is not obvious, please comment accordingly.
0591 
0592   // Add 'node' as an input to 'this' at the end of existing
0593   // arguments.  Returns the added node for ease of chaining.
0594   //
0595   // Given:   %3 = f(%1, %2)
0596   // Execute: %3.addInput(%4)
0597   // Result:  %3 = f(%1, %2, %4)
0598   Value* addInput(Value* node) {
0599     ONNX_ASSERT(graph_ == node->owningGraph());
0600     node->uses_in_current_graph_.emplace_back(this, inputs_.size());
0601     inputs_.push_back(node);
0602     return node;
0603   }
0604 
0605   // Replace the input of 'this' at position 'i' with
0606   // 'newValue', returning the old node.
0607   //
0608   // Given:   %3 = f(%1, %2)
0609   // Execute: %3.replaceInput(1, %4)
0610   // Result:  %3 = f(%1, %4)
0611   Value* replaceInput(size_t i, Value* newValue) {
0612     ONNX_ASSERT(newValue->owningGraph() == graph_);
0613     Value* old = dropInput(i);
0614     inputs_[i] = newValue;
0615     newValue->uses_in_current_graph_.emplace_back(this, i);
0616     return old;
0617   }
0618 
0619   // Replace all occurrences of 'from' in the inputs of this
0620   // node with 'to'. Corresponds to llvm's replaceUsesOfWith.
0621   //
0622   // Given:   %3 = f(%1, %2, %1)
0623   // Execute: %3.replaceInputWith(%1, %4)
0624   // Result:  %3 = f(%4, %2, %4)
0625   void replaceInputWith(Value* from, Value* to) {
0626     ONNX_ASSERT(from->owningGraph() == graph_);
0627     ONNX_ASSERT(to->owningGraph() == graph_);
0628     size_t i = 0;
0629     for (auto input : inputs()) {
0630       if (input == from)
0631         replaceInput(i, to);
0632       i++;
0633     }
0634   }
0635 
0636   Value* addOutput() {
0637     outputs_.push_back(new Value(this, outputs_.size()));
0638     return outputs_.back();
0639   }
0640 
0641   void eraseOutput(size_t i);
0642 
0643   // Insert unattached 'this' node after 'n' in the topological order.
0644   // Returns this (for chaining).
0645   //
0646   // Given:   %3 = f(%1, %2)
0647   //          %4 = g(%3)
0648   // and unattached: %5 = h(%1)
0649   // Execute: %5.insertBefore(%4)
0650   // Result:  %3 = f(%1, %2)
0651   //          %5 = h(%1)
0652   //          %4 = g(%3)
0653   Node* insertBefore(Node* n) {
0654     ONNX_ASSERT(n->inGraphList());
0655     insertAfter(n->prev());
0656     return this;
0657   }
0658 
0659   // Insert unattached 'this' node after 'n' in the topological order.
0660   // Returns this (for chaining).
0661   //
0662   // Given: %3 = f(%1, %2)
0663   //        %4 = g(%3)
0664   // and unattached: %5 = h(%1)
0665   // Execute: %5.insertAfter(%4)
0666   // Result:  %3 = f(%1, %2)
0667   //          %4 = g(%3)
0668   //          %5 = h(%1)
0669   Node* insertAfter(Node* n) {
0670     ONNX_ASSERT(!inGraphList() && n->inGraphList());
0671     Node* next = n->next();
0672     n->next() = this;
0673     this->prev() = n;
0674     this->next() = next;
0675     next->prev() = this;
0676     return this;
0677   }
0678 
0679   // Move 'this' (already in the graph) after 'n' in the topological order.
0680   //
0681   // Given: %2 = f(%1)
0682   //        %3 = g(%1)
0683   // Execute: %2.moveAfter(%3)
0684   // Result: %3 = g(%1)
0685   //         %2 = f(%1)
0686   //
0687   void moveAfter(Node* n) {
0688     removeFromList();
0689     insertAfter(n);
0690   }
0691 
0692   // Move a node 'n' (already in the graph) before 'this' in the topological order.
0693   //
0694   // Given: %2 = f(%1)
0695   //        %3 = g(%1)
0696   // Execute: %3.moveBefore(%2)
0697   // Result: %3 = g(%1)
0698   //         %2 = f(%1)
0699   void moveBefore(Node* n) {
0700     removeFromList();
0701     insertBefore(n);
0702   }
0703 
0704   // Remove the input at 'i' from this node.
0705   //
0706   // WARNING: This is O(n) in the number of inputs, so avoid repeatedly calling
0707   // removeInput.
0708   //
0709   // Given: %3 = f(%1, %2)
0710   // Execute: %3.removeInput(1)
0711   // Result: %3 = f(%1)
0712   void removeInput(size_t i) {
0713     dropInput(i);
0714     // everything after this input shifts left,
0715     // so we need to update their use offsets to match
0716     for (size_t j = i + 1; j < inputs_.size(); j++) {
0717       auto it = findUseForInput(j);
0718       it->offset--;
0719     }
0720     inputs_.erase(inputs_.begin() + i);
0721   }
0722 
0723   // Remove all inputs from a node.
0724   //
0725   // Given: %3 = f(%1, %2)
0726   // Execute: %3.removeAllInputs()
0727   // Result: %3 = f()
0728   void removeAllInputs() {
0729     for (size_t i = 0; i < inputs().size(); ++i)
0730       dropInput(i);
0731     inputs_.clear();
0732   }
0733 
0734   // Check whether this node is before node n in the graph.
0735   bool isBefore(Node* n);
0736 
0737   // iterators of the node list starting at this node
0738   // useful for resuming a search starting at this node
0739   graph_node_list_iterator iterator();
0740   graph_node_list_iterator reverseIterator();
0741   const_graph_node_list_iterator iterator() const;
0742   const_graph_node_list_iterator reverseIterator() const;
0743 
0744   // Remove 'this' from the instruction list and deallocate it.
0745   //
0746   // Invariant: no outputs of 'this' may have any uses.
0747   //
0748   // Given: %2 = f(%1)
0749   //        %3 = g(%1)
0750   // Execute: %2.destroy()
0751   // Result: %3 = g(%1)
0752   void destroy();
0753 
0754   // Dynamically cast this node to the subclass indicated by the
0755   // template variable, returning nullptr if the cast is invalid..
0756   //
0757   // Example usage: if(auto s = n.cast<Select>()) { ... }
0758   //
0759   // TODO: Make this const correct
0760   template <typename T>
0761   T* cast() {
0762     if (T::Kind == kind())
0763       return static_cast<T*>(this);
0764     return nullptr;
0765   }
0766   template <typename T>
0767   T* expect() {
0768     ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
0769     return static_cast<T*>(this);
0770   }
0771 
0772   virtual ~Node() = default;
0773 
0774  private:
0775   // Lookup iterator in use list of _input i_ that corresponds to its use of _this_
0776   use_list::iterator findUseForInput(size_t i) {
0777     auto& input_uses = inputs_[i]->uses_in_current_graph_;
0778     // O(N) on the use list, but unless we get nodes with +100 uses
0779     // vector traversal still is probably faster than linked list
0780     auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
0781     ONNX_ASSERT(use_it != input_uses.end());
0782     return use_it;
0783   }
0784 
0785   // remove the use of input i, this sets input i to nullptr, but
0786   // is only used internally to Node before setting it to a new value
0787   // or erasing the entry from the list.
0788   Value* dropInput(size_t i) {
0789     ONNX_ASSERT(i < inputs_.size());
0790     auto input_node = inputs_[i];
0791     auto use_it = findUseForInput(i);
0792     input_node->uses_in_current_graph_.erase(use_it);
0793     inputs_[i] = nullptr;
0794     return input_node;
0795   }
0796 
0797   bool inGraphList() const {
0798     ONNX_ASSERT(next() != nullptr || prev() == nullptr);
0799     return next() != nullptr;
0800   }
0801   void removeFromList() {
0802     ONNX_ASSERT(inGraphList());
0803     Node* next = this->next();
0804     Node* prev = this->prev();
0805     prev->next() = next;
0806     next->prev() = prev;
0807     this->next() = nullptr;
0808     this->prev() = nullptr;
0809   }
0810 
0811  protected:
0812   // subclasses must override
0813   // this function is used by createClone to initialize a new version
0814   // of a node in another graph. It should allocate a new instance of the same
0815   // concrete type as 'this', but in graph 'g' which might be different
0816   // than graph_
0817   virtual Node* allocNewInstance(Graph* g) {
0818     return new Node(g, kind());
0819   }
0820   // create a copy of all properties of Node s into this.
0821   // subclasses should extend if they have additional information to copy.
0822   // 'this' will be allocated with s->allocNewInstance(g) so it should have
0823   // the same concrete type as 's'
0824   //
0825   // NB: This does NOT clone stages.  You're expected to set the stage correctly
0826   // if you are going to preserve it.
0827   virtual void cloneFrom(Node* s) {
0828     copyAttributes(*s);
0829   }
0830 };
0831 
0832 // A class with the same properties as OperatorSetIdProto, but without protobuf
0833 // overhead, resulting in a simpler and more readable workflow.
0834 class OpSetID final {
0835  private:
0836   std::string domain_;
0837   int64_t version_;
0838 
0839  public:
0840   explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {}
0841 
0842   // Default Domain Constructor
0843   explicit OpSetID(const int64_t version) : domain_(""), version_(version) {}
0844 
0845   explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {}
0846 
0847   // target must be in the form "<domain>&<version>"
0848   std::string toString() const {
0849     return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
0850   }
0851 
0852   // target must be in the form "<domain>&<version>"
0853   static OpSetID fromString(const std::string& target) {
0854     ONNX_TRY {
0855       std::string new_domain = target.substr(0, target.find("$"));
0856       int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
0857       return OpSetID(new_domain, new_version);
0858     }
0859     ONNX_CATCH(const std::runtime_error& e) {
0860       ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s", e.what()); });
0861     }
0862 
0863     // The control will never reach here.
0864     // In the default build where exceptions are turned on in case of any error
0865     // the control will enter catch block where an exception will be thrown again.
0866     // In case of "no exception build" the code aborts at the site of first exception.
0867     // Adding this to appease the warning "control may reach end of non-void function"
0868     // as the mac build fails when ONNX_WERROR==ON
0869     return OpSetID("", 0);
0870   }
0871 
0872   const std::string& domain() const {
0873     return domain_;
0874   }
0875 
0876   int64_t version() const {
0877     return version_;
0878   }
0879 
0880   void incrementVersion(int64_t step) {
0881     version_ += step;
0882   }
0883 
0884   void setVersion(int64_t newVal) {
0885     version_ = newVal;
0886   }
0887 };
0888 
0889 struct Graph final {
0890   ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
0891   friend struct Node;
0892   friend struct Value;
0893 
0894  private:
0895   // only used to keep track of allocated nodes
0896   // actual representation of Graph is done with
0897   // inputs, outputs, nodes
0898 
0899   std::unordered_set<const Node*> all_nodes;
0900   std::unordered_set<const Value*> all_values;
0901   size_t next_unique_;
0902 
0903   size_t new_node_stage_;
0904 
0905   // holds outputs in a way that can be reflected
0906   // as a Use object
0907   // also used as the beginning/end of the circular node list to avoid
0908   // having corner cases where the list is empty.
0909   Node* const output_;
0910   Node* const input_;
0911   // Create an independent node list for those initializers do not exist in input
0912   Node* const initializer_node_;
0913 
0914   std::vector<Tensor> initializers_;
0915   std::vector<std::string> initializer_names_;
0916 
0917   bool has_name_;
0918   std::string name_;
0919   bool has_doc_string_;
0920   std::string doc_string_;
0921 
0922   std::vector<OpSetID> opset_versions_;
0923 
0924   bool isNameUnique(const std::string& name) const {
0925     if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) {
0926       return false;
0927     }
0928     const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
0929     for (const Node* node : all_nodes) {
0930       for (const auto& attr : node->attributeNames()) {
0931         if (node->kindOf(attr) == AttributeKind::g) {
0932           const auto& subgraph = node->g(attr);
0933           if (!subgraph->isNameUnique(name)) {
0934             return false;
0935           }
0936         } else if (node->kindOf(attr) == AttributeKind::gs) {
0937           for (const auto& subgraph : node->gs(attr)) {
0938             if (!subgraph->isNameUnique(name)) {
0939               return false;
0940             }
0941           }
0942         }
0943       }
0944       const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f);
0945       if (found_in != node->inputs().end()) {
0946         return false;
0947       }
0948       const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f);
0949       if (found_out != node->outputs().end()) {
0950         return false;
0951       }
0952     }
0953     return true;
0954   }
0955 
0956  public:
0957   Graph()
0958       : next_unique_(0),
0959         new_node_stage_(0),
0960         output_(initOutput(create(kReturn, 0))),
0961         input_(create(kParam, 0)),
0962         initializer_node_(create(kParam, 0)),
0963         has_name_(false),
0964         has_doc_string_(false) {}
0965 
0966   bool has_doc_string() const {
0967     return has_doc_string_;
0968   }
0969   const std::string& docString() {
0970     return doc_string_;
0971   }
0972   void setDocString(std::string doc_string) {
0973     has_doc_string_ = true;
0974     doc_string_ = std::move(doc_string);
0975   }
0976 
0977   void addInitializer(Tensor& initializer) {
0978     if (initializer.name().empty()) {
0979       initializer.setName(toVarName(getNextUnique()));
0980     }
0981     initializers_.push_back(initializer);
0982     initializer_names_.push_back(initializer.name());
0983   }
0984 
0985   // For IR >= 4, initializer is not required to exist in input
0986   // Add initializer into initializer node list and return its Value
0987   Value* addInitializerAndCreateValue(Tensor& initializer) {
0988     addInitializer(initializer);
0989     auto* init_value = initializer_node_->addOutput();
0990     std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()};
0991     init_value->setUniqueName(initializer.name());
0992     init_value->setSizes(dim_sizes);
0993     init_value->setElemType(initializer.elem_type());
0994     return init_value;
0995   }
0996 
0997   void eraseInitializer(const std::string& name) {
0998     initializers_.erase(
0999         std::remove_if(
1000             initializers_.begin(),
1001             initializers_.end(),
1002             [&name](Tensor& initializer) { return initializer.name() == name; }),
1003         initializers_.end());
1004     initializer_names_.erase(
1005         std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end());
1006     for (size_t i = 0; i < initializer_node_->outputs().size(); i++) {
1007       if (initializer_node_->outputs()[i]->uniqueName() == name) {
1008         initializer_node_->eraseOutput(i);
1009         break;
1010       }
1011     }
1012   }
1013   void clearInitializers() {
1014     initializers_.clear();
1015     initializer_names_.clear();
1016   }
1017   const std::vector<Tensor>& initializers() const {
1018     return initializers_;
1019   }
1020   const std::vector<std::string>& initializer_names() const {
1021     return initializer_names_;
1022   }
1023   std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
1024     for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
1025       if (name == it->name()) {
1026         return it;
1027       }
1028     }
1029     return initializers_.end();
1030   }
1031   bool is_constant_initializer(const Value* value) const {
1032     return value->node() == initializer_node_;
1033   }
1034   ArrayRef<Value*> inputs() {
1035     return input_->outputs();
1036   }
1037   ArrayRef<const Value*> inputs() const {
1038     const auto& inputs = input_->outputs();
1039     return {inputs.data(), inputs.size()};
1040   }
1041   ArrayRef<Value*> outputs() {
1042     return output_->inputs();
1043   }
1044   ArrayRef<const Value*> outputs() const {
1045     return static_cast<const Node*>(output_)->inputs();
1046   }
1047   graph_node_list nodes() {
1048     return graph_node_list(output_, kNextDirection);
1049   }
1050   const_graph_node_list nodes() const {
1051     return const_graph_node_list(output_, kNextDirection);
1052   }
1053 
1054   std::vector<OpSetID>& opset_versions_mutable() {
1055     return opset_versions_;
1056   }
1057 
1058   size_t getNextUnique() {
1059     std::string next_unique_name = toVarName(++next_unique_);
1060     while (!isNameUnique(next_unique_name)) {
1061       next_unique_name = toVarName(++next_unique_);
1062     }
1063     return next_unique_;
1064   }
1065 
1066   // These invocations of begin() on output of function are OK
1067   // because graph_node_list is non-owning, so it doesn't matter
1068   // if it immediately dies after the invocation.
1069   graph_node_list_iterator begin() {
1070     return nodes().begin();
1071   }
1072   const_graph_node_list_iterator begin() const {
1073     return nodes().begin();
1074   }
1075   graph_node_list_iterator end() {
1076     return nodes().end();
1077   }
1078   const_graph_node_list_iterator end() const {
1079     return nodes().end();
1080   }
1081   graph_node_list_iterator rbegin() {
1082     return nodes().rbegin();
1083   }
1084   const_graph_node_list_iterator rbegin() const {
1085     return nodes().rbegin();
1086   }
1087   graph_node_list_iterator rend() {
1088     return nodes().rend();
1089   }
1090   const_graph_node_list_iterator rend() const {
1091     return nodes().rend();
1092   }
1093   Node* return_node() {
1094     return output_;
1095   }
1096   const Node* return_node() const {
1097     return output_;
1098   }
1099 
1100   Value* addInput() {
1101     return input_->addOutput();
1102   }
1103   void eraseInput(size_t i) {
1104     input_->eraseOutput(i);
1105   }
1106   void advanceStage() {
1107     new_node_stage_++;
1108   }
1109   void setStage(size_t new_stage) {
1110     new_node_stage_ = new_stage;
1111   }
1112   size_t stage() const {
1113     return new_node_stage_;
1114   }
1115   ResourceGuard setStageTemporary(size_t s) {
1116     auto prev_stage = new_node_stage_;
1117     new_node_stage_ = s;
1118     return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1119   }
1120 
1121   size_t registerOutput(Value* n) {
1122     output_->addInput(n);
1123     return outputs().size() - 1;
1124   }
1125 
1126   Node* create(NodeKind kind, size_t num_outputs = 1) {
1127     // NB: Node constructor adds node to all_nodes
1128     auto n = new Node(this, kind);
1129     for (size_t i = 0; i < num_outputs; i++)
1130       n->addOutput();
1131     return n;
1132   }
1133 
1134   Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) {
1135     auto n = create(kind, num_outputs);
1136     for (auto i : inputs)
1137       n->addInput(i);
1138     return n;
1139   }
1140 
1141   Node* appendNode(Node* n) {
1142     ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1143     n->insertBefore(output_);
1144     return n;
1145   }
1146 
1147   Node* prependNode(Node* n) {
1148     ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1149     n->insertAfter(output_);
1150     return n;
1151   }
1152 
1153   // Adds to graph initializer list, initializer names list, and as a graph input
1154   // Also syncs the initializer name, tensor name, and value name
1155   // Create an initializer whose value is stored in input
1156   Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) {
1157     Tensor initializerCopy = initializer;
1158     std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()};
1159     Value* new_init = addInput();
1160     initializerCopy.setName(name);
1161     new_init->setUniqueName(name);
1162     new_init->setSizes(dim_sizes);
1163     new_init->setElemType(initializerCopy.elem_type());
1164     addInitializer(initializerCopy);
1165     return new_init;
1166   }
1167 
1168   Value* addInitializerAndInput(const Tensor& initializer) {
1169     return addInitializerAndInput(initializer, toVarName(getNextUnique()));
1170   }
1171 
1172   // Erases from graph initializer list, initializer names list, and as a graph input
1173   // Must have no uses
1174   void eraseInitializerAndInput(Value* v) {
1175     eraseInitializer(v->uniqueName());
1176     if (v->node() == input_) {
1177       eraseInput(v->offset());
1178     }
1179   }
1180 
1181   ~Graph() {
1182     for (const Node* n : all_nodes)
1183       delete n;
1184     for (const Value* v : all_values)
1185       delete v;
1186   }
1187 
1188   std::string toString() const {
1189     std::ostringstream oss;
1190     oss << *this;
1191     return oss.str();
1192   }
1193 
1194   bool has_name() const {
1195     return has_name_;
1196   }
1197 
1198   const std::string& name() const {
1199     return name_;
1200   }
1201 
1202   void setName(std::string name) {
1203     has_name_ = true;
1204     name_ = std::move(name);
1205   }
1206 
1207   friend std::ostream& operator<<(std::ostream& out, const Graph& g);
1208 
1209   void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) {
1210     fn(this);
1211 
1212     for (const Node* node : all_nodes) {
1213       for (const auto& attr : node->attributeNames()) {
1214         if (node->kindOf(attr) == AttributeKind::g) {
1215           std::shared_ptr<Graph> subgraph = node->g(attr);
1216           subgraph->forSelfAndEachSubGraph(fn);
1217         } else if (node->kindOf(attr) == AttributeKind::gs) {
1218           for (const auto& subgraph : node->gs(attr)) {
1219             subgraph->forSelfAndEachSubGraph(fn);
1220           }
1221         }
1222       }
1223     }
1224   }
1225 
1226   void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const {
1227     std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); };
1228     const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1229   }
1230 
1231   void forEachNode(const std::function<void(Node*)>& fn) {
1232     forSelfAndEachSubGraph([fn](Graph* graph) {
1233       for (Node* node : graph->nodes()) {
1234         fn(node);
1235       }
1236     });
1237   }
1238 
1239   void forEachNode(const std::function<void(const Node*)>& fn) const {
1240     std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); };
1241     const_cast<Graph*>(this)->forEachNode(tmp_fn);
1242   }
1243 
1244  private:
1245   // should only be called in the constructor
1246   Node* initOutput(Node* p) {
1247     p->next() = p;
1248     p->prev() = p;
1249     p->setStage(std::numeric_limits<size_t>::max());
1250     return p;
1251   }
1252 
1253   void freeNode(Node* n) {
1254     auto it = all_nodes.find(n);
1255     ONNX_ASSERT(it != all_nodes.end());
1256     delete *it;
1257     all_nodes.erase(it);
1258   }
1259   void freeValue(Value* v) {
1260     auto it = all_values.find(v);
1261     ONNX_ASSERT(it != all_values.end());
1262     delete *it;
1263     all_values.erase(it);
1264   }
1265 };
1266 
1267 inline Value::Value(Node* node_, size_t offset_)
1268     : node_(node_),
1269       offset_(offset_),
1270       unique_(node_->graph_->getNextUnique()),
1271       stage_(node_->graph_->new_node_stage_),
1272       has_unique_name_(false),
1273       elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1274       has_sizes_(false) {
1275   node_->graph_->all_values.emplace(this);
1276 }
1277 
1278 inline Graph* Value::owningGraph() {
1279   return node()->owningGraph();
1280 }
1281 
1282 inline const Graph* Value::owningGraph() const {
1283   return node()->owningGraph();
1284 }
1285 
1286 // `captured` nodes in subgraph determines which value it captures
1287 // by storing the value's unique name, so old unique names in `captured` nodes
1288 // should also be updated.
1289 // Initializer names are also storaged in graph.initializer_names_, it should be
1290 // updated too.
1291 inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) {
1292   if (has_unique_name() && update_related_names) {
1293     auto* graph = owningGraph();
1294     auto old_name = unique_name_;
1295     for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) {
1296       auto& initializer_name = owningGraph()->initializer_names_[i];
1297       if (initializer_name == old_name) {
1298         initializer_name = name;
1299         owningGraph()->initializers_[i].setName(name);
1300       }
1301     }
1302     graph->forEachNode([this, &name, &old_name](Node* node) {
1303       if (node->owningGraph() == this->owningGraph()) {
1304         // skip non-subgraph
1305         return;
1306       }
1307       if (node->kind() == kCaptured) {
1308         Value* output = node->output();
1309         if (output->uniqueName() == old_name) {
1310           output->setUniqueName(name, false);
1311         }
1312       }
1313     });
1314   }
1315   unique_name_ = name;
1316   has_unique_name_ = true;
1317   return this;
1318 }
1319 
1320 inline void Value::replaceAllUsesWith(Value* newValue) {
1321   auto* graph = owningGraph();
1322   ONNX_ASSERT(graph == newValue->owningGraph());
1323   // propagate sizes and elem type
1324   if (this->has_sizes()) {
1325     newValue->setSizes(this->sizes());
1326   }
1327   if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1328     newValue->setElemType(this->elemType());
1329   }
1330   const auto unique_name = this->uniqueName();
1331   // We do not want the optimization to change the graph output name
1332   if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) {
1333     newValue->setUniqueName(unique_name);
1334     // The "unique" semantic of unique_name should be kept or uses()
1335     // will return an incorrect result when the value is used in subgraph
1336     this->setUniqueName(toVarName(graph->getNextUnique()), false);
1337   }
1338   newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1339   for (auto u : uses_in_current_graph_) {
1340     u.user->inputs_[u.offset] = newValue;
1341     newValue->uses_in_current_graph_.push_back(u);
1342   }
1343   graph->forEachNode([this, &newValue, &unique_name](Node* node) {
1344     if (node->owningGraph() == this->owningGraph()) {
1345       // skip non-subgraph
1346       return;
1347     }
1348     if (node->kind() == kCaptured) {
1349       Value* output = node->output();
1350       if (output->uniqueName() == unique_name) {
1351         output->setUniqueName(newValue->uniqueName());
1352       }
1353     }
1354   });
1355   uses_in_current_graph_.clear();
1356   assert(this->uses().empty());
1357 }
1358 
1359 inline Node::Node(Graph* graph_, NodeKind kind_)
1360     : kind_(kind_),
1361       graph_(graph_),
1362       stage_(graph_->new_node_stage_),
1363       has_name_(false),
1364       has_domain_(false),
1365       has_doc_string_(false),
1366       has_overload_(false) {
1367   graph_->all_nodes.emplace(this);
1368 }
1369 
1370 inline void Node::eraseOutput(size_t i) {
1371   ONNX_ASSERT(i < outputs_.size());
1372   ONNX_ASSERT(outputs_[i]->uses().empty());
1373   Value* n = outputs_[i];
1374   outputs_.erase(outputs_.begin() + i);
1375   owningGraph()->freeValue(n);
1376   for (size_t j = i; j < outputs_.size(); j++) {
1377     outputs_[j]->offset_--;
1378   }
1379 }
1380 
1381 inline bool Node::isBefore(Node* n) {
1382   if (n == nullptr || this == n) {
1383     // Bail out early.
1384     return false;
1385   }
1386   // return true if node is Param (in initializers)
1387   if (kind_ == kParam) {
1388     return true;
1389   }
1390   // return false if target node is Param (in initializers)
1391   if (n->kind() == kParam) {
1392     return false;
1393   }
1394   ONNX_ASSERT(n->inGraphList());
1395   for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1396     if (p == n) {
1397       return true;
1398     }
1399   }
1400   return false;
1401 }
1402 
1403 inline void Node::destroy() {
1404   ONNX_ASSERT(inGraphList());
1405   while (!outputs().empty())
1406     eraseOutput(outputs().size() - 1);
1407   removeAllInputs();
1408   removeFromList();
1409   graph_->freeNode(this);
1410 }
1411 
1412 /************* All nodes not required to be defined before Graph **************/
1413 
1414 inline graph_node_list_iterator Node::iterator() {
1415   return graph_node_list_iterator(this, 0);
1416 }
1417 inline graph_node_list_iterator Node::reverseIterator() {
1418   return iterator().reverse();
1419 }
1420 inline const_graph_node_list_iterator Node::iterator() const {
1421   return const_graph_node_list_iterator(this, 0);
1422 }
1423 inline const_graph_node_list_iterator Node::reverseIterator() const {
1424   return iterator().reverse();
1425 }
1426 
1427 // Returns a list about which nodes are using this value,
1428 // nodes in subgraph are also included.
1429 // This method is usually used to check whether it is
1430 // safe to delete a Value.
1431 inline const use_list Value::uses() const {
1432   use_list all_uses = uses_in_current_graph_;
1433   owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1434     if (node->owningGraph() == this->owningGraph()) {
1435       // skip non-subgraph
1436       return;
1437     }
1438     if (node->kind() == kCaptured) {
1439       const Value* output = node->outputs()[0];
1440       if (output->uniqueName() == this->uniqueName()) {
1441         const auto output_uses = output->uses();
1442         all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1443       }
1444     }
1445   });
1446   return all_uses;
1447 }
1448 
1449 } // namespace ONNX_NAMESPACE