File indexing completed on 2025-01-18 09:30:49
0001
0002
0003
0004
0005
0006
0007 #ifndef BOOST_FIBERS_CUDA_WAITFOR_H
0008 #define BOOST_FIBERS_CUDA_WAITFOR_H
0009
0010 #include <initializer_list>
0011 #include <mutex>
0012 #include <iostream>
0013 #include <set>
0014 #include <tuple>
0015 #include <vector>
0016
0017 #include <boost/assert.hpp>
0018 #include <boost/config.hpp>
0019
0020 #include <hip/hip_runtime.h>
0021
0022 #include <boost/fiber/detail/config.hpp>
0023 #include <boost/fiber/detail/is_all_same.hpp>
0024 #include <boost/fiber/condition_variable.hpp>
0025 #include <boost/fiber/mutex.hpp>
0026
0027 #ifdef BOOST_HAS_ABI_HEADERS
0028 # include BOOST_ABI_PREFIX
0029 #endif
0030
0031 namespace boost {
0032 namespace fibers {
0033 namespace cuda {
0034 namespace detail {
0035
0036 template< typename Rendezvous >
0037 static void trampoline( hipStream_t st, hipError_t status, void * vp) {
0038 Rendezvous * data = static_cast< Rendezvous * >( vp);
0039 data->notify( st, status);
0040 }
0041
0042 class single_stream_rendezvous {
0043 public:
0044 single_stream_rendezvous( hipStream_t st) {
0045 unsigned int flags = 0;
0046 hipError_t status = ::hipStreamAddCallback( st, trampoline< single_stream_rendezvous >, this, flags);
0047 if ( hipSuccess != status) {
0048 st_ = st;
0049 status_ = status;
0050 done_ = true;
0051 }
0052 }
0053
0054 void notify( hipStream_t st, hipError_t status) noexcept {
0055 std::unique_lock< mutex > lk{ mtx_ };
0056 st_ = st;
0057 status_ = status;
0058 done_ = true;
0059 lk.unlock();
0060 cv_.notify_one();
0061 }
0062
0063 std::tuple< hipStream_t, hipError_t > wait() {
0064 std::unique_lock< mutex > lk{ mtx_ };
0065 cv_.wait( lk, [this]{ return done_; });
0066 return std::make_tuple( st_, status_);
0067 }
0068
0069 private:
0070 mutex mtx_{};
0071 condition_variable cv_{};
0072 hipStream_t st_{};
0073 hipError_t status_{ hipErrorUnknown };
0074 bool done_{ false };
0075 };
0076
0077 class many_streams_rendezvous {
0078 public:
0079 many_streams_rendezvous( std::initializer_list< hipStream_t > l) :
0080 stx_{ l } {
0081 results_.reserve( stx_.size() );
0082 for ( hipStream_t st : stx_) {
0083 unsigned int flags = 0;
0084 hipError_t status = ::hipStreamAddCallback( st, trampoline< many_streams_rendezvous >, this, flags);
0085 if ( hipSuccess != status) {
0086 std::unique_lock< mutex > lk{ mtx_ };
0087 stx_.erase( st);
0088 results_.push_back( std::make_tuple( st, status) );
0089 }
0090 }
0091 }
0092
0093 void notify( hipStream_t st, hipError_t status) noexcept {
0094 std::unique_lock< mutex > lk{ mtx_ };
0095 stx_.erase( st);
0096 results_.push_back( std::make_tuple( st, status) );
0097 if ( stx_.empty() ) {
0098 lk.unlock();
0099 cv_.notify_one();
0100 }
0101 }
0102
0103 std::vector< std::tuple< hipStream_t, hipError_t > > wait() {
0104 std::unique_lock< mutex > lk{ mtx_ };
0105 cv_.wait( lk, [this]{ return stx_.empty(); });
0106 return results_;
0107 }
0108
0109 private:
0110 mutex mtx_{};
0111 condition_variable cv_{};
0112 std::set< hipStream_t > stx_;
0113 std::vector< std::tuple< hipStream_t, hipError_t > > results_;
0114 };
0115
0116 }
0117
0118 void waitfor_all();
0119
0120 inline
0121 std::tuple< hipStream_t, hipError_t > waitfor_all( hipStream_t st) {
0122 detail::single_stream_rendezvous rendezvous( st);
0123 return rendezvous.wait();
0124 }
0125
0126 template< typename ... STP >
0127 std::vector< std::tuple< hipStream_t, hipError_t > > waitfor_all( hipStream_t st0, STP ... stx) {
0128 static_assert( boost::fibers::detail::is_all_same< hipStream_t, STP ...>::value, "all arguments must be of type `CUstream*`.");
0129 detail::many_streams_rendezvous rendezvous{ st0, stx ... };
0130 return rendezvous.wait();
0131 }
0132
0133 }}}
0134
0135 #ifdef BOOST_HAS_ABI_HEADERS
0136 # include BOOST_ABI_SUFFIX
0137 #endif
0138
0139 #endif