Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-02-22 10:31:26

0001 //----------------------------------*-C++-*----------------------------------//
0002 // Copyright 2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file celeritas/optical/action/TrackSlotExecutor.hh
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include "corecel/Assert.hh"
0011 #include "corecel/Types.hh"
0012 #include "corecel/sys/ThreadId.hh"
0013 #include "celeritas/optical/CoreTrackData.hh"
0014 #include "celeritas/optical/CoreTrackView.hh"
0015 #include "celeritas/optical/SimTrackView.hh"
0016 #include "celeritas/track/SimFunctors.hh"
0017 
0018 namespace celeritas
0019 {
0020 namespace optical
0021 {
0022 //---------------------------------------------------------------------------//
0023 /*!
0024  * Transform a thread or track slot ID into a core track view.
0025  *
0026  * This class can be used to call a functor that applies to \c
0027  * optical::CoreTrackView using a \c TrackSlotId, so that the tracks can be
0028  * easily looped over as a group on CPU or GPU.
0029  *
0030  * To facilitate kernel launches, the class can also directly map \c ThreadId
0031  * to \c TrackSlotId, which will have the same numerical value because optical
0032  * photons do not implement sorting.
0033  */
0034 template<class T>
0035 class TrackSlotExecutor
0036 {
0037   public:
0038     //!@{
0039     //! \name Type aliases
0040     using ParamsPtr = CoreParamsPtr<MemSpace::native>;
0041     using StatePtr = CoreStatePtr<MemSpace::native>;
0042     using Applier = T;
0043     //!@}
0044 
0045   public:
0046     //! Construct with core data and executor
0047     CELER_FUNCTION
0048     TrackSlotExecutor(ParamsPtr params, StatePtr state, T&& execute_track)
0049         : params_{params}
0050         , state_{state}
0051         , execute_track_{celeritas::forward<T>(execute_track)}
0052     {
0053     }
0054 
0055     //! Call the underlying function based on the index in the state array
0056     CELER_FUNCTION void operator()(TrackSlotId ts)
0057     {
0058         CELER_EXPECT(ts < state_->size());
0059         CoreTrackView track(*params_, *state_, ts);
0060         return execute_track_(track);
0061     }
0062 
0063     //! Call the underlying function using the thread index
0064     CELER_FORCEINLINE_FUNCTION void operator()(ThreadId thread)
0065     {
0066         // For optical photons, thread index maps exactly to
0067         return (*this)(TrackSlotId{thread.unchecked_get()});
0068     }
0069 
0070   private:
0071     ParamsPtr const params_;
0072     StatePtr const state_;
0073     T execute_track_;
0074 };
0075 
0076 //---------------------------------------------------------------------------//
0077 /*!
0078  * Launch the track only when a certain condition applies to the sim state.
0079  *
0080  * The condition \c C must have the signature \code
0081  * (SimTrackView const&) -> bool
0082   \endcode
0083  *
0084  * see \c make_active_track_executor for an example where this is used to apply
0085  * only to active (or killed) tracks.
0086  */
0087 template<class C, class T>
0088 class ConditionalTrackSlotExecutor
0089 {
0090   public:
0091     //!@{
0092     //! \name Type aliases
0093     using ParamsPtr = CoreParamsPtr<MemSpace::native>;
0094     using StatePtr = CoreStatePtr<MemSpace::native>;
0095     using Applier = T;
0096     //!@}
0097 
0098   public:
0099     //! Construct with condition and operator
0100     CELER_FUNCTION
0101     ConditionalTrackSlotExecutor(ParamsPtr params,
0102                                  StatePtr state,
0103                                  C&& applies,
0104                                  T&& execute_track)
0105         : params_{params}
0106         , state_{state}
0107         , applies_{celeritas::forward<C>(applies)}
0108         , execute_track_{celeritas::forward<T>(execute_track)}
0109     {
0110     }
0111 
0112     //! Launch the given thread if the track meets the condition
0113     CELER_FUNCTION void operator()(TrackSlotId ts)
0114     {
0115         CELER_EXPECT(ts < state_->size());
0116         CoreTrackView track(*params_, *state_, ts);
0117         if (!applies_(track.sim()))
0118         {
0119             return;
0120         }
0121 
0122         return execute_track_(track);
0123     }
0124 
0125     //! Call the underlying function using the thread index
0126     CELER_FORCEINLINE_FUNCTION void operator()(ThreadId thread)
0127     {
0128         // For optical photons, thread index maps exactly to
0129         return (*this)(TrackSlotId{thread.unchecked_get()});
0130     }
0131 
0132   private:
0133     ParamsPtr const params_;
0134     StatePtr const state_;
0135     C applies_;
0136     T execute_track_;
0137 };
0138 
0139 //---------------------------------------------------------------------------//
0140 // DEDUCTION GUIDES
0141 //---------------------------------------------------------------------------//
0142 template<class T>
0143 CELER_FUNCTION TrackSlotExecutor(CoreParamsPtr<MemSpace::native>,
0144                                  CoreStatePtr<MemSpace::native>,
0145                                  T&&) -> TrackSlotExecutor<T>;
0146 
0147 template<class C, class T>
0148 CELER_FUNCTION
0149 ConditionalTrackSlotExecutor(CoreParamsPtr<MemSpace::native>,
0150                              CoreStatePtr<MemSpace::native>,
0151                              C&&,
0152                              T&&) -> ConditionalTrackSlotExecutor<C, T>;
0153 
0154 //---------------------------------------------------------------------------//
0155 // FREE FUNCTIONS
0156 //---------------------------------------------------------------------------//
0157 /*!
0158  * Return a track executor that only applies to active, non-errored tracks.
0159  */
0160 template<class T>
0161 inline CELER_FUNCTION decltype(auto)
0162 make_active_thread_executor(CoreParamsPtr<MemSpace::native> params,
0163                             CoreStatePtr<MemSpace::native> const& state,
0164                             T&& apply_track)
0165 {
0166     return ConditionalTrackSlotExecutor{
0167         params, state, AppliesValid{}, celeritas::forward<T>(apply_track)};
0168 }
0169 
0170 //---------------------------------------------------------------------------//
0171 /*!
0172  * Return a track executor that only applies if the action ID matches.
0173  *
0174  * \note This should generally only be used for post-step actions and other
0175  * cases where the IDs \em explicitly are set. Many explicit actions apply to
0176  * all threads, active or not.
0177  */
0178 template<class T>
0179 inline CELER_FUNCTION decltype(auto)
0180 make_action_thread_executor(CoreParamsPtr<MemSpace::native> params,
0181                             CoreStatePtr<MemSpace::native> state,
0182                             ActionId action,
0183                             T&& apply_track)
0184 {
0185     CELER_EXPECT(action);
0186     return ConditionalTrackSlotExecutor{params,
0187                                         state,
0188                                         IsStepActionEqual{action},
0189                                         celeritas::forward<T>(apply_track)};
0190 }
0191 
0192 //---------------------------------------------------------------------------//
0193 }  // namespace optical
0194 }  // namespace celeritas