Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-16 08:52:44

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file corecel/sys/Thrust.device.hh
0006 //! \brief Platform and version-specific thrust setup
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include <thrust/execution_policy.h>
0011 #include <thrust/mr/allocator.h>
0012 #include <thrust/version.h>
0013 
0014 #include "corecel/DeviceRuntimeApi.hh"  // IWYU pragma: keep
0015 
0016 #include "Device.hh"
0017 #include "Stream.hh"
0018 #include "ThreadId.hh"
0019 
0020 #include "detail/AsyncMemoryResource.device.hh"  // IWYU pragma: keep
0021 
0022 namespace celeritas
0023 {
0024 //---------------------------------------------------------------------------//
0025 // FREE FUNCTIONS
0026 //---------------------------------------------------------------------------//
0027 /*!
0028  * Get the Thrust synchronous parallel policy.
0029  */
0030 inline auto& thrust_execute()
0031 {
0032     return thrust::CELER_DEVICE_PLATFORM::par;
0033 }
0034 
0035 //---------------------------------------------------------------------------//
0036 /*!
0037  * Get a Thrust asynchronous parallel policy for the given stream.
0038  *
0039  * For older versions of thrust, this executes synchronously on the stream.
0040  */
0041 inline auto thrust_execute_on(StreamId stream_id)
0042 {
0043     using Alloc = thrust::mr::allocator<char, Stream::ResourceT>;
0044     Stream& stream = celeritas::device().stream(stream_id);
0045 #if THRUST_VERSION >= 101600
0046     // Newer thrust supports asynchronous par
0047     auto& par_nosync = thrust::CELER_DEVICE_PLATFORM::par_nosync;
0048 #else
0049     // Fall back to synchronous execution
0050     auto& par_nosync = thrust::CELER_DEVICE_PLATFORM::par;
0051 #endif
0052     return par_nosync(Alloc(&stream.memory_resource())).on(stream.get());
0053 }
0054 
0055 //---------------------------------------------------------------------------//
0056 }  // namespace celeritas