Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-30 10:06:30

0001 //------------------------------ -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file celeritas/optical/action/ActionLauncher.device.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <type_traits>
0010 
0011 #include "corecel/DeviceRuntimeApi.hh"
0012 
0013 #include "corecel/Assert.hh"
0014 #include "corecel/Macros.hh"
0015 #include "corecel/Types.hh"
0016 #include "corecel/cont/Range.hh"
0017 #include "corecel/sys/KernelLauncher.device.hh"
0018 #include "corecel/sys/ThreadId.hh"
0019 #include "celeritas/optical/CoreParams.hh"
0020 #include "celeritas/optical/CoreState.hh"
0021 
0022 #include "ActionInterface.hh"
0023 
0024 namespace celeritas
0025 {
0026 namespace optical
0027 {
0028 //---------------------------------------------------------------------------//
0029 /*!
0030  * Profile and launch optical stepping loop kernels.
0031  *
0032  * This is an extension to \c KernelLauncher which uses an action's label and
0033  * takes the optical state to determine the launch size. The "call thread"
0034  * operation (thread executor) should contain the params and state.
0035  *
0036  * Example:
0037  * \code
0038  void FooAction::step(CoreParams const& params,
0039                       CoreStateDevice& state) const
0040  {
0041    auto execute_thread = make_blah_executor(blah);
0042    static ActionLauncher<decltype(execute_thread)> const launch_kernel(*this);
0043    launch_kernel(state, execute_thread);
0044  }
0045  * \endcode
0046  */
0047 template<class F>
0048 class ActionLauncher : public KernelLauncher<F>
0049 {
0050     static_assert(
0051         (std::is_trivially_copyable_v<F> || CELERITAS_USE_HIP
0052          || CELER_COMPILER == CELER_COMPILER_CLANG)
0053             && !std::is_pointer_v<F> && !std::is_reference_v<F>,
0054         R"(Launched action must be a trivially copyable function object)");
0055     using StepActionT = OpticalStepActionInterface;
0056 
0057   public:
0058     // Create a launcher from a string
0059     using KernelLauncher<F>::KernelLauncher;
0060 
0061     // Create a launcher from an action
0062     explicit ActionLauncher(StepActionT const& action);
0063 
0064     // Create a launcher with a string extension
0065     ActionLauncher(StepActionT const& action, std::string_view ext);
0066 
0067     // Launch a kernel for a thread range or number of threads
0068     using KernelLauncher<F>::operator();
0069 
0070     // Launch a kernel for the wrapped executor
0071     void operator()(CoreState<MemSpace::device> const& state,
0072                     F const& call_thread) const;
0073 };
0074 
0075 //---------------------------------------------------------------------------//
0076 /*!
0077  * Create a launcher from an action.
0078  */
0079 template<class F>
0080 ActionLauncher<F>::ActionLauncher(StepActionT const& action)
0081     : ActionLauncher{action.label()}
0082 {
0083 }
0084 
0085 //---------------------------------------------------------------------------//
0086 /*!
0087  * Create a launcher with a string extension.
0088  */
0089 template<class F>
0090 ActionLauncher<F>::ActionLauncher(StepActionT const& action,
0091                                   std::string_view ext)
0092     : ActionLauncher{std::string(action.label()) + "-" + std::string(ext)}
0093 {
0094 }
0095 
0096 //---------------------------------------------------------------------------//
0097 /*!
0098  * Launch a kernel for the wrapped executor.
0099  */
0100 template<class F>
0101 void ActionLauncher<F>::operator()(CoreState<MemSpace::device> const& state,
0102                                    F const& call_thread) const
0103 {
0104     return (*this)(
0105         range(ThreadId{state.size()}), state.stream_id(), call_thread);
0106 }
0107 
0108 //---------------------------------------------------------------------------//
0109 }  // namespace optical
0110 }  // namespace celeritas