Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:30:48

0001 
0002 //          Copyright Oliver Kowalke 2017.
0003 // Distributed under the Boost Software License, Version 1.0.
0004 //    (See accompanying file LICENSE_1_0.txt or copy at
0005 //          http://www.boost.org/LICENSE_1_0.txt)
0006 
0007 #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
0008 #define BOOST_FIBERS_CUDA_WAITFOR_H
0009 
0010 #include <initializer_list>
0011 #include <mutex>
0012 #include <iostream>
0013 #include <set>
0014 #include <tuple>
0015 #include <vector>
0016 
0017 #include <boost/assert.hpp>
0018 #include <boost/config.hpp>
0019 
0020 #include <cuda.h>
0021 
0022 #include <boost/fiber/detail/config.hpp>
0023 #include <boost/fiber/detail/is_all_same.hpp>
0024 #include <boost/fiber/condition_variable.hpp>
0025 #include <boost/fiber/mutex.hpp>
0026 
0027 #ifdef BOOST_HAS_ABI_HEADERS
0028 #  include BOOST_ABI_PREFIX
0029 #endif
0030 
0031 namespace boost {
0032 namespace fibers {
0033 namespace cuda {
0034 namespace detail {
0035 
0036 template< typename Rendezvous >
0037 static void trampoline( cudaStream_t st, cudaError_t status, void * vp) {
0038     Rendezvous * data = static_cast< Rendezvous * >( vp);
0039     data->notify( st, status);
0040 }
0041 
0042 class single_stream_rendezvous {
0043 public:
0044     single_stream_rendezvous( cudaStream_t st) {
0045         unsigned int flags = 0;
0046         cudaError_t status = ::cudaStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
0047         if ( cudaSuccess != status) {
0048             st_ = st;
0049             status_ = status;
0050             done_ = true;
0051         }
0052     }
0053 
0054     void notify( cudaStream_t st, cudaError_t status) noexcept {
0055         std::unique_lock< mutex > lk{ mtx_ };
0056         st_ = st;
0057         status_ = status;
0058         done_ = true;
0059         lk.unlock();
0060         cv_.notify_one();
0061     }
0062 
0063     std::tuple< cudaStream_t, cudaError_t > wait() {
0064         std::unique_lock< mutex > lk{ mtx_ };
0065         cv_.wait( lk, [this]{ return done_; });
0066         return std::make_tuple( st_, status_);
0067     }
0068 
0069 private:
0070     mutex               mtx_{};
0071     condition_variable  cv_{};
0072     cudaStream_t        st_{};
0073     cudaError_t         status_{ cudaErrorUnknown };
0074     bool                done_{ false };
0075 };
0076 
0077 class many_streams_rendezvous {
0078 public:
0079     many_streams_rendezvous( std::initializer_list< cudaStream_t > l) :
0080             stx_{ l } {
0081         results_.reserve( stx_.size() );
0082         for ( cudaStream_t st : stx_) {
0083             unsigned int flags = 0;
0084             cudaError_t status = ::cudaStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
0085             if ( cudaSuccess != status) {
0086                 std::unique_lock< mutex > lk{ mtx_ };
0087                 stx_.erase( st);
0088                 results_.push_back( std::make_tuple( st, status) );
0089             }
0090         }
0091     }
0092 
0093     void notify( cudaStream_t st, cudaError_t status) noexcept {
0094         std::unique_lock< mutex > lk{ mtx_ };
0095         stx_.erase( st);
0096         results_.push_back( std::make_tuple( st, status) );
0097         if ( stx_.empty() ) {
0098             lk.unlock();
0099             cv_.notify_one();
0100         }
0101     }
0102 
0103     std::vector< std::tuple< cudaStream_t, cudaError_t > > wait() {
0104         std::unique_lock< mutex > lk{ mtx_ };
0105         cv_.wait( lk, [this]{ return stx_.empty(); });
0106         return results_;
0107     }
0108 
0109 private:
0110     mutex                                                   mtx_{};
0111     condition_variable                                      cv_{};
0112     std::set< cudaStream_t >                                stx_;
0113     std::vector< std::tuple< cudaStream_t, cudaError_t > >  results_;
0114 };
0115 
0116 }
0117 
0118 void waitfor_all();
0119 
0120 inline
0121 std::tuple< cudaStream_t, cudaError_t > waitfor_all( cudaStream_t st) {
0122     detail::single_stream_rendezvous rendezvous( st);
0123     return rendezvous.wait();
0124 }
0125 
0126 template< typename ... STP >
0127 std::vector< std::tuple< cudaStream_t, cudaError_t > > waitfor_all( cudaStream_t st0, STP ... stx) {
0128     static_assert( boost::fibers::detail::is_all_same< cudaStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
0129     detail::many_streams_rendezvous rendezvous{ st0, stx ... };
0130     return rendezvous.wait();
0131 }
0132 
0133 }}}
0134 
0135 #ifdef BOOST_HAS_ABI_HEADERS
0136 #  include BOOST_ABI_SUFFIX
0137 #endif
0138 
0139 #endif // BOOST_FIBERS_CUDA_WAITFOR_H