File indexing completed on 2025-01-30 10:03:45
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <thrust/execution_policy.h>
0012 #include <thrust/mr/allocator.h>
0013 #include <thrust/version.h>
0014
0015 #include "corecel/Config.hh"
0016
0017 #include "corecel/Assert.hh"
0018
0019 #include "Device.hh"
0020 #include "Stream.hh"
0021 #include "ThreadId.hh"
0022
0023 namespace celeritas
0024 {
0025 #if CELERITAS_USE_CUDA
0026 namespace thrust_native = thrust::cuda;
0027 #elif CELERITAS_USE_HIP
0028 namespace thrust_native = thrust::hip;
0029 #endif
0030
0031
0032
0033
0034
0035 enum class ThrustExecMode
0036 {
0037 Sync,
0038 Async,
0039 };
0040
0041
0042
0043
0044
0045
0046
0047 template<ThrustExecMode T = ThrustExecMode::Async>
0048 inline auto& thrust_execution_policy()
0049 {
0050 if constexpr (T == ThrustExecMode::Async)
0051 {
0052 #if THRUST_MAJOR_VERSION == 1 && THRUST_MINOR_VERSION < 16
0053 return thrust_native::par;
0054 #else
0055 return thrust_native::par_nosync;
0056 #endif
0057 }
0058 else
0059 {
0060 return thrust_native::par;
0061 }
0062 #if (__CUDACC_VER_MAJOR__ < 11) \
0063 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ < 5)
0064 CELER_ASSERT_UNREACHABLE();
0065 #endif
0066 }
0067
0068
0069
0070
0071
0072
0073
0074
0075 template<ThrustExecMode T = ThrustExecMode::Async>
0076 inline auto thrust_execute_on(StreamId stream_id)
0077 {
0078 if constexpr (T == ThrustExecMode::Sync)
0079 {
0080 return thrust_execution_policy<T>().on(
0081 celeritas::device().stream(stream_id).get());
0082 }
0083 else if constexpr (T == ThrustExecMode::Async)
0084 {
0085 using Alloc = thrust::mr::allocator<char, Stream::ResourceT>;
0086 Stream& stream = celeritas::device().stream(stream_id);
0087 return thrust_execution_policy<T>()(Alloc(&stream.memory_resource()))
0088 .on(stream.get());
0089 }
0090 #if (__CUDACC_VER_MAJOR__ < 11) \
0091 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ < 5)
0092 CELER_ASSERT_UNREACHABLE();
0093 #endif
0094 }
0095
0096
0097 }