File indexing completed on 2025-12-30 10:06:30
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <type_traits>
0010
0011 #include "corecel/DeviceRuntimeApi.hh"
0012
0013 #include "corecel/Assert.hh"
0014 #include "corecel/Macros.hh"
0015 #include "corecel/Types.hh"
0016 #include "corecel/cont/Range.hh"
0017 #include "corecel/sys/KernelLauncher.device.hh"
0018 #include "corecel/sys/ThreadId.hh"
0019 #include "celeritas/optical/CoreParams.hh"
0020 #include "celeritas/optical/CoreState.hh"
0021
0022 #include "ActionInterface.hh"
0023
0024 namespace celeritas
0025 {
0026 namespace optical
0027 {
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047 template<class F>
0048 class ActionLauncher : public KernelLauncher<F>
0049 {
0050 static_assert(
0051 (std::is_trivially_copyable_v<F> || CELERITAS_USE_HIP
0052 || CELER_COMPILER == CELER_COMPILER_CLANG)
0053 && !std::is_pointer_v<F> && !std::is_reference_v<F>,
0054 R"(Launched action must be a trivially copyable function object)");
0055 using StepActionT = OpticalStepActionInterface;
0056
0057 public:
0058
0059 using KernelLauncher<F>::KernelLauncher;
0060
0061
0062 explicit ActionLauncher(StepActionT const& action);
0063
0064
0065 ActionLauncher(StepActionT const& action, std::string_view ext);
0066
0067
0068 using KernelLauncher<F>::operator();
0069
0070
0071 void operator()(CoreState<MemSpace::device> const& state,
0072 F const& call_thread) const;
0073 };
0074
0075
0076
0077
0078
0079 template<class F>
0080 ActionLauncher<F>::ActionLauncher(StepActionT const& action)
0081 : ActionLauncher{action.label()}
0082 {
0083 }
0084
0085
0086
0087
0088
0089 template<class F>
0090 ActionLauncher<F>::ActionLauncher(StepActionT const& action,
0091 std::string_view ext)
0092 : ActionLauncher{std::string(action.label()) + "-" + std::string(ext)}
0093 {
0094 }
0095
0096
0097
0098
0099
0100 template<class F>
0101 void ActionLauncher<F>::operator()(CoreState<MemSpace::device> const& state,
0102 F const& call_thread) const
0103 {
0104 return (*this)(
0105 range(ThreadId{state.size()}), state.stream_id(), call_thread);
0106 }
0107
0108
0109 }
0110 }