Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:42:43

0001 // Copyright (c) ONNX Project Contributors
0002 
0003 /*
0004  * SPDX-License-Identifier: Apache-2.0
0005  */
0006 
0007 // ATTENTION: The code in this file is highly EXPERIMENTAL.
0008 // Adventurous users should note that the APIs will probably change.
0009 
0010 #include "onnx/common/assertions.h"
0011 
0012 namespace ONNX_NAMESPACE {
0013 
0014 // Intrusive doubly linked lists with sane reverse iterators.
0015 // The header file is named graph_node_list.h because it is ONLY
0016 // used for Graph's Node lists, and if you want to use it for other
0017 // things, you will have to do some refactoring.
0018 //
0019 // At the moment, the templated type T must support a few operations:
0020 //
0021 //  - It must have a field: T* next_in_graph[2] = { nullptr, nullptr };
0022 //    which are used for the intrusive linked list pointers.
0023 //
0024 //  - It must have a method 'destroy()', which removes T from the
0025 //    list and frees a T.
0026 //
0027 // In practice, we are only using it with Node and const Node.  'destroy()'
0028 // needs to be renegotiated if you want to use this somewhere else.
0029 //
0030 // Besides the benefits of being intrusive, unlike std::list, these lists handle
0031 // forward and backward iteration uniformly because we require a
0032 // "before-first-element" sentinel.  This means that reverse iterators
0033 // physically point to the element they logically point to, rather than
0034 // the off-by-one behavior for all standard library reverse iterators.
0035 
0036 static constexpr size_t kNextDirection = 0;
0037 static constexpr size_t kPrevDirection = 1;
0038 
0039 template <typename T>
0040 struct generic_graph_node_list;
0041 
0042 template <typename T>
0043 struct generic_graph_node_list_iterator;
0044 
0045 struct Node;
0046 using graph_node_list = generic_graph_node_list<Node>;
0047 using const_graph_node_list = generic_graph_node_list<const Node>;
0048 using graph_node_list_iterator = generic_graph_node_list_iterator<Node>;
0049 using const_graph_node_list_iterator = generic_graph_node_list_iterator<const Node>;
0050 
0051 template <typename T>
0052 struct generic_graph_node_list_iterator final {
0053   generic_graph_node_list_iterator() : cur(nullptr), d(kNextDirection) {}
0054   generic_graph_node_list_iterator(T* cur, size_t d) : cur(cur), d(d) {}
0055   T* operator*() const {
0056     return cur;
0057   }
0058   T* operator->() const {
0059     return cur;
0060   }
0061   generic_graph_node_list_iterator& operator++() {
0062     ONNX_ASSERT(cur);
0063     cur = cur->next_in_graph[d];
0064     return *this;
0065   }
0066   generic_graph_node_list_iterator operator++(int) {
0067     generic_graph_node_list_iterator old = *this;
0068     ++(*this);
0069     return old;
0070   }
0071   generic_graph_node_list_iterator& operator--() {
0072     ONNX_ASSERT(cur);
0073     cur = cur->next_in_graph[reverseDir()];
0074     return *this;
0075   }
0076   generic_graph_node_list_iterator operator--(int) {
0077     generic_graph_node_list_iterator old = *this;
0078     --(*this);
0079     return old;
0080   }
0081 
0082   // erase cur without invalidating this iterator
0083   // named differently from destroy so that ->/. bugs do not
0084   // silently cause the wrong one to be called.
0085   // iterator will point to the previous entry after call
0086   void destroyCurrent() {
0087     T* n = cur;
0088     cur = cur->next_in_graph[reverseDir()];
0089     n->destroy();
0090   }
0091   generic_graph_node_list_iterator reverse() {
0092     return generic_graph_node_list_iterator(cur, reverseDir());
0093   }
0094 
0095  private:
0096   size_t reverseDir() {
0097     return d == kNextDirection ? kPrevDirection : kNextDirection;
0098   }
0099   T* cur;
0100   size_t d; // direction 0 is forward 1 is reverse, see next_in_graph
0101 };
0102 
0103 template <typename T>
0104 struct generic_graph_node_list final {
0105   using iterator = generic_graph_node_list_iterator<T>;
0106   using const_iterator = generic_graph_node_list_iterator<const T>;
0107   generic_graph_node_list_iterator<T> begin() {
0108     return generic_graph_node_list_iterator<T>(head->next_in_graph[d], d);
0109   }
0110   generic_graph_node_list_iterator<const T> begin() const {
0111     return generic_graph_node_list_iterator<const T>(head->next_in_graph[d], d);
0112   }
0113   generic_graph_node_list_iterator<T> end() {
0114     return generic_graph_node_list_iterator<T>(head, d);
0115   }
0116   generic_graph_node_list_iterator<const T> end() const {
0117     return generic_graph_node_list_iterator<const T>(head, d);
0118   }
0119   generic_graph_node_list_iterator<T> rbegin() {
0120     return reverse().begin();
0121   }
0122   generic_graph_node_list_iterator<const T> rbegin() const {
0123     return reverse().begin();
0124   }
0125   generic_graph_node_list_iterator<T> rend() {
0126     return reverse().end();
0127   }
0128   generic_graph_node_list_iterator<const T> rend() const {
0129     return reverse().end();
0130   }
0131   generic_graph_node_list reverse() {
0132     return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
0133   }
0134   const generic_graph_node_list reverse() const {
0135     return generic_graph_node_list(head, d == kNextDirection ? kPrevDirection : kNextDirection);
0136   }
0137   generic_graph_node_list(T* head, size_t d) : head(head), d(d) {}
0138 
0139  private:
0140   T* head;
0141   size_t d;
0142 };
0143 
0144 template <typename T>
0145 static inline bool operator==(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
0146   return *a == *b;
0147 }
0148 
0149 template <typename T>
0150 static inline bool operator!=(generic_graph_node_list_iterator<T> a, generic_graph_node_list_iterator<T> b) {
0151   return *a != *b;
0152 }
0153 
0154 } // namespace ONNX_NAMESPACE
0155 
0156 namespace std {
0157 
0158 template <typename T>
0159 struct iterator_traits<ONNX_NAMESPACE::generic_graph_node_list_iterator<T>> {
0160   using difference_type = int64_t;
0161   using value_type = T*;
0162   using pointer = T**;
0163   using reference = T*&;
0164   using iterator_category = bidirectional_iterator_tag;
0165 };
0166 
0167 } // namespace std