File indexing completed on 2025-02-22 10:42:44
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #pragma once
0011
0012 #include <stdint.h>
0013
0014 #include <algorithm>
0015 #include <atomic>
0016 #include <cstdint>
0017 #include <functional>
0018 #include <iostream>
0019 #include <limits>
0020 #include <memory>
0021 #include <set>
0022 #include <sstream>
0023 #include <string>
0024 #include <unordered_set>
0025 #include <utility>
0026 #include <vector>
0027
0028 #include "onnx/common/array_ref.h"
0029 #include "onnx/common/assertions.h"
0030 #include "onnx/common/common.h"
0031 #include "onnx/common/graph_node_list.h"
0032 #include "onnx/common/interned_strings.h"
0033 #include "onnx/common/tensor.h"
0034 #include "onnx/string_utils.h"
0035
0036 #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
0037 TypeName(const TypeName&) = delete; \
0038 TypeName& operator=(const TypeName&) = delete
0039
0040 namespace ONNX_NAMESPACE {
0041
0042
0043
0044
0045
0046 struct Graph;
0047
0048
0049
0050 struct Node;
0051
0052
0053
0054 struct Value;
0055
0056 class ResourceGuard final {
0057 std::function<void()> destructor_;
0058 bool released_;
0059
0060 public:
0061 ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard);
0062 explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {}
0063 ResourceGuard(ResourceGuard&& other) = default;
0064 ResourceGuard& operator=(ResourceGuard&& other) = default;
0065
0066 ~ResourceGuard() {
0067 if (!released_)
0068 destructor_();
0069 }
0070
0071 void release() {
0072 released_ = true;
0073 }
0074 };
0075
0076 struct Dimension final {
0077 Dimension() : is_unknown(true), is_int(false), dim(-1) {}
0078 Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {}
0079 Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {}
0080
0081 bool is_unknown;
0082 bool is_int;
0083 int64_t dim;
0084 std::string param;
0085 };
0086
0087 enum class AttributeKind : uint8_t {
0088
0089
0090 f,
0091 fs,
0092 i,
0093 is,
0094 s,
0095 ss,
0096 t,
0097 ts,
0098 g,
0099 gs,
0100 tp,
0101 tps
0102 };
0103
0104 static inline const char* toString(AttributeKind kind) {
0105 static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
0106 ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
0107 return names[int(kind)];
0108 }
0109
0110 struct AttributeValue {
0111 explicit AttributeValue(Symbol name) : name(name) {}
0112 using Ptr = std::unique_ptr<AttributeValue>;
0113 Symbol name;
0114 virtual AttributeKind kind() const = 0;
0115 virtual Ptr clone() const = 0;
0116 virtual ~AttributeValue() = default;
0117 };
0118
0119 template <typename T, AttributeKind Kind>
0120 struct ScalarAttributeValue final : public AttributeValue {
0121 using ConstructorType = const T&;
0122 using ValueType = T;
0123 ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {}
0124 ValueType& value() {
0125 return value_;
0126 }
0127 virtual Ptr clone() const override {
0128 return Ptr(new ScalarAttributeValue(name, value_));
0129 }
0130 virtual AttributeKind kind() const override {
0131 return Kind;
0132 }
0133
0134 private:
0135 ValueType value_;
0136 };
0137
0138 template <typename T, AttributeKind Kind>
0139 struct VectorAttributeValue final : public AttributeValue {
0140 using ConstructorType = const std::vector<T>&&;
0141 using ValueType = std::vector<T>;
0142 VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {}
0143 ValueType& value() {
0144 return value_;
0145 }
0146 virtual AttributeKind kind() const override {
0147 return Kind;
0148 }
0149 virtual std::unique_ptr<AttributeValue> clone() const override {
0150 auto copy = value_;
0151 return Ptr(new VectorAttributeValue(name, std::move(copy)));
0152 }
0153
0154 private:
0155 ValueType value_;
0156 };
0157
0158 using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
0159 using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
0160 using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
0161 using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
0162 using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
0163 using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
0164 using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>;
0165 using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>;
0166 using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
0167 using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
0168 using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>;
0169 using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>;
0170
0171
0172
0173
0174
0175 template <typename Derived>
0176 struct Attributes {
0177 Attributes() {}
0178 void copyAttributes(const Attributes& rhs) {
0179 values_.clear();
0180 values_.reserve(rhs.values_.size());
0181 for (auto& i : rhs.values_) {
0182 values_.push_back(i->clone());
0183 }
0184 }
0185 bool hasAttribute(Symbol name) const {
0186 return find(name, false) != values_.end();
0187 }
0188 AttributeKind kindOf(Symbol name) const {
0189 return (*find(name, true))->kind();
0190 }
0191 Derived* removeAttribute(Symbol name) {
0192 values_.erase(find(name, true));
0193 return This();
0194 }
0195 bool hasAttributes() const {
0196 return !values_.empty();
0197 }
0198
0199 std::vector<Symbol> attributeNames() const {
0200 std::vector<Symbol> names;
0201 names.reserve(values_.size());
0202 for (auto& a : values_)
0203 names.push_back(a->name);
0204 return names;
0205 }
0206
0207 #define CREATE_ACCESSOR(Kind, method) \
0208 Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
0209 return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \
0210 } \
0211 const Kind##Attr::ValueType& method(Symbol name) const { \
0212 return get<Kind##Attr>(name); \
0213 }
0214 CREATE_ACCESSOR(Float, f)
0215 CREATE_ACCESSOR(Floats, fs)
0216 CREATE_ACCESSOR(String, s)
0217 CREATE_ACCESSOR(Strings, ss)
0218 CREATE_ACCESSOR(Int, i)
0219 CREATE_ACCESSOR(Ints, is)
0220 CREATE_ACCESSOR(Tensor, t)
0221 CREATE_ACCESSOR(Tensors, ts)
0222 CREATE_ACCESSOR(Graph, g)
0223 CREATE_ACCESSOR(Graphs, gs)
0224 CREATE_ACCESSOR(TypeProto, tp)
0225 CREATE_ACCESSOR(TypeProtos, tps)
0226
0227 #undef CREATE_ACCESSOR
0228
0229 private:
0230 Derived* This() {
0231 return static_cast<Derived*>(this);
0232 }
0233 template <typename T>
0234 Derived* set(Symbol name, typename T::ConstructorType v) {
0235 auto it = find(name, false);
0236 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
0237 if (it == values_.end()) {
0238 values_.push_back(std::move(nv));
0239 } else {
0240 *it = std::move(nv);
0241 }
0242 return This();
0243 }
0244 template <typename T>
0245 typename T::ValueType& get(Symbol name) const {
0246 auto it = find(name, true);
0247 T* child = static_cast<T*>(it->get());
0248 return child->value();
0249 }
0250 using AVPtr = AttributeValue::Ptr;
0251
0252
0253
0254 std::vector<AVPtr> values_;
0255 using iterator = std::vector<AVPtr>::iterator;
0256 iterator find(Symbol name, bool required) {
0257 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0258 ONNX_ASSERT(!required || it != values_.end());
0259 return it;
0260 }
0261 using const_iterator = std::vector<AVPtr>::const_iterator;
0262 const_iterator find(Symbol name, bool required) const {
0263 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0264 ONNX_ASSERTM(
0265 !required || it != values_.end(),
0266 "%s:%u: %s: required undefined attribute '%s'",
0267 __FILE__,
0268 __LINE__,
0269 __func__,
0270 name.toString());
0271 return it;
0272 }
0273 };
0274
0275
0276
0277
0278 struct Use final {
0279 Use(Node* user, size_t offset) : user(user), offset(offset) {}
0280 Node* user;
0281 size_t offset;
0282 };
0283
0284 static inline bool operator==(const Use& a, const Use& b) {
0285 return a.user == b.user && a.offset == b.offset;
0286 }
0287
0288
0289
0290 using node_list = std::vector<Node*>;
0291 using value_list = std::vector<Value*>;
0292 using use_list = std::vector<Use>;
0293 using NodeKind = Symbol;
0294
0295 struct Value final {
0296 ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
0297 Value(Node* node_, size_t offset_);
0298 Value(Value&&) = default;
0299 Value& operator=(Value&&) = default;
0300 ~Value() = default;
0301
0302 private:
0303 friend struct Node;
0304 friend struct Graph;
0305 Node* node_;
0306 size_t offset_;
0307 size_t unique_ = 0;
0308 size_t stage_ = 0;
0309 use_list uses_in_current_graph_;
0310 bool has_unique_name_;
0311 std::string unique_name_;
0312 int32_t elem_type_;
0313 bool has_sizes_;
0314 std::vector<Dimension> sizes_;
0315
0316 public:
0317 Value* setElemType(int32_t elem_type) {
0318 elem_type_ = elem_type;
0319 return this;
0320 }
0321 int32_t elemType() const {
0322 return elem_type_;
0323 }
0324 bool has_sizes() const {
0325 return has_sizes_;
0326 }
0327 Value* setSizes(std::vector<Dimension> sizes) {
0328 has_sizes_ = true;
0329 sizes_ = std::move(sizes);
0330 return this;
0331 }
0332 Value* wipeSizes() {
0333 has_sizes_ = false;
0334 sizes_ = std::vector<Dimension>();
0335 return this;
0336 }
0337 const std::vector<Dimension>& sizes() const {
0338 return sizes_;
0339 }
0340 size_t unique() const {
0341 return unique_;
0342 }
0343 bool has_unique_name() const {
0344 return has_unique_name_;
0345 }
0346 std::string uniqueName() const {
0347 if (has_unique_name())
0348 return unique_name_;
0349 return ONNX_NAMESPACE::to_string(unique());
0350 }
0351 Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true);
0352 Value* setStage(size_t s) {
0353 stage_ = s;
0354 return this;
0355 }
0356 size_t stage() const {
0357 return stage_;
0358 }
0359 Node* node() {
0360 return node_;
0361 }
0362 size_t offset() const {
0363 return offset_;
0364 }
0365 const Node* node() const {
0366 return node_;
0367 }
0368 Graph* owningGraph();
0369 const Graph* owningGraph() const;
0370
0371 const use_list uses() const;
0372
0373
0374
0375
0376
0377
0378
0379
0380
0381
0382 void replaceAllUsesWith(Value* newValue);
0383
0384 Value* copyMetadata(Value* from) {
0385 setElemType(from->elemType());
0386 setSizes(from->sizes());
0387 if (from->has_unique_name()) {
0388 setUniqueName(from->uniqueName());
0389 }
0390 return this;
0391 }
0392 };
0393
0394 struct Node : public Attributes<Node> {
0395 ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
0396 friend struct Graph;
0397 friend struct Value;
0398 friend graph_node_list;
0399 friend const_graph_node_list;
0400 friend graph_node_list_iterator;
0401 friend const_graph_node_list_iterator;
0402
0403 private:
0404
0405
0406
0407
0408
0409
0410
0411
0412 Node* next_in_graph[2] = {nullptr, nullptr};
0413 Node*& next() {
0414 return next_in_graph[kNextDirection];
0415 }
0416 Node*& prev() {
0417 return next_in_graph[kPrevDirection];
0418 }
0419 Node* const& next() const {
0420 return next_in_graph[kNextDirection];
0421 }
0422 Node* const& prev() const {
0423 return next_in_graph[kPrevDirection];
0424 }
0425
0426 const NodeKind kind_;
0427 std::vector<Value*> inputs_;
0428 std::vector<Value*> outputs_;
0429 Graph* graph_;
0430 size_t stage_;
0431 bool has_name_;
0432 std::string name_;
0433 bool has_domain_;
0434 std::string domain_;
0435 bool has_doc_string_;
0436 std::string doc_string_;
0437
0438 protected:
0439 Node(Graph* graph_, NodeKind kind_);
0440
0441 public:
0442 bool has_name() const {
0443 return has_name_;
0444 }
0445 const std::string& name() const {
0446 return name_;
0447 }
0448 void setName(std::string name) {
0449 has_name_ = true;
0450 name_ = std::move(name);
0451 }
0452 bool has_domain() const {
0453 return has_domain_;
0454 }
0455 const std::string& domain() const {
0456 return domain_;
0457 }
0458 void setDomain(std::string domain) {
0459 has_domain_ = true;
0460 domain_ = std::move(domain);
0461 }
0462 bool has_doc_string() const {
0463 return has_doc_string_;
0464 }
0465 const std::string& docString() const {
0466 return doc_string_;
0467 }
0468 void setDocString(std::string doc_string) {
0469 has_doc_string_ = true;
0470 doc_string_ = std::move(doc_string);
0471 }
0472 NodeKind kind() const {
0473 return kind_;
0474 }
0475 Graph* owningGraph() {
0476 return graph_;
0477 }
0478 const Graph* owningGraph() const {
0479 return graph_;
0480 }
0481 size_t stage() const {
0482 return stage_;
0483 }
0484 Node* setStage(size_t s) {
0485 stage_ = s;
0486 return this;
0487 }
0488
0489
0490
0491
0492
0493
0494 ArrayRef<Value*> inputs() {
0495 return inputs_;
0496 }
0497 ArrayRef<const Value*> inputs() const {
0498
0499
0500 return {inputs_.data(), inputs_.size()};
0501 }
0502
0503
0504
0505
0506
0507
0508 ArrayRef<Value*> outputs() {
0509 return outputs_;
0510 }
0511 ArrayRef<const Value*> outputs() const {
0512
0513
0514 return {outputs_.data(), outputs_.size()};
0515 }
0516 bool hasUses() const {
0517 for (auto o : outputs()) {
0518 if (!o->uses().empty())
0519 return true;
0520 }
0521 return false;
0522 }
0523 void replaceAllUsesWith(Node* n) {
0524 ONNX_ASSERT(outputs().size() == n->outputs().size());
0525 size_t nOutputs = outputs().size();
0526 for (size_t i = 0; i < nOutputs; i++) {
0527 outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
0528 }
0529 }
0530
0531
0532 Value* input() {
0533 ONNX_ASSERT(inputs_.size() == 1);
0534 return inputs_.at(0);
0535 }
0536 Value* output() {
0537 ONNX_ASSERT(outputs_.size() == 1);
0538 return outputs_.at(0);
0539 }
0540 const Value* input() const {
0541 ONNX_ASSERT(inputs_.size() == 1);
0542 return inputs_.at(0);
0543 }
0544 Value* output() const {
0545 ONNX_ASSERT(outputs_.size() == 1);
0546 return outputs_.at(0);
0547 }
0548
0549 Value* input(size_t i) {
0550 return inputs_.at(i);
0551 }
0552 const Value* input(size_t i) const {
0553 return inputs_.at(i);
0554 }
0555
0556
0557
0558
0559
0560
0561
0562
0563
0564
0565
0566
0567
0568
0569
0570
0571
0572
0573
0574
0575
0576 Value* addInput(Value* node) {
0577 ONNX_ASSERT(graph_ == node->owningGraph());
0578 node->uses_in_current_graph_.emplace_back(this, inputs_.size());
0579 inputs_.push_back(node);
0580 return node;
0581 }
0582
0583
0584
0585
0586
0587
0588
0589 Value* replaceInput(size_t i, Value* newValue) {
0590 ONNX_ASSERT(newValue->owningGraph() == graph_);
0591 Value* old = dropInput(i);
0592 inputs_[i] = newValue;
0593 newValue->uses_in_current_graph_.emplace_back(this, i);
0594 return old;
0595 }
0596
0597
0598
0599
0600
0601
0602
0603 void replaceInputWith(Value* from, Value* to) {
0604 ONNX_ASSERT(from->owningGraph() == graph_);
0605 ONNX_ASSERT(to->owningGraph() == graph_);
0606 size_t i = 0;
0607 for (auto input : inputs()) {
0608 if (input == from)
0609 replaceInput(i, to);
0610 i++;
0611 }
0612 }
0613
0614 Value* addOutput() {
0615 outputs_.push_back(new Value(this, outputs_.size()));
0616 return outputs_.back();
0617 }
0618
0619 void eraseOutput(size_t i);
0620
0621
0622
0623
0624
0625
0626
0627
0628
0629
0630
0631 Node* insertBefore(Node* n) {
0632 ONNX_ASSERT(n->inGraphList());
0633 insertAfter(n->prev());
0634 return this;
0635 }
0636
0637
0638
0639
0640
0641
0642
0643
0644
0645
0646
0647 Node* insertAfter(Node* n) {
0648 ONNX_ASSERT(!inGraphList() && n->inGraphList());
0649 Node* next = n->next();
0650 n->next() = this;
0651 this->prev() = n;
0652 this->next() = next;
0653 next->prev() = this;
0654 return this;
0655 }
0656
0657
0658
0659
0660
0661
0662
0663
0664
0665 void moveAfter(Node* n) {
0666 removeFromList();
0667 insertAfter(n);
0668 }
0669
0670
0671
0672
0673
0674
0675
0676
0677 void moveBefore(Node* n) {
0678 removeFromList();
0679 insertBefore(n);
0680 }
0681
0682
0683
0684
0685
0686
0687
0688
0689
0690 void removeInput(size_t i) {
0691 dropInput(i);
0692
0693
0694 for (size_t j = i + 1; j < inputs_.size(); j++) {
0695 auto it = findUseForInput(j);
0696 it->offset--;
0697 }
0698 inputs_.erase(inputs_.begin() + i);
0699 }
0700
0701
0702
0703
0704
0705
0706 void removeAllInputs() {
0707 for (size_t i = 0; i < inputs().size(); ++i)
0708 dropInput(i);
0709 inputs_.clear();
0710 }
0711
0712
0713 bool isBefore(Node* n);
0714
0715
0716
0717 graph_node_list_iterator iterator();
0718 graph_node_list_iterator reverseIterator();
0719 const_graph_node_list_iterator iterator() const;
0720 const_graph_node_list_iterator reverseIterator() const;
0721
0722
0723
0724
0725
0726
0727
0728
0729
0730 void destroy();
0731
0732
0733
0734
0735
0736
0737
0738 template <typename T>
0739 T* cast() {
0740 if (T::Kind == kind())
0741 return static_cast<T*>(this);
0742 return nullptr;
0743 }
0744 template <typename T>
0745 T* expect() {
0746 ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
0747 return static_cast<T*>(this);
0748 }
0749
0750 virtual ~Node() = default;
0751
0752 private:
0753
0754 use_list::iterator findUseForInput(size_t i) {
0755 auto& input_uses = inputs_[i]->uses_in_current_graph_;
0756
0757
0758 auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
0759 ONNX_ASSERT(use_it != input_uses.end());
0760 return use_it;
0761 }
0762
0763
0764
0765
0766 Value* dropInput(size_t i) {
0767 ONNX_ASSERT(i < inputs_.size());
0768 auto input_node = inputs_[i];
0769 auto use_it = findUseForInput(i);
0770 input_node->uses_in_current_graph_.erase(use_it);
0771 inputs_[i] = nullptr;
0772 return input_node;
0773 }
0774
0775 bool inGraphList() const {
0776 ONNX_ASSERT(next() != nullptr || prev() == nullptr);
0777 return next() != nullptr;
0778 }
0779 void removeFromList() {
0780 ONNX_ASSERT(inGraphList());
0781 Node* next = this->next();
0782 Node* prev = this->prev();
0783 prev->next() = next;
0784 next->prev() = prev;
0785 this->next() = nullptr;
0786 this->prev() = nullptr;
0787 }
0788
0789 protected:
0790
0791
0792
0793
0794
0795 virtual Node* allocNewInstance(Graph* g) {
0796 return new Node(g, kind());
0797 }
0798
0799
0800
0801
0802
0803
0804
0805 virtual void cloneFrom(Node* s) {
0806 copyAttributes(*s);
0807 }
0808 };
0809
0810
0811
0812 class OpSetID final {
0813 private:
0814 std::string domain_;
0815 int64_t version_;
0816
0817 public:
0818 explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {}
0819
0820
0821 explicit OpSetID(const int64_t version) : domain_(""), version_(version) {}
0822
0823 explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {}
0824
0825
0826 std::string toString() const {
0827 return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
0828 }
0829
0830
0831 static OpSetID fromString(const std::string& target) {
0832 ONNX_TRY {
0833 std::string new_domain = target.substr(0, target.find("$"));
0834 int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
0835 return OpSetID(new_domain, new_version);
0836 }
0837 ONNX_CATCH(const std::runtime_error& e) {
0838 ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s", e.what()); });
0839 }
0840
0841
0842
0843
0844
0845
0846
0847 return OpSetID("", 0);
0848 }
0849
0850 const std::string& domain() const {
0851 return domain_;
0852 }
0853
0854 int64_t version() const {
0855 return version_;
0856 }
0857
0858 void incrementVersion(int64_t step) {
0859 version_ += step;
0860 }
0861
0862 void setVersion(int64_t newVal) {
0863 version_ = newVal;
0864 }
0865 };
0866
0867 struct Graph final {
0868 ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
0869 friend struct Node;
0870 friend struct Value;
0871
0872 private:
0873
0874
0875
0876
0877 std::unordered_set<const Node*> all_nodes;
0878 std::unordered_set<const Value*> all_values;
0879 size_t next_unique_;
0880
0881 size_t new_node_stage_;
0882
0883
0884
0885
0886
0887 Node* const output_;
0888 Node* const input_;
0889
0890 Node* const initializer_node_;
0891
0892 std::vector<Tensor> initializers_;
0893 std::vector<std::string> initializer_names_;
0894
0895 bool has_name_;
0896 std::string name_;
0897 bool has_doc_string_;
0898 std::string doc_string_;
0899
0900 std::vector<OpSetID> opset_versions_;
0901
0902 bool isNameUnique(const std::string& name) const {
0903 if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) {
0904 return false;
0905 }
0906 const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
0907 for (const Node* node : all_nodes) {
0908 for (const auto& attr : node->attributeNames()) {
0909 if (node->kindOf(attr) == AttributeKind::g) {
0910 const auto& subgraph = node->g(attr);
0911 if (!subgraph->isNameUnique(name)) {
0912 return false;
0913 }
0914 } else if (node->kindOf(attr) == AttributeKind::gs) {
0915 for (const auto& subgraph : node->gs(attr)) {
0916 if (!subgraph->isNameUnique(name)) {
0917 return false;
0918 }
0919 }
0920 }
0921 }
0922 const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f);
0923 if (found_in != node->inputs().end()) {
0924 return false;
0925 }
0926 const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f);
0927 if (found_out != node->outputs().end()) {
0928 return false;
0929 }
0930 }
0931 return true;
0932 }
0933
0934 public:
0935 Graph()
0936 : next_unique_(0),
0937 new_node_stage_(0),
0938 output_(initOutput(create(kReturn, 0))),
0939 input_(create(kParam, 0)),
0940 initializer_node_(create(kParam, 0)),
0941 has_name_(false),
0942 has_doc_string_(false) {}
0943
0944 bool has_doc_string() const {
0945 return has_doc_string_;
0946 }
0947 const std::string& docString() {
0948 return doc_string_;
0949 }
0950 void setDocString(std::string doc_string) {
0951 has_doc_string_ = true;
0952 doc_string_ = std::move(doc_string);
0953 }
0954
0955 void addInitializer(Tensor& initializer) {
0956 if (initializer.name().empty()) {
0957 initializer.setName(ONNX_NAMESPACE::to_string(getNextUnique()));
0958 }
0959 initializers_.push_back(initializer);
0960 initializer_names_.push_back(initializer.name());
0961 }
0962
0963
0964
0965 Value* addInitializerAndCreateValue(Tensor& initializer) {
0966 addInitializer(initializer);
0967 auto* init_value = initializer_node_->addOutput();
0968 std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()};
0969 init_value->setUniqueName(initializer.name());
0970 init_value->setSizes(dim_sizes);
0971 init_value->setElemType(initializer.elem_type());
0972 return init_value;
0973 }
0974
0975 void eraseInitializer(const std::string& name) {
0976 initializers_.erase(
0977 std::remove_if(
0978 initializers_.begin(),
0979 initializers_.end(),
0980 [&name](Tensor& initializer) { return initializer.name() == name; }),
0981 initializers_.end());
0982 initializer_names_.erase(
0983 std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end());
0984 for (size_t i = 0; i < initializer_node_->outputs().size(); i++) {
0985 if (initializer_node_->outputs()[i]->uniqueName() == name) {
0986 initializer_node_->eraseOutput(i);
0987 break;
0988 }
0989 }
0990 }
0991 void clearInitializers() {
0992 initializers_.clear();
0993 initializer_names_.clear();
0994 }
0995 const std::vector<Tensor>& initializers() const {
0996 return initializers_;
0997 }
0998 const std::vector<std::string>& initializer_names() const {
0999 return initializer_names_;
1000 }
1001 std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
1002 for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
1003 if (name == it->name()) {
1004 return it;
1005 }
1006 }
1007 return initializers_.end();
1008 }
1009 bool is_constant_initializer(const Value* value) const {
1010 return value->node() == initializer_node_;
1011 }
1012 ArrayRef<Value*> inputs() {
1013 return input_->outputs();
1014 }
1015 ArrayRef<const Value*> inputs() const {
1016 const auto& inputs = input_->outputs();
1017 return {inputs.data(), inputs.size()};
1018 }
1019 ArrayRef<Value*> outputs() {
1020 return output_->inputs();
1021 }
1022 ArrayRef<const Value*> outputs() const {
1023 return static_cast<const Node*>(output_)->inputs();
1024 }
1025 graph_node_list nodes() {
1026 return graph_node_list(output_, kNextDirection);
1027 }
1028 const_graph_node_list nodes() const {
1029 return const_graph_node_list(output_, kNextDirection);
1030 }
1031
1032 std::vector<OpSetID>& opset_versions_mutable() {
1033 return opset_versions_;
1034 }
1035
1036 size_t getNextUnique() {
1037 std::string next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1038 while (!isNameUnique(next_unique_name)) {
1039 next_unique_name = ONNX_NAMESPACE::to_string(++next_unique_);
1040 }
1041 return next_unique_;
1042 }
1043
1044
1045
1046
1047 graph_node_list_iterator begin() {
1048 return nodes().begin();
1049 }
1050 const_graph_node_list_iterator begin() const {
1051 return nodes().begin();
1052 }
1053 graph_node_list_iterator end() {
1054 return nodes().end();
1055 }
1056 const_graph_node_list_iterator end() const {
1057 return nodes().end();
1058 }
1059 graph_node_list_iterator rbegin() {
1060 return nodes().rbegin();
1061 }
1062 const_graph_node_list_iterator rbegin() const {
1063 return nodes().rbegin();
1064 }
1065 graph_node_list_iterator rend() {
1066 return nodes().rend();
1067 }
1068 const_graph_node_list_iterator rend() const {
1069 return nodes().rend();
1070 }
1071 Node* return_node() {
1072 return output_;
1073 }
1074 const Node* return_node() const {
1075 return output_;
1076 }
1077
1078 Value* addInput() {
1079 return input_->addOutput();
1080 }
1081 void eraseInput(size_t i) {
1082 input_->eraseOutput(i);
1083 }
1084 void advanceStage() {
1085 new_node_stage_++;
1086 }
1087 void setStage(size_t new_stage) {
1088 new_node_stage_ = new_stage;
1089 }
1090 size_t stage() const {
1091 return new_node_stage_;
1092 }
1093 ResourceGuard setStageTemporary(size_t s) {
1094 auto prev_stage = new_node_stage_;
1095 new_node_stage_ = s;
1096 return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1097 }
1098
1099 size_t registerOutput(Value* n) {
1100 output_->addInput(n);
1101 return outputs().size() - 1;
1102 }
1103
1104 Node* create(NodeKind kind, size_t num_outputs = 1) {
1105
1106 auto n = new Node(this, kind);
1107 for (size_t i = 0; i < num_outputs; i++)
1108 n->addOutput();
1109 return n;
1110 }
1111
1112 Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) {
1113 auto n = create(kind, num_outputs);
1114 for (auto i : inputs)
1115 n->addInput(i);
1116 return n;
1117 }
1118
1119 Node* appendNode(Node* n) {
1120 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1121 n->insertBefore(output_);
1122 return n;
1123 }
1124
1125 Node* prependNode(Node* n) {
1126 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1127 n->insertAfter(output_);
1128 return n;
1129 }
1130
1131
1132
1133
1134 Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) {
1135 Tensor initializerCopy = initializer;
1136 std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()};
1137 Value* new_init = addInput();
1138 initializerCopy.setName(name);
1139 new_init->setUniqueName(name);
1140 new_init->setSizes(dim_sizes);
1141 new_init->setElemType(initializerCopy.elem_type());
1142 addInitializer(initializerCopy);
1143 return new_init;
1144 }
1145
1146 Value* addInitializerAndInput(const Tensor& initializer) {
1147 return addInitializerAndInput(initializer, ONNX_NAMESPACE::to_string(getNextUnique()));
1148 }
1149
1150
1151
1152 void eraseInitializerAndInput(Value* v) {
1153 eraseInitializer(v->uniqueName());
1154 if (v->node() == input_) {
1155 eraseInput(v->offset());
1156 }
1157 }
1158
1159 ~Graph() {
1160 for (const Node* n : all_nodes)
1161 delete n;
1162 for (const Value* v : all_values)
1163 delete v;
1164 }
1165
1166 std::string toString() const {
1167 std::ostringstream oss;
1168 oss << *this;
1169 return oss.str();
1170 }
1171
1172 bool has_name() const {
1173 return has_name_;
1174 }
1175
1176 const std::string& name() const {
1177 return name_;
1178 }
1179
1180 void setName(std::string name) {
1181 has_name_ = true;
1182 name_ = std::move(name);
1183 }
1184
1185 friend std::ostream& operator<<(std::ostream& out, const Graph& g);
1186
1187 void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) {
1188 fn(this);
1189
1190 for (const Node* node : all_nodes) {
1191 for (const auto& attr : node->attributeNames()) {
1192 if (node->kindOf(attr) == AttributeKind::g) {
1193 std::shared_ptr<Graph> subgraph = node->g(attr);
1194 subgraph->forSelfAndEachSubGraph(fn);
1195 } else if (node->kindOf(attr) == AttributeKind::gs) {
1196 for (const auto& subgraph : node->gs(attr)) {
1197 subgraph->forSelfAndEachSubGraph(fn);
1198 }
1199 }
1200 }
1201 }
1202 }
1203
1204 void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const {
1205 std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); };
1206 const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1207 }
1208
1209 void forEachNode(const std::function<void(Node*)>& fn) {
1210 forSelfAndEachSubGraph([fn](Graph* graph) {
1211 for (Node* node : graph->nodes()) {
1212 fn(node);
1213 }
1214 });
1215 }
1216
1217 void forEachNode(const std::function<void(const Node*)>& fn) const {
1218 std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); };
1219 const_cast<Graph*>(this)->forEachNode(tmp_fn);
1220 }
1221
1222 private:
1223
1224 Node* initOutput(Node* p) {
1225 p->next() = p;
1226 p->prev() = p;
1227 p->setStage(std::numeric_limits<size_t>::max());
1228 return p;
1229 }
1230
1231 void freeNode(Node* n) {
1232 auto it = all_nodes.find(n);
1233 ONNX_ASSERT(it != all_nodes.end());
1234 delete *it;
1235 all_nodes.erase(it);
1236 }
1237 void freeValue(Value* v) {
1238 auto it = all_values.find(v);
1239 ONNX_ASSERT(it != all_values.end());
1240 delete *it;
1241 all_values.erase(it);
1242 }
1243 };
1244
1245 inline Value::Value(Node* node_, size_t offset_)
1246 : node_(node_),
1247 offset_(offset_),
1248 unique_(node_->graph_->getNextUnique()),
1249 stage_(node_->graph_->new_node_stage_),
1250 has_unique_name_(false),
1251 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1252 has_sizes_(false) {
1253 node_->graph_->all_values.emplace(this);
1254 }
1255
1256 inline Graph* Value::owningGraph() {
1257 return node()->owningGraph();
1258 }
1259
1260 inline const Graph* Value::owningGraph() const {
1261 return node()->owningGraph();
1262 }
1263
1264
1265
1266
1267
1268
1269 inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) {
1270 if (has_unique_name() && update_related_names) {
1271 auto* graph = owningGraph();
1272 auto old_name = unique_name_;
1273 for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) {
1274 auto& initializer_name = owningGraph()->initializer_names_[i];
1275 if (initializer_name == old_name) {
1276 initializer_name = name;
1277 owningGraph()->initializers_[i].setName(name);
1278 }
1279 }
1280 graph->forEachNode([this, &name, &old_name](Node* node) {
1281 if (node->owningGraph() == this->owningGraph()) {
1282
1283 return;
1284 }
1285 if (node->kind() == kCaptured) {
1286 Value* output = node->output();
1287 if (output->uniqueName() == old_name) {
1288 output->setUniqueName(name, false);
1289 }
1290 }
1291 });
1292 }
1293 unique_name_ = name;
1294 has_unique_name_ = true;
1295 return this;
1296 }
1297
1298 inline void Value::replaceAllUsesWith(Value* newValue) {
1299 auto* graph = owningGraph();
1300 ONNX_ASSERT(graph == newValue->owningGraph());
1301
1302 if (this->has_sizes()) {
1303 newValue->setSizes(this->sizes());
1304 }
1305 if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1306 newValue->setElemType(this->elemType());
1307 }
1308 const auto unique_name = this->uniqueName();
1309
1310 if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) {
1311 newValue->setUniqueName(unique_name);
1312
1313
1314 this->setUniqueName(ONNX_NAMESPACE::to_string(graph->getNextUnique()), false);
1315 }
1316 newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1317 for (auto u : uses_in_current_graph_) {
1318 u.user->inputs_[u.offset] = newValue;
1319 newValue->uses_in_current_graph_.push_back(u);
1320 }
1321 graph->forEachNode([this, &newValue, &unique_name](Node* node) {
1322 if (node->owningGraph() == this->owningGraph()) {
1323
1324 return;
1325 }
1326 if (node->kind() == kCaptured) {
1327 Value* output = node->output();
1328 if (output->uniqueName() == unique_name) {
1329 output->setUniqueName(newValue->uniqueName());
1330 }
1331 }
1332 });
1333 uses_in_current_graph_.clear();
1334 assert(this->uses().empty());
1335 }
1336
1337 inline Node::Node(Graph* graph_, NodeKind kind_)
1338 : kind_(kind_),
1339 graph_(graph_),
1340 stage_(graph_->new_node_stage_),
1341 has_name_(false),
1342 has_domain_(false),
1343 has_doc_string_(false) {
1344 graph_->all_nodes.emplace(this);
1345 }
1346
1347 inline void Node::eraseOutput(size_t i) {
1348 ONNX_ASSERT(i < outputs_.size());
1349 ONNX_ASSERT(outputs_[i]->uses().empty());
1350 Value* n = outputs_[i];
1351 outputs_.erase(outputs_.begin() + i);
1352 owningGraph()->freeValue(n);
1353 for (size_t j = i; j < outputs_.size(); j++) {
1354 outputs_[j]->offset_--;
1355 }
1356 }
1357
1358 inline bool Node::isBefore(Node* n) {
1359 if (n == nullptr || this == n) {
1360
1361 return false;
1362 }
1363
1364 if (kind_ == kParam) {
1365 return true;
1366 }
1367
1368 if (n->kind() == kParam) {
1369 return false;
1370 }
1371 ONNX_ASSERT(n->inGraphList());
1372 for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1373 if (p == n) {
1374 return true;
1375 }
1376 }
1377 return false;
1378 }
1379
1380 inline void Node::destroy() {
1381 ONNX_ASSERT(inGraphList());
1382 while (!outputs().empty())
1383 eraseOutput(outputs().size() - 1);
1384 removeAllInputs();
1385 removeFromList();
1386 graph_->freeNode(this);
1387 }
1388
1389
1390
1391 inline graph_node_list_iterator Node::iterator() {
1392 return graph_node_list_iterator(this, 0);
1393 }
1394 inline graph_node_list_iterator Node::reverseIterator() {
1395 return iterator().reverse();
1396 }
1397 inline const_graph_node_list_iterator Node::iterator() const {
1398 return const_graph_node_list_iterator(this, 0);
1399 }
1400 inline const_graph_node_list_iterator Node::reverseIterator() const {
1401 return iterator().reverse();
1402 }
1403
1404
1405
1406
1407
1408 inline const use_list Value::uses() const {
1409 use_list all_uses = uses_in_current_graph_;
1410 owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1411 if (node->owningGraph() == this->owningGraph()) {
1412
1413 return;
1414 }
1415 if (node->kind() == kCaptured) {
1416 const Value* output = node->outputs()[0];
1417 if (output->uniqueName() == this->uniqueName()) {
1418 const auto output_uses = output->uses();
1419 all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1420 }
1421 }
1422 });
1423 return all_uses;
1424 }
1425
1426 }