File indexing completed on 2025-09-16 09:00:13
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #pragma once
0011
0012 #include <stdint.h>
0013
0014 #include <algorithm>
0015 #include <atomic>
0016 #include <cstdint>
0017 #include <functional>
0018 #include <iostream>
0019 #include <limits>
0020 #include <memory>
0021 #include <set>
0022 #include <sstream>
0023 #include <string>
0024 #include <unordered_set>
0025 #include <utility>
0026 #include <vector>
0027
0028 #include "onnx/common/array_ref.h"
0029 #include "onnx/common/assertions.h"
0030 #include "onnx/common/common.h"
0031 #include "onnx/common/graph_node_list.h"
0032 #include "onnx/common/interned_strings.h"
0033 #include "onnx/common/tensor.h"
0034 #include "onnx/string_utils.h"
0035
0036 #define ONNX_DISALLOW_COPY_AND_ASSIGN(TypeName) \
0037 TypeName(const TypeName&) = delete; \
0038 TypeName& operator=(const TypeName&) = delete
0039
0040 namespace ONNX_NAMESPACE {
0041
0042 namespace {
0043
0044 std::string toVarName(size_t i) {
0045 std::ostringstream oss;
0046 oss << "_v_" << i;
0047 return oss.str();
0048 }
0049
0050 }
0051
0052
0053
0054
0055
0056 struct Graph;
0057
0058
0059
0060 struct Node;
0061
0062
0063
0064 struct Value;
0065
0066 class ResourceGuard final {
0067 std::function<void()> destructor_;
0068 bool released_;
0069
0070 public:
0071 ONNX_DISALLOW_COPY_AND_ASSIGN(ResourceGuard);
0072 explicit ResourceGuard(std::function<void()> destructor) : destructor_(std::move(destructor)), released_(false) {}
0073 ResourceGuard(ResourceGuard&& other) = default;
0074 ResourceGuard& operator=(ResourceGuard&& other) = default;
0075
0076 ~ResourceGuard() {
0077 if (!released_)
0078 destructor_();
0079 }
0080
0081 void release() {
0082 released_ = true;
0083 }
0084 };
0085
0086 struct Dimension final {
0087 Dimension() : is_unknown(true), is_int(false), dim(-1) {}
0088 Dimension(std::string param) : is_unknown(false), is_int(false), dim(-1), param(std::move(param)) {}
0089 Dimension(int64_t dim) : is_unknown(false), is_int(true), dim(dim) {}
0090
0091 bool is_unknown;
0092 bool is_int;
0093 int64_t dim;
0094 std::string param;
0095 };
0096
0097 enum class AttributeKind : uint8_t {
0098
0099
0100 f,
0101 fs,
0102 i,
0103 is,
0104 s,
0105 ss,
0106 t,
0107 ts,
0108 g,
0109 gs,
0110 tp,
0111 tps
0112 };
0113
0114 static inline const char* toString(AttributeKind kind) {
0115 static constexpr const char* names[] = {"f", "fs", "i", "is", "s", "ss", "t", "ts", "g", "gs", "tp", "tps"};
0116 ONNX_ASSERT(size_t(kind) < sizeof(names) / sizeof(const char*));
0117 return names[int(kind)];
0118 }
0119
0120 struct AttributeValue {
0121 explicit AttributeValue(Symbol name) : name(name) {}
0122 using Ptr = std::unique_ptr<AttributeValue>;
0123 Symbol name;
0124 virtual AttributeKind kind() const = 0;
0125 virtual Ptr clone() const = 0;
0126 virtual ~AttributeValue() = default;
0127 };
0128
0129 template <typename T, AttributeKind Kind>
0130 struct ScalarAttributeValue final : public AttributeValue {
0131 using ConstructorType = const T&;
0132 using ValueType = T;
0133 ScalarAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(value_) {}
0134 ValueType& value() {
0135 return value_;
0136 }
0137 virtual Ptr clone() const override {
0138 return Ptr(new ScalarAttributeValue(name, value_));
0139 }
0140 virtual AttributeKind kind() const override {
0141 return Kind;
0142 }
0143
0144 private:
0145 ValueType value_;
0146 };
0147
0148 template <typename T, AttributeKind Kind>
0149 struct VectorAttributeValue final : public AttributeValue {
0150 using ConstructorType = const std::vector<T>&&;
0151 using ValueType = std::vector<T>;
0152 VectorAttributeValue(Symbol name, ConstructorType value_) : AttributeValue(name), value_(std::move(value_)) {}
0153 ValueType& value() {
0154 return value_;
0155 }
0156 virtual AttributeKind kind() const override {
0157 return Kind;
0158 }
0159 virtual std::unique_ptr<AttributeValue> clone() const override {
0160 auto copy = value_;
0161 return Ptr(new VectorAttributeValue(name, std::move(copy)));
0162 }
0163
0164 private:
0165 ValueType value_;
0166 };
0167
0168 using FloatAttr = ScalarAttributeValue<double, AttributeKind::f>;
0169 using FloatsAttr = VectorAttributeValue<double, AttributeKind::fs>;
0170 using IntAttr = ScalarAttributeValue<int64_t, AttributeKind::i>;
0171 using IntsAttr = VectorAttributeValue<int64_t, AttributeKind::is>;
0172 using StringAttr = ScalarAttributeValue<std::string, AttributeKind::s>;
0173 using StringsAttr = VectorAttributeValue<std::string, AttributeKind::ss>;
0174 using TensorAttr = ScalarAttributeValue<Tensor, AttributeKind::t>;
0175 using TensorsAttr = VectorAttributeValue<Tensor, AttributeKind::ts>;
0176 using GraphAttr = ScalarAttributeValue<std::shared_ptr<Graph>, AttributeKind::g>;
0177 using GraphsAttr = VectorAttributeValue<std::shared_ptr<Graph>, AttributeKind::gs>;
0178 using TypeProtoAttr = ScalarAttributeValue<TypeProto, AttributeKind::tp>;
0179 using TypeProtosAttr = VectorAttributeValue<TypeProto, AttributeKind::tps>;
0180
0181
0182
0183
0184
0185 template <typename Derived>
0186 struct Attributes {
0187 Attributes() {}
0188 void copyAttributes(const Attributes& rhs) {
0189 values_.clear();
0190 values_.reserve(rhs.values_.size());
0191 for (auto& i : rhs.values_) {
0192 values_.push_back(i->clone());
0193 }
0194 }
0195 bool hasAttribute(Symbol name) const {
0196 return find(name, false) != values_.end();
0197 }
0198 AttributeKind kindOf(Symbol name) const {
0199 return (*find(name, true))->kind();
0200 }
0201 Derived* removeAttribute(Symbol name) {
0202 values_.erase(find(name, true));
0203 return This();
0204 }
0205 bool hasAttributes() const {
0206 return !values_.empty();
0207 }
0208
0209 std::vector<Symbol> attributeNames() const {
0210 std::vector<Symbol> names;
0211 names.reserve(values_.size());
0212 for (auto& a : values_)
0213 names.push_back(a->name);
0214 return names;
0215 }
0216
0217 #define CREATE_ACCESSOR(Kind, method) \
0218 Derived* method##_(Symbol name, Kind##Attr::ConstructorType v) { \
0219 return set<Kind##Attr>(name, std::forward<Kind##Attr::ConstructorType>(v)); \
0220 } \
0221 const Kind##Attr::ValueType& method(Symbol name) const { \
0222 return get<Kind##Attr>(name); \
0223 }
0224 CREATE_ACCESSOR(Float, f)
0225 CREATE_ACCESSOR(Floats, fs)
0226 CREATE_ACCESSOR(String, s)
0227 CREATE_ACCESSOR(Strings, ss)
0228 CREATE_ACCESSOR(Int, i)
0229 CREATE_ACCESSOR(Ints, is)
0230 CREATE_ACCESSOR(Tensor, t)
0231 CREATE_ACCESSOR(Tensors, ts)
0232 CREATE_ACCESSOR(Graph, g)
0233 CREATE_ACCESSOR(Graphs, gs)
0234 CREATE_ACCESSOR(TypeProto, tp)
0235 CREATE_ACCESSOR(TypeProtos, tps)
0236
0237 #undef CREATE_ACCESSOR
0238
0239 private:
0240 Derived* This() {
0241 return static_cast<Derived*>(this);
0242 }
0243 template <typename T>
0244 Derived* set(Symbol name, typename T::ConstructorType v) {
0245 auto it = find(name, false);
0246 auto nv = AVPtr(new T(name, std::forward<typename T::ConstructorType>(v)));
0247 if (it == values_.end()) {
0248 values_.push_back(std::move(nv));
0249 } else {
0250 *it = std::move(nv);
0251 }
0252 return This();
0253 }
0254 template <typename T>
0255 typename T::ValueType& get(Symbol name) const {
0256 auto it = find(name, true);
0257 T* child = static_cast<T*>(it->get());
0258 return child->value();
0259 }
0260 using AVPtr = AttributeValue::Ptr;
0261
0262
0263
0264 std::vector<AVPtr> values_;
0265 using iterator = std::vector<AVPtr>::iterator;
0266 iterator find(Symbol name, bool required) {
0267 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0268 ONNX_ASSERT(!required || it != values_.end());
0269 return it;
0270 }
0271 using const_iterator = std::vector<AVPtr>::const_iterator;
0272 const_iterator find(Symbol name, bool required) const {
0273 auto it = std::find_if(values_.begin(), values_.end(), [&](const AVPtr& v) { return v->name == name; });
0274 ONNX_ASSERTM(
0275 !required || it != values_.end(),
0276 "%s:%u: %s: required undefined attribute '%s'",
0277 __FILE__,
0278 __LINE__,
0279 __func__,
0280 name.toString());
0281 return it;
0282 }
0283 };
0284
0285
0286
0287
0288 struct Use final {
0289 Use(Node* user, size_t offset) : user(user), offset(offset) {}
0290 Node* user;
0291 size_t offset;
0292 };
0293
0294 static inline bool operator==(const Use& a, const Use& b) {
0295 return a.user == b.user && a.offset == b.offset;
0296 }
0297
0298
0299
0300 using node_list = std::vector<Node*>;
0301 using value_list = std::vector<Value*>;
0302 using use_list = std::vector<Use>;
0303 using NodeKind = Symbol;
0304
0305 struct Value final {
0306 ONNX_DISALLOW_COPY_AND_ASSIGN(Value);
0307 Value(Node* node_, size_t offset_);
0308 Value(Value&&) = default;
0309 Value& operator=(Value&&) = default;
0310 ~Value() = default;
0311
0312 private:
0313 friend struct Node;
0314 friend struct Graph;
0315 Node* node_;
0316 size_t offset_;
0317 size_t unique_ = 0;
0318 size_t stage_ = 0;
0319 use_list uses_in_current_graph_;
0320 bool has_unique_name_;
0321 std::string unique_name_;
0322 int32_t elem_type_;
0323 bool has_sizes_;
0324 std::vector<Dimension> sizes_;
0325
0326 public:
0327 Value* setElemType(int32_t elem_type) {
0328 elem_type_ = elem_type;
0329 return this;
0330 }
0331 int32_t elemType() const {
0332 return elem_type_;
0333 }
0334 bool has_sizes() const {
0335 return has_sizes_;
0336 }
0337 Value* setSizes(std::vector<Dimension> sizes) {
0338 has_sizes_ = true;
0339 sizes_ = std::move(sizes);
0340 return this;
0341 }
0342 Value* wipeSizes() {
0343 has_sizes_ = false;
0344 sizes_ = std::vector<Dimension>();
0345 return this;
0346 }
0347 const std::vector<Dimension>& sizes() const {
0348 return sizes_;
0349 }
0350 size_t unique() const {
0351 return unique_;
0352 }
0353 bool has_unique_name() const {
0354 return has_unique_name_;
0355 }
0356 std::string uniqueName() const {
0357 if (has_unique_name())
0358 return unique_name_;
0359 return toVarName(unique());
0360 }
0361 Value* setUniqueName(const std::string& name, bool rename_subgraph_captured_nodes = true);
0362 Value* setStage(size_t s) {
0363 stage_ = s;
0364 return this;
0365 }
0366 size_t stage() const {
0367 return stage_;
0368 }
0369 Node* node() {
0370 return node_;
0371 }
0372 size_t offset() const {
0373 return offset_;
0374 }
0375 const Node* node() const {
0376 return node_;
0377 }
0378 Graph* owningGraph();
0379 const Graph* owningGraph() const;
0380
0381 const use_list uses() const;
0382
0383
0384
0385
0386
0387
0388
0389
0390
0391
0392 void replaceAllUsesWith(Value* newValue);
0393
0394 Value* copyMetadata(Value* from) {
0395 setElemType(from->elemType());
0396 setSizes(from->sizes());
0397 if (from->has_unique_name()) {
0398 setUniqueName(from->uniqueName());
0399 }
0400 return this;
0401 }
0402 };
0403
0404 struct Node : public Attributes<Node> {
0405 ONNX_DISALLOW_COPY_AND_ASSIGN(Node);
0406 friend struct Graph;
0407 friend struct Value;
0408 friend graph_node_list;
0409 friend const_graph_node_list;
0410 friend graph_node_list_iterator;
0411 friend const_graph_node_list_iterator;
0412
0413 private:
0414
0415
0416
0417
0418
0419
0420
0421
0422 Node* next_in_graph[2] = {nullptr, nullptr};
0423 Node*& next() {
0424 return next_in_graph[kNextDirection];
0425 }
0426 Node*& prev() {
0427 return next_in_graph[kPrevDirection];
0428 }
0429 Node* const& next() const {
0430 return next_in_graph[kNextDirection];
0431 }
0432 Node* const& prev() const {
0433 return next_in_graph[kPrevDirection];
0434 }
0435
0436 const NodeKind kind_;
0437 std::vector<Value*> inputs_;
0438 std::vector<Value*> outputs_;
0439 Graph* graph_;
0440 size_t stage_;
0441 bool has_name_;
0442 std::string name_;
0443 bool has_domain_;
0444 std::string domain_;
0445 bool has_doc_string_;
0446 std::string doc_string_;
0447 bool has_overload_;
0448 std::string overload_;
0449
0450 protected:
0451 Node(Graph* graph_, NodeKind kind_);
0452
0453 public:
0454 bool has_name() const {
0455 return has_name_;
0456 }
0457 const std::string& name() const {
0458 return name_;
0459 }
0460 void setName(std::string name) {
0461 has_name_ = true;
0462 name_ = std::move(name);
0463 }
0464 bool has_domain() const {
0465 return has_domain_;
0466 }
0467 const std::string& domain() const {
0468 return domain_;
0469 }
0470 void setDomain(std::string domain) {
0471 has_domain_ = true;
0472 domain_ = std::move(domain);
0473 }
0474 bool has_overload() const {
0475 return has_overload_;
0476 }
0477 const std::string& overload() const {
0478 return overload_;
0479 }
0480 void setOverload(std::string overload) {
0481 has_overload_ = true;
0482 overload_ = std::move(overload);
0483 }
0484 bool has_doc_string() const {
0485 return has_doc_string_;
0486 }
0487 const std::string& docString() const {
0488 return doc_string_;
0489 }
0490 void setDocString(std::string doc_string) {
0491 has_doc_string_ = true;
0492 doc_string_ = std::move(doc_string);
0493 }
0494 NodeKind kind() const {
0495 return kind_;
0496 }
0497 Graph* owningGraph() {
0498 return graph_;
0499 }
0500 const Graph* owningGraph() const {
0501 return graph_;
0502 }
0503 size_t stage() const {
0504 return stage_;
0505 }
0506 Node* setStage(size_t s) {
0507 stage_ = s;
0508 return this;
0509 }
0510
0511
0512
0513
0514
0515
0516 ArrayRef<Value*> inputs() {
0517 return inputs_;
0518 }
0519 ArrayRef<const Value*> inputs() const {
0520
0521
0522 return {inputs_.data(), inputs_.size()};
0523 }
0524
0525
0526
0527
0528
0529
0530 ArrayRef<Value*> outputs() {
0531 return outputs_;
0532 }
0533 ArrayRef<const Value*> outputs() const {
0534
0535
0536 return {outputs_.data(), outputs_.size()};
0537 }
0538 bool hasUses() const {
0539 for (auto o : outputs()) {
0540 if (!o->uses().empty())
0541 return true;
0542 }
0543 return false;
0544 }
0545 void replaceAllUsesWith(Node* n) {
0546 ONNX_ASSERT(outputs().size() == n->outputs().size());
0547 size_t nOutputs = outputs().size();
0548 for (size_t i = 0; i < nOutputs; i++) {
0549 outputs()[i]->replaceAllUsesWith(n->outputs()[i]);
0550 }
0551 }
0552
0553
0554 Value* input() {
0555 ONNX_ASSERT(inputs_.size() == 1);
0556 return inputs_.at(0);
0557 }
0558 Value* output() {
0559 ONNX_ASSERT(outputs_.size() == 1);
0560 return outputs_.at(0);
0561 }
0562 const Value* input() const {
0563 ONNX_ASSERT(inputs_.size() == 1);
0564 return inputs_.at(0);
0565 }
0566 Value* output() const {
0567 ONNX_ASSERT(outputs_.size() == 1);
0568 return outputs_.at(0);
0569 }
0570
0571 Value* input(size_t i) {
0572 return inputs_.at(i);
0573 }
0574 const Value* input(size_t i) const {
0575 return inputs_.at(i);
0576 }
0577
0578
0579
0580
0581
0582
0583
0584
0585
0586
0587
0588
0589
0590
0591
0592
0593
0594
0595
0596
0597
0598 Value* addInput(Value* node) {
0599 ONNX_ASSERT(graph_ == node->owningGraph());
0600 node->uses_in_current_graph_.emplace_back(this, inputs_.size());
0601 inputs_.push_back(node);
0602 return node;
0603 }
0604
0605
0606
0607
0608
0609
0610
0611 Value* replaceInput(size_t i, Value* newValue) {
0612 ONNX_ASSERT(newValue->owningGraph() == graph_);
0613 Value* old = dropInput(i);
0614 inputs_[i] = newValue;
0615 newValue->uses_in_current_graph_.emplace_back(this, i);
0616 return old;
0617 }
0618
0619
0620
0621
0622
0623
0624
0625 void replaceInputWith(Value* from, Value* to) {
0626 ONNX_ASSERT(from->owningGraph() == graph_);
0627 ONNX_ASSERT(to->owningGraph() == graph_);
0628 size_t i = 0;
0629 for (auto input : inputs()) {
0630 if (input == from)
0631 replaceInput(i, to);
0632 i++;
0633 }
0634 }
0635
0636 Value* addOutput() {
0637 outputs_.push_back(new Value(this, outputs_.size()));
0638 return outputs_.back();
0639 }
0640
0641 void eraseOutput(size_t i);
0642
0643
0644
0645
0646
0647
0648
0649
0650
0651
0652
0653 Node* insertBefore(Node* n) {
0654 ONNX_ASSERT(n->inGraphList());
0655 insertAfter(n->prev());
0656 return this;
0657 }
0658
0659
0660
0661
0662
0663
0664
0665
0666
0667
0668
0669 Node* insertAfter(Node* n) {
0670 ONNX_ASSERT(!inGraphList() && n->inGraphList());
0671 Node* next = n->next();
0672 n->next() = this;
0673 this->prev() = n;
0674 this->next() = next;
0675 next->prev() = this;
0676 return this;
0677 }
0678
0679
0680
0681
0682
0683
0684
0685
0686
0687 void moveAfter(Node* n) {
0688 removeFromList();
0689 insertAfter(n);
0690 }
0691
0692
0693
0694
0695
0696
0697
0698
0699 void moveBefore(Node* n) {
0700 removeFromList();
0701 insertBefore(n);
0702 }
0703
0704
0705
0706
0707
0708
0709
0710
0711
0712 void removeInput(size_t i) {
0713 dropInput(i);
0714
0715
0716 for (size_t j = i + 1; j < inputs_.size(); j++) {
0717 auto it = findUseForInput(j);
0718 it->offset--;
0719 }
0720 inputs_.erase(inputs_.begin() + i);
0721 }
0722
0723
0724
0725
0726
0727
0728 void removeAllInputs() {
0729 for (size_t i = 0; i < inputs().size(); ++i)
0730 dropInput(i);
0731 inputs_.clear();
0732 }
0733
0734
0735 bool isBefore(Node* n);
0736
0737
0738
0739 graph_node_list_iterator iterator();
0740 graph_node_list_iterator reverseIterator();
0741 const_graph_node_list_iterator iterator() const;
0742 const_graph_node_list_iterator reverseIterator() const;
0743
0744
0745
0746
0747
0748
0749
0750
0751
0752 void destroy();
0753
0754
0755
0756
0757
0758
0759
0760 template <typename T>
0761 T* cast() {
0762 if (T::Kind == kind())
0763 return static_cast<T*>(this);
0764 return nullptr;
0765 }
0766 template <typename T>
0767 T* expect() {
0768 ONNX_ASSERTM(T::Kind == kind(), "expected a %s but found a %s", T::Kind.toString(), kind().toString());
0769 return static_cast<T*>(this);
0770 }
0771
0772 virtual ~Node() = default;
0773
0774 private:
0775
0776 use_list::iterator findUseForInput(size_t i) {
0777 auto& input_uses = inputs_[i]->uses_in_current_graph_;
0778
0779
0780 auto use_it = std::find(input_uses.begin(), input_uses.end(), Use(this, i));
0781 ONNX_ASSERT(use_it != input_uses.end());
0782 return use_it;
0783 }
0784
0785
0786
0787
0788 Value* dropInput(size_t i) {
0789 ONNX_ASSERT(i < inputs_.size());
0790 auto input_node = inputs_[i];
0791 auto use_it = findUseForInput(i);
0792 input_node->uses_in_current_graph_.erase(use_it);
0793 inputs_[i] = nullptr;
0794 return input_node;
0795 }
0796
0797 bool inGraphList() const {
0798 ONNX_ASSERT(next() != nullptr || prev() == nullptr);
0799 return next() != nullptr;
0800 }
0801 void removeFromList() {
0802 ONNX_ASSERT(inGraphList());
0803 Node* next = this->next();
0804 Node* prev = this->prev();
0805 prev->next() = next;
0806 next->prev() = prev;
0807 this->next() = nullptr;
0808 this->prev() = nullptr;
0809 }
0810
0811 protected:
0812
0813
0814
0815
0816
0817 virtual Node* allocNewInstance(Graph* g) {
0818 return new Node(g, kind());
0819 }
0820
0821
0822
0823
0824
0825
0826
0827 virtual void cloneFrom(Node* s) {
0828 copyAttributes(*s);
0829 }
0830 };
0831
0832
0833
0834 class OpSetID final {
0835 private:
0836 std::string domain_;
0837 int64_t version_;
0838
0839 public:
0840 explicit OpSetID(const OperatorSetIdProto& proto) : domain_(proto.domain()), version_(proto.version()) {}
0841
0842
0843 explicit OpSetID(const int64_t version) : domain_(""), version_(version) {}
0844
0845 explicit OpSetID(const std::string& domain, int64_t version) : domain_(domain), version_(version) {}
0846
0847
0848 std::string toString() const {
0849 return domain_ + "$" + ONNX_NAMESPACE::to_string(version_);
0850 }
0851
0852
0853 static OpSetID fromString(const std::string& target) {
0854 ONNX_TRY {
0855 std::string new_domain = target.substr(0, target.find("$"));
0856 int new_version = ONNX_NAMESPACE::stoi(target.substr(target.find("$") + 1, target.length()).c_str());
0857 return OpSetID(new_domain, new_version);
0858 }
0859 ONNX_CATCH(const std::runtime_error& e) {
0860 ONNX_HANDLE_EXCEPTION([&]() { ONNX_ASSERTM(false, "Error in fromString: %s", e.what()); });
0861 }
0862
0863
0864
0865
0866
0867
0868
0869 return OpSetID("", 0);
0870 }
0871
0872 const std::string& domain() const {
0873 return domain_;
0874 }
0875
0876 int64_t version() const {
0877 return version_;
0878 }
0879
0880 void incrementVersion(int64_t step) {
0881 version_ += step;
0882 }
0883
0884 void setVersion(int64_t newVal) {
0885 version_ = newVal;
0886 }
0887 };
0888
0889 struct Graph final {
0890 ONNX_DISALLOW_COPY_AND_ASSIGN(Graph);
0891 friend struct Node;
0892 friend struct Value;
0893
0894 private:
0895
0896
0897
0898
0899 std::unordered_set<const Node*> all_nodes;
0900 std::unordered_set<const Value*> all_values;
0901 size_t next_unique_;
0902
0903 size_t new_node_stage_;
0904
0905
0906
0907
0908
0909 Node* const output_;
0910 Node* const input_;
0911
0912 Node* const initializer_node_;
0913
0914 std::vector<Tensor> initializers_;
0915 std::vector<std::string> initializer_names_;
0916
0917 bool has_name_;
0918 std::string name_;
0919 bool has_doc_string_;
0920 std::string doc_string_;
0921
0922 std::vector<OpSetID> opset_versions_;
0923
0924 bool isNameUnique(const std::string& name) const {
0925 if (std::find(initializer_names_.cbegin(), initializer_names_.cend(), name) != initializer_names_.cend()) {
0926 return false;
0927 }
0928 const auto f = [&name](const Value* v) { return v->uniqueName() == name; };
0929 for (const Node* node : all_nodes) {
0930 for (const auto& attr : node->attributeNames()) {
0931 if (node->kindOf(attr) == AttributeKind::g) {
0932 const auto& subgraph = node->g(attr);
0933 if (!subgraph->isNameUnique(name)) {
0934 return false;
0935 }
0936 } else if (node->kindOf(attr) == AttributeKind::gs) {
0937 for (const auto& subgraph : node->gs(attr)) {
0938 if (!subgraph->isNameUnique(name)) {
0939 return false;
0940 }
0941 }
0942 }
0943 }
0944 const auto found_in = std::find_if(node->inputs().begin(), node->inputs().end(), f);
0945 if (found_in != node->inputs().end()) {
0946 return false;
0947 }
0948 const auto found_out = std::find_if(node->outputs().begin(), node->outputs().end(), f);
0949 if (found_out != node->outputs().end()) {
0950 return false;
0951 }
0952 }
0953 return true;
0954 }
0955
0956 public:
0957 Graph()
0958 : next_unique_(0),
0959 new_node_stage_(0),
0960 output_(initOutput(create(kReturn, 0))),
0961 input_(create(kParam, 0)),
0962 initializer_node_(create(kParam, 0)),
0963 has_name_(false),
0964 has_doc_string_(false) {}
0965
0966 bool has_doc_string() const {
0967 return has_doc_string_;
0968 }
0969 const std::string& docString() {
0970 return doc_string_;
0971 }
0972 void setDocString(std::string doc_string) {
0973 has_doc_string_ = true;
0974 doc_string_ = std::move(doc_string);
0975 }
0976
0977 void addInitializer(Tensor& initializer) {
0978 if (initializer.name().empty()) {
0979 initializer.setName(toVarName(getNextUnique()));
0980 }
0981 initializers_.push_back(initializer);
0982 initializer_names_.push_back(initializer.name());
0983 }
0984
0985
0986
0987 Value* addInitializerAndCreateValue(Tensor& initializer) {
0988 addInitializer(initializer);
0989 auto* init_value = initializer_node_->addOutput();
0990 std::vector<Dimension> dim_sizes{initializer.sizes().cbegin(), initializer.sizes().cend()};
0991 init_value->setUniqueName(initializer.name());
0992 init_value->setSizes(dim_sizes);
0993 init_value->setElemType(initializer.elem_type());
0994 return init_value;
0995 }
0996
0997 void eraseInitializer(const std::string& name) {
0998 initializers_.erase(
0999 std::remove_if(
1000 initializers_.begin(),
1001 initializers_.end(),
1002 [&name](Tensor& initializer) { return initializer.name() == name; }),
1003 initializers_.end());
1004 initializer_names_.erase(
1005 std::remove(initializer_names_.begin(), initializer_names_.end(), name), initializer_names_.end());
1006 for (size_t i = 0; i < initializer_node_->outputs().size(); i++) {
1007 if (initializer_node_->outputs()[i]->uniqueName() == name) {
1008 initializer_node_->eraseOutput(i);
1009 break;
1010 }
1011 }
1012 }
1013 void clearInitializers() {
1014 initializers_.clear();
1015 initializer_names_.clear();
1016 }
1017 const std::vector<Tensor>& initializers() const {
1018 return initializers_;
1019 }
1020 const std::vector<std::string>& initializer_names() const {
1021 return initializer_names_;
1022 }
1023 std::vector<Tensor>::const_iterator getInitializer(const std::string& name) const {
1024 for (auto it = initializers_.cbegin(); it != initializers_.cend(); ++it) {
1025 if (name == it->name()) {
1026 return it;
1027 }
1028 }
1029 return initializers_.end();
1030 }
1031 bool is_constant_initializer(const Value* value) const {
1032 return value->node() == initializer_node_;
1033 }
1034 ArrayRef<Value*> inputs() {
1035 return input_->outputs();
1036 }
1037 ArrayRef<const Value*> inputs() const {
1038 const auto& inputs = input_->outputs();
1039 return {inputs.data(), inputs.size()};
1040 }
1041 ArrayRef<Value*> outputs() {
1042 return output_->inputs();
1043 }
1044 ArrayRef<const Value*> outputs() const {
1045 return static_cast<const Node*>(output_)->inputs();
1046 }
1047 graph_node_list nodes() {
1048 return graph_node_list(output_, kNextDirection);
1049 }
1050 const_graph_node_list nodes() const {
1051 return const_graph_node_list(output_, kNextDirection);
1052 }
1053
1054 std::vector<OpSetID>& opset_versions_mutable() {
1055 return opset_versions_;
1056 }
1057
1058 size_t getNextUnique() {
1059 std::string next_unique_name = toVarName(++next_unique_);
1060 while (!isNameUnique(next_unique_name)) {
1061 next_unique_name = toVarName(++next_unique_);
1062 }
1063 return next_unique_;
1064 }
1065
1066
1067
1068
1069 graph_node_list_iterator begin() {
1070 return nodes().begin();
1071 }
1072 const_graph_node_list_iterator begin() const {
1073 return nodes().begin();
1074 }
1075 graph_node_list_iterator end() {
1076 return nodes().end();
1077 }
1078 const_graph_node_list_iterator end() const {
1079 return nodes().end();
1080 }
1081 graph_node_list_iterator rbegin() {
1082 return nodes().rbegin();
1083 }
1084 const_graph_node_list_iterator rbegin() const {
1085 return nodes().rbegin();
1086 }
1087 graph_node_list_iterator rend() {
1088 return nodes().rend();
1089 }
1090 const_graph_node_list_iterator rend() const {
1091 return nodes().rend();
1092 }
1093 Node* return_node() {
1094 return output_;
1095 }
1096 const Node* return_node() const {
1097 return output_;
1098 }
1099
1100 Value* addInput() {
1101 return input_->addOutput();
1102 }
1103 void eraseInput(size_t i) {
1104 input_->eraseOutput(i);
1105 }
1106 void advanceStage() {
1107 new_node_stage_++;
1108 }
1109 void setStage(size_t new_stage) {
1110 new_node_stage_ = new_stage;
1111 }
1112 size_t stage() const {
1113 return new_node_stage_;
1114 }
1115 ResourceGuard setStageTemporary(size_t s) {
1116 auto prev_stage = new_node_stage_;
1117 new_node_stage_ = s;
1118 return ResourceGuard([prev_stage, this]() { this->new_node_stage_ = prev_stage; });
1119 }
1120
1121 size_t registerOutput(Value* n) {
1122 output_->addInput(n);
1123 return outputs().size() - 1;
1124 }
1125
1126 Node* create(NodeKind kind, size_t num_outputs = 1) {
1127
1128 auto n = new Node(this, kind);
1129 for (size_t i = 0; i < num_outputs; i++)
1130 n->addOutput();
1131 return n;
1132 }
1133
1134 Node* create(NodeKind kind, ArrayRef<Value*> inputs, size_t num_outputs = 1) {
1135 auto n = create(kind, num_outputs);
1136 for (auto i : inputs)
1137 n->addInput(i);
1138 return n;
1139 }
1140
1141 Node* appendNode(Node* n) {
1142 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1143 n->insertBefore(output_);
1144 return n;
1145 }
1146
1147 Node* prependNode(Node* n) {
1148 ONNX_ASSERT(n->graph_ == this && !n->inGraphList());
1149 n->insertAfter(output_);
1150 return n;
1151 }
1152
1153
1154
1155
1156 Value* addInitializerAndInput(const Tensor& initializer, const std::string& name) {
1157 Tensor initializerCopy = initializer;
1158 std::vector<Dimension> dim_sizes{initializerCopy.sizes().cbegin(), initializerCopy.sizes().cend()};
1159 Value* new_init = addInput();
1160 initializerCopy.setName(name);
1161 new_init->setUniqueName(name);
1162 new_init->setSizes(dim_sizes);
1163 new_init->setElemType(initializerCopy.elem_type());
1164 addInitializer(initializerCopy);
1165 return new_init;
1166 }
1167
1168 Value* addInitializerAndInput(const Tensor& initializer) {
1169 return addInitializerAndInput(initializer, toVarName(getNextUnique()));
1170 }
1171
1172
1173
1174 void eraseInitializerAndInput(Value* v) {
1175 eraseInitializer(v->uniqueName());
1176 if (v->node() == input_) {
1177 eraseInput(v->offset());
1178 }
1179 }
1180
1181 ~Graph() {
1182 for (const Node* n : all_nodes)
1183 delete n;
1184 for (const Value* v : all_values)
1185 delete v;
1186 }
1187
1188 std::string toString() const {
1189 std::ostringstream oss;
1190 oss << *this;
1191 return oss.str();
1192 }
1193
1194 bool has_name() const {
1195 return has_name_;
1196 }
1197
1198 const std::string& name() const {
1199 return name_;
1200 }
1201
1202 void setName(std::string name) {
1203 has_name_ = true;
1204 name_ = std::move(name);
1205 }
1206
1207 friend std::ostream& operator<<(std::ostream& out, const Graph& g);
1208
1209 void forSelfAndEachSubGraph(const std::function<void(Graph*)>& fn) {
1210 fn(this);
1211
1212 for (const Node* node : all_nodes) {
1213 for (const auto& attr : node->attributeNames()) {
1214 if (node->kindOf(attr) == AttributeKind::g) {
1215 std::shared_ptr<Graph> subgraph = node->g(attr);
1216 subgraph->forSelfAndEachSubGraph(fn);
1217 } else if (node->kindOf(attr) == AttributeKind::gs) {
1218 for (const auto& subgraph : node->gs(attr)) {
1219 subgraph->forSelfAndEachSubGraph(fn);
1220 }
1221 }
1222 }
1223 }
1224 }
1225
1226 void forSelfAndEachSubGraph(const std::function<void(const Graph*)>& fn) const {
1227 std::function<void(Graph*)> tmp_fn = [fn](Graph* graph) { fn(graph); };
1228 const_cast<Graph*>(this)->forSelfAndEachSubGraph(tmp_fn);
1229 }
1230
1231 void forEachNode(const std::function<void(Node*)>& fn) {
1232 forSelfAndEachSubGraph([fn](Graph* graph) {
1233 for (Node* node : graph->nodes()) {
1234 fn(node);
1235 }
1236 });
1237 }
1238
1239 void forEachNode(const std::function<void(const Node*)>& fn) const {
1240 std::function<void(Node*)> tmp_fn = [fn](Node* node) { fn(node); };
1241 const_cast<Graph*>(this)->forEachNode(tmp_fn);
1242 }
1243
1244 private:
1245
1246 Node* initOutput(Node* p) {
1247 p->next() = p;
1248 p->prev() = p;
1249 p->setStage(std::numeric_limits<size_t>::max());
1250 return p;
1251 }
1252
1253 void freeNode(Node* n) {
1254 auto it = all_nodes.find(n);
1255 ONNX_ASSERT(it != all_nodes.end());
1256 delete *it;
1257 all_nodes.erase(it);
1258 }
1259 void freeValue(Value* v) {
1260 auto it = all_values.find(v);
1261 ONNX_ASSERT(it != all_values.end());
1262 delete *it;
1263 all_values.erase(it);
1264 }
1265 };
1266
1267 inline Value::Value(Node* node_, size_t offset_)
1268 : node_(node_),
1269 offset_(offset_),
1270 unique_(node_->graph_->getNextUnique()),
1271 stage_(node_->graph_->new_node_stage_),
1272 has_unique_name_(false),
1273 elem_type_(ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED),
1274 has_sizes_(false) {
1275 node_->graph_->all_values.emplace(this);
1276 }
1277
1278 inline Graph* Value::owningGraph() {
1279 return node()->owningGraph();
1280 }
1281
1282 inline const Graph* Value::owningGraph() const {
1283 return node()->owningGraph();
1284 }
1285
1286
1287
1288
1289
1290
1291 inline Value* Value::setUniqueName(const std::string& name, bool update_related_names) {
1292 if (has_unique_name() && update_related_names) {
1293 auto* graph = owningGraph();
1294 auto old_name = unique_name_;
1295 for (size_t i = 0; i < owningGraph()->initializer_names_.size(); i++) {
1296 auto& initializer_name = owningGraph()->initializer_names_[i];
1297 if (initializer_name == old_name) {
1298 initializer_name = name;
1299 owningGraph()->initializers_[i].setName(name);
1300 }
1301 }
1302 graph->forEachNode([this, &name, &old_name](Node* node) {
1303 if (node->owningGraph() == this->owningGraph()) {
1304
1305 return;
1306 }
1307 if (node->kind() == kCaptured) {
1308 Value* output = node->output();
1309 if (output->uniqueName() == old_name) {
1310 output->setUniqueName(name, false);
1311 }
1312 }
1313 });
1314 }
1315 unique_name_ = name;
1316 has_unique_name_ = true;
1317 return this;
1318 }
1319
1320 inline void Value::replaceAllUsesWith(Value* newValue) {
1321 auto* graph = owningGraph();
1322 ONNX_ASSERT(graph == newValue->owningGraph());
1323
1324 if (this->has_sizes()) {
1325 newValue->setSizes(this->sizes());
1326 }
1327 if (this->elemType() != ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
1328 newValue->setElemType(this->elemType());
1329 }
1330 const auto unique_name = this->uniqueName();
1331
1332 if (std::find(graph->outputs().rbegin(), graph->outputs().rend(), this) != graph->outputs().rend()) {
1333 newValue->setUniqueName(unique_name);
1334
1335
1336 this->setUniqueName(toVarName(graph->getNextUnique()), false);
1337 }
1338 newValue->uses_in_current_graph_.reserve(this->uses_in_current_graph_.size());
1339 for (auto u : uses_in_current_graph_) {
1340 u.user->inputs_[u.offset] = newValue;
1341 newValue->uses_in_current_graph_.push_back(u);
1342 }
1343 graph->forEachNode([this, &newValue, &unique_name](Node* node) {
1344 if (node->owningGraph() == this->owningGraph()) {
1345
1346 return;
1347 }
1348 if (node->kind() == kCaptured) {
1349 Value* output = node->output();
1350 if (output->uniqueName() == unique_name) {
1351 output->setUniqueName(newValue->uniqueName());
1352 }
1353 }
1354 });
1355 uses_in_current_graph_.clear();
1356 assert(this->uses().empty());
1357 }
1358
1359 inline Node::Node(Graph* graph_, NodeKind kind_)
1360 : kind_(kind_),
1361 graph_(graph_),
1362 stage_(graph_->new_node_stage_),
1363 has_name_(false),
1364 has_domain_(false),
1365 has_doc_string_(false),
1366 has_overload_(false) {
1367 graph_->all_nodes.emplace(this);
1368 }
1369
1370 inline void Node::eraseOutput(size_t i) {
1371 ONNX_ASSERT(i < outputs_.size());
1372 ONNX_ASSERT(outputs_[i]->uses().empty());
1373 Value* n = outputs_[i];
1374 outputs_.erase(outputs_.begin() + i);
1375 owningGraph()->freeValue(n);
1376 for (size_t j = i; j < outputs_.size(); j++) {
1377 outputs_[j]->offset_--;
1378 }
1379 }
1380
1381 inline bool Node::isBefore(Node* n) {
1382 if (n == nullptr || this == n) {
1383
1384 return false;
1385 }
1386
1387 if (kind_ == kParam) {
1388 return true;
1389 }
1390
1391 if (n->kind() == kParam) {
1392 return false;
1393 }
1394 ONNX_ASSERT(n->inGraphList());
1395 for (Node* p = next(); p != *graph_->end(); p = p->next()) {
1396 if (p == n) {
1397 return true;
1398 }
1399 }
1400 return false;
1401 }
1402
1403 inline void Node::destroy() {
1404 ONNX_ASSERT(inGraphList());
1405 while (!outputs().empty())
1406 eraseOutput(outputs().size() - 1);
1407 removeAllInputs();
1408 removeFromList();
1409 graph_->freeNode(this);
1410 }
1411
1412
1413
1414 inline graph_node_list_iterator Node::iterator() {
1415 return graph_node_list_iterator(this, 0);
1416 }
1417 inline graph_node_list_iterator Node::reverseIterator() {
1418 return iterator().reverse();
1419 }
1420 inline const_graph_node_list_iterator Node::iterator() const {
1421 return const_graph_node_list_iterator(this, 0);
1422 }
1423 inline const_graph_node_list_iterator Node::reverseIterator() const {
1424 return iterator().reverse();
1425 }
1426
1427
1428
1429
1430
1431 inline const use_list Value::uses() const {
1432 use_list all_uses = uses_in_current_graph_;
1433 owningGraph()->forEachNode([this, &all_uses](const Node* node) {
1434 if (node->owningGraph() == this->owningGraph()) {
1435
1436 return;
1437 }
1438 if (node->kind() == kCaptured) {
1439 const Value* output = node->outputs()[0];
1440 if (output->uniqueName() == this->uniqueName()) {
1441 const auto output_uses = output->uses();
1442 all_uses.insert(all_uses.end(), output_uses.begin(), output_uses.end());
1443 }
1444 }
1445 });
1446 return all_uses;
1447 }
1448
1449 }