Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:54:52

0001 ///////////////////////////////////////////////////////////////////////////////
0002 // Copyright (c) Lewis Baker
0003 // Licenced under MIT license. See LICENSE.txt for details.
0004 ///////////////////////////////////////////////////////////////////////////////
0005 #ifndef CPPCORO_RECURSIVE_GENERATOR_HPP_INCLUDED
0006 #define CPPCORO_RECURSIVE_GENERATOR_HPP_INCLUDED
0007 
0008 #include <cppcoro/generator.hpp>
0009 
0010 #include <cppcoro/coroutine.hpp>
0011 #include <type_traits>
0012 #include <utility>
0013 #include <cassert>
0014 #include <functional>
0015 
0016 namespace cppcoro
0017 {
0018     template<typename T>
0019     class [[nodiscard]] recursive_generator
0020     {
0021     public:
0022 
0023         class promise_type final
0024         {
0025         public:
0026 
0027             promise_type() noexcept
0028                 : m_value(nullptr)
0029                 , m_exception(nullptr)
0030                 , m_root(this)
0031                 , m_parentOrLeaf(this)
0032             {}
0033 
0034             promise_type(const promise_type&) = delete;
0035             promise_type(promise_type&&) = delete;
0036 
0037             auto get_return_object() noexcept
0038             {
0039                 return recursive_generator<T>{ *this };
0040             }
0041 
0042             cppcoro::suspend_always initial_suspend() noexcept
0043             {
0044                 return {};
0045             }
0046 
0047             cppcoro::suspend_always final_suspend() noexcept
0048             {
0049                 return {};
0050             }
0051 
0052             void unhandled_exception() noexcept
0053             {
0054                 m_exception = std::current_exception();
0055             }
0056 
0057             void return_void() noexcept {}
0058 
0059             cppcoro::suspend_always yield_value(T& value) noexcept
0060             {
0061                 m_value = std::addressof(value);
0062                 return {};
0063             }
0064 
0065             cppcoro::suspend_always yield_value(T&& value) noexcept
0066             {
0067                 m_value = std::addressof(value);
0068                 return {};
0069             }
0070 
0071             auto yield_value(recursive_generator&& generator) noexcept
0072             {
0073                 return yield_value(generator);
0074             }
0075 
0076             auto yield_value(recursive_generator& generator) noexcept
0077             {
0078                 struct awaitable
0079                 {
0080 
0081                     awaitable(promise_type* childPromise)
0082                         : m_childPromise(childPromise)
0083                     {}
0084 
0085                     bool await_ready() noexcept
0086                     {
0087                         return this->m_childPromise == nullptr;
0088                     }
0089 
0090                     void await_suspend(cppcoro::coroutine_handle<promise_type>) noexcept
0091                     {}
0092 
0093                     void await_resume()
0094                     {
0095                         if (this->m_childPromise != nullptr)
0096                         {
0097                             this->m_childPromise->throw_if_exception();
0098                         }
0099                     }
0100 
0101                 private:
0102                     promise_type* m_childPromise;
0103                 };
0104 
0105                 if (generator.m_promise != nullptr)
0106                 {
0107                     m_root->m_parentOrLeaf = generator.m_promise;
0108                     generator.m_promise->m_root = m_root;
0109                     generator.m_promise->m_parentOrLeaf = this;
0110                     generator.m_promise->resume();
0111 
0112                     if (!generator.m_promise->is_complete())
0113                     {
0114                         return awaitable{ generator.m_promise };
0115                     }
0116 
0117                     m_root->m_parentOrLeaf = this;
0118                 }
0119 
0120                 return awaitable{ nullptr };
0121             }
0122 
0123             // Don't allow any use of 'co_await' inside the recursive_generator coroutine.
0124             template<typename U>
0125             cppcoro::suspend_never await_transform(U&& value) = delete;
0126 
0127             void destroy() noexcept
0128             {
0129                 cppcoro::coroutine_handle<promise_type>::from_promise(*this).destroy();
0130             }
0131 
0132             void throw_if_exception()
0133             {
0134                 if (m_exception != nullptr)
0135                 {
0136                     std::rethrow_exception(std::move(m_exception));
0137                 }
0138             }
0139 
0140             bool is_complete() noexcept
0141             {
0142                 return cppcoro::coroutine_handle<promise_type>::from_promise(*this).done();
0143             }
0144 
0145             T& value() noexcept
0146             {
0147                 assert(this == m_root);
0148                 assert(!is_complete());
0149                 return *(m_parentOrLeaf->m_value);
0150             }
0151 
0152             void pull() noexcept
0153             {
0154                 assert(this == m_root);
0155                 assert(!m_parentOrLeaf->is_complete());
0156 
0157                 m_parentOrLeaf->resume();
0158 
0159                 while (m_parentOrLeaf != this && m_parentOrLeaf->is_complete())
0160                 {
0161                     m_parentOrLeaf = m_parentOrLeaf->m_parentOrLeaf;
0162                     m_parentOrLeaf->resume();
0163                 }
0164             }
0165 
0166         private:
0167 
0168             void resume() noexcept
0169             {
0170                 cppcoro::coroutine_handle<promise_type>::from_promise(*this).resume();
0171             }
0172 
0173             std::add_pointer_t<T> m_value;
0174             std::exception_ptr m_exception;
0175 
0176             promise_type* m_root;
0177 
0178             // If this is the promise of the root generator then this field
0179             // is a pointer to the leaf promise.
0180             // For non-root generators this is a pointer to the parent promise.
0181             promise_type* m_parentOrLeaf;
0182 
0183         };
0184 
0185         recursive_generator() noexcept
0186             : m_promise(nullptr)
0187         {}
0188 
0189         recursive_generator(promise_type& promise) noexcept
0190             : m_promise(&promise)
0191         {}
0192 
0193         recursive_generator(recursive_generator&& other) noexcept
0194             : m_promise(other.m_promise)
0195         {
0196             other.m_promise = nullptr;
0197         }
0198 
0199         recursive_generator(const recursive_generator& other) = delete;
0200         recursive_generator& operator=(const recursive_generator& other) = delete;
0201 
0202         ~recursive_generator()
0203         {
0204             if (m_promise != nullptr)
0205             {
0206                 m_promise->destroy();
0207             }
0208         }
0209 
0210         recursive_generator& operator=(recursive_generator&& other) noexcept
0211         {
0212             if (this != &other)
0213             {
0214                 if (m_promise != nullptr)
0215                 {
0216                     m_promise->destroy();
0217                 }
0218 
0219                 m_promise = other.m_promise;
0220                 other.m_promise = nullptr;
0221             }
0222 
0223             return *this;
0224         }
0225 
0226         class iterator
0227         {
0228         public:
0229 
0230             using iterator_category = std::input_iterator_tag;
0231             // What type should we use for counting elements of a potentially infinite sequence?
0232             using difference_type = std::ptrdiff_t;
0233             using value_type = std::remove_reference_t<T>;
0234             using reference = std::conditional_t<std::is_reference_v<T>, T, T&>;
0235             using pointer = std::add_pointer_t<T>;
0236 
0237             iterator() noexcept
0238                 : m_promise(nullptr)
0239             {}
0240 
0241             explicit iterator(promise_type* promise) noexcept
0242                 : m_promise(promise)
0243             {}
0244 
0245             bool operator==(const iterator& other) const noexcept
0246             {
0247                 return m_promise == other.m_promise;
0248             }
0249 
0250             bool operator!=(const iterator& other) const noexcept
0251             {
0252                 return m_promise != other.m_promise;
0253             }
0254 
0255             iterator& operator++()
0256             {
0257                 assert(m_promise != nullptr);
0258                 assert(!m_promise->is_complete());
0259 
0260                 m_promise->pull();
0261                 if (m_promise->is_complete())
0262                 {
0263                     auto* temp = m_promise;
0264                     m_promise = nullptr;
0265                     temp->throw_if_exception();
0266                 }
0267 
0268                 return *this;
0269             }
0270 
0271             void operator++(int)
0272             {
0273                 (void)operator++();
0274             }
0275 
0276             reference operator*() const noexcept
0277             {
0278                 assert(m_promise != nullptr);
0279                 return static_cast<reference>(m_promise->value());
0280             }
0281 
0282             pointer operator->() const noexcept
0283             {
0284                 return std::addressof(operator*());
0285             }
0286 
0287         private:
0288 
0289             promise_type* m_promise;
0290 
0291         };
0292 
0293         iterator begin()
0294         {
0295             if (m_promise != nullptr)
0296             {
0297                 m_promise->pull();
0298                 if (!m_promise->is_complete())
0299                 {
0300                     return iterator(m_promise);
0301                 }
0302 
0303                 m_promise->throw_if_exception();
0304             }
0305 
0306             return iterator(nullptr);
0307         }
0308 
0309         iterator end() noexcept
0310         {
0311             return iterator(nullptr);
0312         }
0313 
0314         void swap(recursive_generator& other) noexcept
0315         {
0316             std::swap(m_promise, other.m_promise);
0317         }
0318 
0319     private:
0320 
0321         friend class promise_type;
0322 
0323         promise_type* m_promise;
0324 
0325     };
0326 
0327     template<typename T>
0328     void swap(recursive_generator<T>& a, recursive_generator<T>& b) noexcept
0329     {
0330         a.swap(b);
0331     }
0332 
0333     // Note: When applying fmap operator to a recursive_generator we just yield a non-recursive
0334     // generator since we generally won't be using the result in a recursive context.
0335     template<typename FUNC, typename T>
0336     generator<std::invoke_result_t<FUNC&, typename recursive_generator<T>::iterator::reference>> fmap(FUNC func, recursive_generator<T> source)
0337     {
0338         for (auto&& value : source)
0339         {
0340             co_yield std::invoke(func, static_cast<decltype(value)>(value));
0341         }
0342     }
0343 }
0344 
0345 #endif