Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 08:53:42

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file celeritas/optical/action/TrackSlotExecutor.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include "corecel/Assert.hh"
0010 #include "corecel/Types.hh"
0011 #include "corecel/sys/ThreadId.hh"
0012 #include "celeritas/optical/CoreTrackData.hh"
0013 #include "celeritas/optical/CoreTrackView.hh"
0014 #include "celeritas/optical/SimTrackView.hh"
0015 #include "celeritas/track/TrackFunctors.hh"
0016 
0017 namespace celeritas
0018 {
0019 namespace optical
0020 {
0021 //---------------------------------------------------------------------------//
0022 /*!
0023  * Transform a thread or track slot ID into a core track view.
0024  *
0025  * This class can be used to call a functor that applies to \c
0026  * optical::CoreTrackView using a \c TrackSlotId, so that the tracks can be
0027  * easily looped over as a group on CPU or GPU.
0028  *
0029  * To facilitate kernel launches, the class can also directly map \c ThreadId
0030  * to \c TrackSlotId, which will have the same numerical value because optical
0031  * photons do not implement sorting.
0032  */
0033 template<class T>
0034 class TrackSlotExecutor
0035 {
0036   public:
0037     //!@{
0038     //! \name Type aliases
0039     using ParamsPtr = CoreParamsPtr<MemSpace::native>;
0040     using StatePtr = CoreStatePtr<MemSpace::native>;
0041     using Applier = T;
0042     //!@}
0043 
0044   public:
0045     //! Construct with core data and executor
0046     CELER_FUNCTION
0047     TrackSlotExecutor(ParamsPtr params, StatePtr state, T&& execute_track)
0048         : params_{params}
0049         , state_{state}
0050         , execute_track_{celeritas::forward<T>(execute_track)}
0051     {
0052     }
0053 
0054     //! Call the underlying function based on the index in the state array
0055     CELER_FUNCTION void operator()(TrackSlotId ts)
0056     {
0057         CELER_EXPECT(ts < state_->size());
0058         CoreTrackView track(*params_, *state_, ts);
0059         return execute_track_(track);
0060     }
0061 
0062     //! Call the underlying function using the thread index
0063     CELER_FORCEINLINE_FUNCTION void operator()(ThreadId thread)
0064     {
0065         // For optical photons, thread index maps exactly to
0066         return (*this)(TrackSlotId{thread.unchecked_get()});
0067     }
0068 
0069   private:
0070     ParamsPtr const params_;
0071     StatePtr const state_;
0072     T execute_track_;
0073 };
0074 
0075 //---------------------------------------------------------------------------//
0076 /*!
0077  * Launch the track only when a certain condition applies.
0078  *
0079  * The condition \c C must have the signature \code
0080  * (SimTrackView const&) -> bool
0081   \endcode
0082  *
0083  * see \c make_active_track_executor for an example where this is used to apply
0084  * only to active (or killed) tracks.
0085  */
0086 template<class C, class T>
0087 class ConditionalTrackSlotExecutor
0088 {
0089   public:
0090     //!@{
0091     //! \name Type aliases
0092     using ParamsPtr = CoreParamsPtr<MemSpace::native>;
0093     using StatePtr = CoreStatePtr<MemSpace::native>;
0094     using Applier = T;
0095     //!@}
0096 
0097   public:
0098     //! Construct with condition and operator
0099     CELER_FUNCTION
0100     ConditionalTrackSlotExecutor(ParamsPtr params,
0101                                  StatePtr state,
0102                                  C&& applies,
0103                                  T&& execute_track)
0104         : params_{params}
0105         , state_{state}
0106         , applies_{celeritas::forward<C>(applies)}
0107         , execute_track_{celeritas::forward<T>(execute_track)}
0108     {
0109     }
0110 
0111     //! Launch the given thread if the track meets the condition
0112     CELER_FUNCTION void operator()(TrackSlotId ts)
0113     {
0114         CELER_EXPECT(ts < state_->size());
0115         CoreTrackView track(*params_, *state_, ts);
0116         if (!applies_(track))
0117         {
0118             return;
0119         }
0120 
0121         return execute_track_(track);
0122     }
0123 
0124     //! Call the underlying function using the thread index
0125     CELER_FORCEINLINE_FUNCTION void operator()(ThreadId thread)
0126     {
0127         // For optical photons, thread index maps exactly to
0128         return (*this)(TrackSlotId{thread.unchecked_get()});
0129     }
0130 
0131   private:
0132     ParamsPtr const params_;
0133     StatePtr const state_;
0134     C applies_;
0135     T execute_track_;
0136 };
0137 
0138 //---------------------------------------------------------------------------//
0139 // DEDUCTION GUIDES
0140 //---------------------------------------------------------------------------//
0141 template<class T>
0142 CELER_FUNCTION TrackSlotExecutor(CoreParamsPtr<MemSpace::native>,
0143                                  CoreStatePtr<MemSpace::native>,
0144                                  T&&) -> TrackSlotExecutor<T>;
0145 
0146 template<class C, class T>
0147 CELER_FUNCTION
0148 ConditionalTrackSlotExecutor(CoreParamsPtr<MemSpace::native>,
0149                              CoreStatePtr<MemSpace::native>,
0150                              C&&,
0151                              T&&) -> ConditionalTrackSlotExecutor<C, T>;
0152 
0153 //---------------------------------------------------------------------------//
0154 // FREE FUNCTIONS
0155 //---------------------------------------------------------------------------//
0156 /*!
0157  * Return a track executor that only applies to active, non-errored tracks.
0158  */
0159 template<class T>
0160 inline CELER_FUNCTION decltype(auto)
0161 make_active_thread_executor(CoreParamsPtr<MemSpace::native> params,
0162                             CoreStatePtr<MemSpace::native> const& state,
0163                             T&& apply_track)
0164 {
0165     return ConditionalTrackSlotExecutor{
0166         params, state, AppliesValid{}, celeritas::forward<T>(apply_track)};
0167 }
0168 
0169 //---------------------------------------------------------------------------//
0170 /*!
0171  * Return a track executor that only applies if the action ID matches.
0172  *
0173  * \note This should generally only be used for post-step actions and other
0174  * cases where the IDs \em explicitly are set. Many explicit actions apply to
0175  * all threads, active or not.
0176  */
0177 template<class T>
0178 inline CELER_FUNCTION decltype(auto)
0179 make_action_thread_executor(CoreParamsPtr<MemSpace::native> params,
0180                             CoreStatePtr<MemSpace::native> state,
0181                             ActionId action,
0182                             T&& apply_track)
0183 {
0184     CELER_EXPECT(action);
0185     return ConditionalTrackSlotExecutor{params,
0186                                         state,
0187                                         IsStepActionEqual{action},
0188                                         celeritas::forward<T>(apply_track)};
0189 }
0190 
0191 //---------------------------------------------------------------------------//
0192 }  // namespace optical
0193 }  // namespace celeritas