Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:31:26

0001 //---------------------------------*-CUDA-*----------------------------------//
0002 // Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file celeritas/optical/action/ActionLauncher.device.hh
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include <type_traits>
0011 
0012 #include "corecel/DeviceRuntimeApi.hh"
0013 
0014 #include "corecel/Assert.hh"
0015 #include "corecel/Macros.hh"
0016 #include "corecel/Types.hh"
0017 #include "corecel/cont/Range.hh"
0018 #include "corecel/sys/KernelLauncher.device.hh"
0019 #include "corecel/sys/ThreadId.hh"
0020 #include "celeritas/optical/CoreParams.hh"
0021 #include "celeritas/optical/CoreState.hh"
0022 #include "celeritas/track/TrackInitParams.hh"
0023 
0024 #include "ActionInterface.hh"
0025 
0026 namespace celeritas
0027 {
0028 namespace optical
0029 {
0030 //---------------------------------------------------------------------------//
0031 /*!
0032  * Profile and launch optical stepping loop kernels.
0033  *
0034  * This is an extension to \c KernelLauncher which uses an action's label and
0035  * takes the optical state to determine the launch size. The "call thread"
0036  * operation (thread executor) should contain the params and state.
0037  *
0038  * Example:
0039  * \code
0040  void FooAction::step(CoreParams const& params,
0041                       CoreStateDevice& state) const
0042  {
0043    auto execute_thread = make_blah_executor(blah);
0044    static ActionLauncher<decltype(execute_thread)> const launch_kernel(*this);
0045    launch_kernel(state, execute_thread);
0046  }
0047  * \endcode
0048  */
0049 template<class F>
0050 class ActionLauncher : public KernelLauncher<F>
0051 {
0052     static_assert(
0053         (std::is_trivially_copyable_v<F> || CELERITAS_USE_HIP
0054          || CELER_COMPILER == CELER_COMPILER_CLANG)
0055             && !std::is_pointer_v<F> && !std::is_reference_v<F>,
0056         R"(Launched action must be a trivially copyable function object)");
0057     using StepActionT = OpticalStepActionInterface;
0058 
0059   public:
0060     // Create a launcher from a string
0061     using KernelLauncher<F>::KernelLauncher;
0062 
0063     // Create a launcher from an action
0064     explicit ActionLauncher(StepActionT const& action);
0065 
0066     // Create a launcher with a string extension
0067     ActionLauncher(StepActionT const& action, std::string_view ext);
0068 
0069     // Launch a kernel for a thread range or number of threads
0070     using KernelLauncher<F>::operator();
0071 
0072     // Launch a kernel for the wrapped executor
0073     void operator()(CoreState<MemSpace::device> const& state,
0074                     F const& call_thread) const;
0075 };
0076 
0077 //---------------------------------------------------------------------------//
0078 /*!
0079  * Create a launcher from an action.
0080  */
0081 template<class F>
0082 ActionLauncher<F>::ActionLauncher(StepActionT const& action)
0083     : ActionLauncher{action.label()}
0084 {
0085 }
0086 
0087 //---------------------------------------------------------------------------//
0088 /*!
0089  * Create a launcher with a string extension.
0090  */
0091 template<class F>
0092 ActionLauncher<F>::ActionLauncher(StepActionT const& action,
0093                                   std::string_view ext)
0094     : ActionLauncher{std::string(action.label()) + "-" + std::string(ext)}
0095 {
0096 }
0097 
0098 //---------------------------------------------------------------------------//
0099 /*!
0100  * Launch a kernel for the wrapped executor.
0101  */
0102 template<class F>
0103 void ActionLauncher<F>::operator()(CoreState<MemSpace::device> const& state,
0104                                    F const& call_thread) const
0105 {
0106     return (*this)(
0107         range(ThreadId{state.size()}), state.stream_id(), call_thread);
0108 }
0109 
0110 //---------------------------------------------------------------------------//
0111 }  // namespace optical
0112 }  // namespace celeritas