Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:09:05

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file celeritas/global/Stepper.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <memory>
0010 #include <vector>
0011 
0012 #include "corecel/Types.hh"
0013 #include "corecel/cont/Span.hh"
0014 #include "corecel/data/CollectionStateStore.hh"
0015 #include "corecel/random/params/RngParamsFwd.hh"
0016 #include "celeritas/Types.hh"
0017 #include "celeritas/geo/GeoFwd.hh"
0018 #include "celeritas/phys/Primary.hh"
0019 #include "celeritas/track/TrackInitData.hh"
0020 
0021 #include "CoreState.hh"
0022 #include "CoreTrackData.hh"
0023 
0024 namespace celeritas
0025 {
0026 //---------------------------------------------------------------------------//
0027 class CoreParams;
0028 struct Primary;
0029 class ExtendFromPrimariesAction;
0030 
0031 class ActionSequence;
0032 
0033 //---------------------------------------------------------------------------//
0034 /*!
0035  * State-specific options for the stepper.
0036  *
0037  * - \c params : Problem definition
0038  * - \c num_track_slots : Maximum number of threads to run in parallel on GPU
0039  *   (optional, could be set by params)
0040  *   \c stream_id : Unique (thread/task) ID for this process
0041  * - \c action_times : Whether to synchronize device between actions for timing
0042  */
0043 struct StepperInput
0044 {
0045     std::shared_ptr<CoreParams const> params;
0046     StreamId stream_id{};
0047     size_type num_track_slots{};
0048     bool action_times{false};
0049 
0050     //! True if defined
0051     explicit operator bool() const { return params && stream_id; }
0052 };
0053 
0054 //---------------------------------------------------------------------------//
0055 /*!
0056  * Track counters for a step.
0057  */
0058 struct StepperResult
0059 {
0060     size_type generated{};  //!< New primaries added
0061     size_type queued{};  //!< Pending track initializers at end of step
0062     size_type active{};  //!< Active tracks at start of step
0063     size_type alive{};  //!< Active and alive at end of step
0064 
0065     //! True if more steps need to be run
0066     explicit operator bool() const { return queued > 0 || alive > 0; }
0067 };
0068 
0069 //---------------------------------------------------------------------------//
0070 /*!
0071  * Interface class for stepper classes.
0072  *
0073  * This allows higher-level classes not to care whether the stepper operates on
0074  * host or device.
0075  *
0076  * \note This class and its daughter may be removed soon to facilitate step
0077  * gathering.
0078  */
0079 class StepperInterface
0080 {
0081   public:
0082     //!@{
0083     //! \name Type aliases
0084     using Input = StepperInput;
0085     using ActionSequenceT = ActionSequence;
0086     using SpanConstPrimary = Span<Primary const>;
0087     using result_type = StepperResult;
0088     using SPState = std::shared_ptr<CoreStateInterface>;
0089     //!@}
0090 
0091   public:
0092     // Default virtual destructor
0093     virtual ~StepperInterface();
0094 
0095     // Warm up before stepping
0096     virtual void warm_up() = 0;
0097 
0098     // Transport existing states
0099     virtual StepperResult operator()() = 0;
0100 
0101     // Transport existing states and these new primaries
0102     virtual StepperResult operator()(SpanConstPrimary primaries) = 0;
0103 
0104     // Kill all tracks in flight to debug "stuck" tracks
0105     virtual void kill_active() = 0;
0106 
0107     // Reseed the RNGs at the start of an event for reproducibility
0108     virtual void reseed(UniqueEventId event_id) = 0;
0109 
0110     //! Get action sequence for timing diagnostics
0111     virtual ActionSequenceT const& actions() const = 0;
0112 
0113     //! Get the core state interface
0114     virtual CoreStateInterface const& state() const = 0;
0115 
0116     //! Get a shared pointer to the state (TEMPORARY)
0117     virtual SPState sp_state() = 0;
0118 
0119   protected:
0120     StepperInterface() = default;
0121     CELER_DEFAULT_COPY_MOVE(StepperInterface);
0122 };
0123 
0124 //---------------------------------------------------------------------------//
0125 /*!
0126  * Manage a state vector and execute a single step on all of them.
0127  *
0128  * \note This is likely to be removed and refactored since we're changing how
0129  * primaries are created and how multithread state ownership is managed.
0130  *
0131  * \code
0132    Stepper<MemSpace::host> step(input);
0133 
0134    // Transport primaries for the initial step
0135    StepperResult alive_tracks = step(my_primaries);
0136    while (alive_tracks)
0137    {
0138        // Transport secondaries
0139        alive_tracks = step();
0140    }
0141    \endcode
0142  */
0143 template<MemSpace M>
0144 class Stepper final : public StepperInterface
0145 {
0146   public:
0147     //!@{
0148     //! \name Type aliases
0149     using StateRef = CoreStateData<Ownership::reference, M>;
0150     //!@}
0151 
0152   public:
0153     // Construct with problem parameters and setup options
0154     explicit Stepper(Input input);
0155 
0156     // Default destructor
0157     ~Stepper() final;
0158 
0159     // Warm up before stepping
0160     void warm_up() final;
0161 
0162     // Transport existing states
0163     StepperResult operator()() final;
0164 
0165     // Transport existing states and these new primaries
0166     StepperResult operator()(SpanConstPrimary primaries) final;
0167 
0168     // Kill all tracks in flight to debug "stuck" tracks
0169     void kill_active() final;
0170 
0171     // Reseed the RNGs at the start of an event for reproducibility
0172     void reseed(UniqueEventId event_id) final;
0173 
0174     //! Get action sequence for timing diagnostics
0175     ActionSequenceT const& actions() const final { return *actions_; }
0176 
0177     //! Access core data, primarily for debugging
0178     StateRef const& state_ref() const { return state_->ref(); }
0179 
0180     //! Get the core state interface for diagnostic output
0181     CoreStateInterface const& state() const final { return *state_; }
0182 
0183     //! Reset the core state counters and data so it can be reused
0184     void reset_state() { state_->reset(); }
0185 
0186     //! Get a shared pointer to the state (TEMPORARY, DO NOT USE)
0187     SPState sp_state() final { return state_; }
0188 
0189   private:
0190     // Params data
0191     std::shared_ptr<CoreParams const> params_;
0192     // Primary initialization
0193     std::shared_ptr<ExtendFromPrimariesAction const> primaries_action_;
0194     // State data
0195     std::shared_ptr<CoreState<M>> state_;
0196     // Call sequence
0197     std::shared_ptr<ActionSequenceT> actions_;
0198 };
0199 
0200 //---------------------------------------------------------------------------//
0201 // EXPLICIT INSTANTIATION
0202 //---------------------------------------------------------------------------//
0203 
0204 extern template class Stepper<MemSpace::host>;
0205 extern template class Stepper<MemSpace::device>;
0206 
0207 //---------------------------------------------------------------------------//
0208 }  // namespace celeritas