Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-05-11 08:38:33

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file corecel/data/StreamStore.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <type_traits>
0010 
0011 #include "corecel/Assert.hh"
0012 #include "corecel/OpaqueId.hh"
0013 #include "corecel/Types.hh"
0014 #include "corecel/cont/Range.hh"
0015 #include "corecel/sys/Device.hh"
0016 #include "corecel/sys/ThreadId.hh"
0017 
0018 #include "Collection.hh"
0019 #include "CollectionMirror.hh"
0020 #include "CollectionStateStore.hh"
0021 
0022 namespace celeritas
0023 {
0024 //---------------------------------------------------------------------------//
0025 /*!
0026  * Helper class for storing parameters and multiple stream-dependent states.
0027  *
0028  * This requires a templated ParamsData and StateData. Hopefully this
0029  * frankenstein of a class will be replaced by a std::any-like data container
0030  * owned by each (possibly thread-local) State.
0031  *
0032  * Usage:
0033  * \code
0034    StreamStore<FooParams, FooState> store{host_val, num_streams};
0035    assert(store);
0036 
0037    execute_kernel(store.params(), store.state<Memspace::host>(StreamId{0},
0038  state_size))
0039 
0040    if (auto* state = store.state<Memspace::device>(StreamId{1}))
0041    {
0042        cout << "Have device data for stream 1" << endl;
0043    }
0044    \endcode
0045  *
0046  * There is some additional complexity in the "state" accessors to allow for
0047  * const correctness.
0048  */
0049 template<template<Ownership, MemSpace> class P,
0050          template<Ownership, MemSpace> class S>
0051 class StreamStore
0052 {
0053   public:
0054     //!@{
0055     //! \name Type aliases
0056     using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0057     //!@}
0058 
0059   public:
0060     // Default for unassigned/lazy construction
0061     StreamStore() = default;
0062 
0063     // Construct with number of streams and host data
0064     inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0065 
0066     //// ACCESSORS ////
0067 
0068     //! Whether the instance is ready for storing data
0069     explicit operator bool() const { return num_streams_ > 0; }
0070 
0071     //! Number of streams being stored
0072     StreamId::size_type num_streams() const { return num_streams_; }
0073 
0074     // Get references to the params data
0075     template<MemSpace M>
0076     inline P<Ownership::const_reference, M> const& params() const;
0077 
0078     // Get references to the state data for a given stream, allocating if
0079     // necessary.
0080     template<MemSpace M>
0081     inline S<Ownership::reference, M>&
0082     state(StreamId stream_id, size_type size);
0083 
0084     //! Get a pointer to the state data, null if not allocated
0085     template<MemSpace M>
0086     S<Ownership::reference, M> const* state(StreamId stream_id) const
0087     {
0088         return StreamStore::stateptr_impl<M>(*this, stream_id);
0089     }
0090 
0091     //! Get a mutable pointer to the state data, null if not allocated
0092     template<MemSpace M>
0093     S<Ownership::reference, M>* state(StreamId stream_id)
0094     {
0095         return StreamStore::stateptr_impl<M>(*this, stream_id);
0096     }
0097 
0098   private:
0099     //// TYPES ////
0100     using ParamMirror = CollectionMirror<P>;
0101     template<MemSpace M>
0102     using StateStore = CollectionStateStore<S, M>;
0103     template<MemSpace M>
0104     using VecSS = std::vector<StateStore<M>>;
0105 
0106     //// DATA ////
0107 
0108     CollectionMirror<P> params_;
0109     StreamId::size_type num_streams_{0};
0110     VecSS<MemSpace::host> host_states_;
0111     VecSS<MemSpace::device> device_states_;
0112 
0113     //// FUNCTIONS ////
0114 
0115     template<MemSpace M, class Self>
0116     static constexpr decltype(auto) states_impl(Self&& self)
0117     {
0118         if constexpr (M == MemSpace::host)
0119         {
0120             // Extra parens needed to return reference instead of copy
0121             return (self.host_states_);
0122         }
0123         else
0124         {
0125             return (self.device_states_);
0126         }
0127 #if CELER_CUDACC_BUGGY_IF_CONSTEXPR
0128         CELER_ASSERT_UNREACHABLE();
0129 #endif
0130     }
0131 
0132     template<MemSpace M, class Self>
0133     static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0134     {
0135         CELER_EXPECT(stream_id < self.num_streams_ || !self);
0136         using result_type = std::add_pointer_t<
0137             decltype(StreamStore::states_impl<M>(self).front().ref())>;
0138         if (!self)
0139         {
0140             return result_type{nullptr};
0141         }
0142 
0143         auto& state_vec = StreamStore::states_impl<M>(self);
0144         CELER_ASSERT(state_vec.size() == self.num_streams_);
0145         auto& state_store = state_vec[stream_id.unchecked_get()];
0146         if (!state_store)
0147         {
0148             return result_type{nullptr};
0149         }
0150 
0151         return &state_store.ref();
0152     }
0153 };
0154 
0155 //---------------------------------------------------------------------------//
0156 // INLINE DEFINITIONS
0157 //---------------------------------------------------------------------------//
0158 /*!
0159  * Construct with parameters and the number of streams.
0160  *
0161  * The constructor is *not* thread safe and should be called during params
0162  * setup, not at run time.
0163  */
0164 template<template<Ownership, MemSpace> class P,
0165          template<Ownership, MemSpace> class S>
0166 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0167                                StreamId::size_type num_streams)
0168     : params_(std::move(host)), num_streams_(num_streams)
0169 {
0170     CELER_EXPECT(params_);
0171     CELER_EXPECT(num_streams_ > 0);
0172 
0173     // Resize stores in advance, but don't allocate memory.
0174     host_states_.resize(num_streams_);
0175     device_states_.resize(num_streams_);
0176 }
0177 
0178 //---------------------------------------------------------------------------//
0179 /*!
0180  * Get a reference to the params data.
0181  */
0182 template<template<Ownership, MemSpace> class P,
0183          template<Ownership, MemSpace> class S>
0184 template<MemSpace M>
0185 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0186 {
0187     CELER_EXPECT(*this);
0188     return params_.template ref<M>();
0189 }
0190 
0191 //---------------------------------------------------------------------------//
0192 /*!
0193  * Get a reference to the state data, allocating if necessary.
0194  */
0195 template<template<Ownership, MemSpace> class P,
0196          template<Ownership, MemSpace> class S>
0197 template<MemSpace M>
0198 S<Ownership::reference, M>&
0199 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0200 {
0201     CELER_EXPECT(*this);
0202     CELER_EXPECT(stream_id < num_streams_);
0203 
0204     auto& state_vec = StreamStore::states_impl<M>(*this);
0205     CELER_ASSERT(state_vec.size() == num_streams_);
0206     auto& state_store = state_vec[stream_id.unchecked_get()];
0207     if (CELER_UNLIKELY(!state_store))
0208     {
0209         state_store = {this->params<MemSpace::host>(), stream_id, size};
0210     }
0211 
0212     CELER_ENSURE(state_store.size() == size);
0213     return state_store.ref();
0214 }
0215 
0216 //---------------------------------------------------------------------------//
0217 // HELPER FUNCTIONS
0218 //---------------------------------------------------------------------------//
0219 /*!
0220  * Apply a function to all streams.
0221  */
0222 template<class S, class F>
0223 void apply_to_all_streams(S&& store, F&& func)
0224 {
0225     // Apply on host
0226     for (StreamId s : range(StreamId{store.num_streams()}))
0227     {
0228         if (auto* state = store.template state<MemSpace::host>(s))
0229         {
0230             func(*state);
0231         }
0232     }
0233 
0234     // Apply on device
0235     for (StreamId s : range(StreamId{store.num_streams()}))
0236     {
0237         if (auto* state = store.template state<MemSpace::device>(s))
0238         {
0239             func(*state);
0240         }
0241     }
0242 }
0243 
0244 //---------------------------------------------------------------------------//
0245 /*!
0246  * Accumulate data over all streams.
0247  */
0248 template<class S, class F, class T>
0249 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0250 {
0251     std::vector<T> temp_host;
0252 
0253     // Accumulate on host
0254     for (StreamId s : range(StreamId{store.num_streams()}))
0255     {
0256         if (auto* state = store.template state<MemSpace::host>(s))
0257         {
0258             auto data = func(*state)[AllItems<T>{}];
0259             CELER_EXPECT(data.size() == result->size());
0260             for (auto i : range(data.size()))
0261             {
0262                 (*result)[i] += data[i];
0263             }
0264         }
0265     }
0266 
0267     // Accumulate on device
0268     for (StreamId s : range(StreamId{store.num_streams()}))
0269     {
0270         if (auto* state = store.template state<MemSpace::device>(s))
0271         {
0272             auto data = func(*state);
0273             CELER_EXPECT(data.size() == result->size());
0274 
0275             if (temp_host.empty())
0276             {
0277                 temp_host.resize(result->size());
0278             }
0279             copy_to_host(data, make_span(temp_host));
0280 
0281             for (auto i : range(data.size()))
0282             {
0283                 (*result)[i] += temp_host[i];
0284             }
0285         }
0286     }
0287 }
0288 
0289 //---------------------------------------------------------------------------//
0290 }  // namespace celeritas