Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 08:54:05

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file corecel/data/StreamStore.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <type_traits>
0010 
0011 #include "corecel/Assert.hh"
0012 #include "corecel/OpaqueId.hh"
0013 #include "corecel/Types.hh"
0014 #include "corecel/cont/Range.hh"
0015 #include "corecel/sys/Device.hh"
0016 #include "corecel/sys/ThreadId.hh"
0017 
0018 #include "Collection.hh"
0019 #include "CollectionMirror.hh"
0020 #include "CollectionStateStore.hh"
0021 
0022 namespace celeritas
0023 {
0024 //---------------------------------------------------------------------------//
0025 /*!
0026  * Helper class for storing parameters and multiple stream-dependent states.
0027  *
0028  * This requires a templated ParamsData and StateData. Hopefully this
0029  * frankenstein of a class will be replaced by a std::any-like data container
0030  * owned by each (possibly thread-local) State.
0031  *
0032  * Usage:
0033  * \code
0034    StreamStore<FooParams, FooState> store{host_val, num_streams};
0035    assert(store);
0036 
0037    execute_kernel(store.params(), store.state<Memspace::host>(StreamId{0},
0038  state_size))
0039 
0040    if (auto* state = store.state<Memspace::device>(StreamId{1}))
0041    {
0042        cout << "Have device data for stream 1" << endl;
0043    }
0044    \endcode
0045  *
0046  * There is some additional complexity in the "state" accessors to allow for
0047  * const correctness.
0048  */
0049 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0050 class StreamStore
0051 {
0052   public:
0053     //!@{
0054     //! \name Type aliases
0055     using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0056     //!@}
0057 
0058   public:
0059     // Default for unassigned/lazy construction
0060     StreamStore() = default;
0061 
0062     // Construct with number of streams and host data
0063     inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0064 
0065     //// ACCESSORS ////
0066 
0067     //! Whether the instance is ready for storing data
0068     explicit operator bool() const { return num_streams_ > 0; }
0069 
0070     //! Number of streams being stored
0071     StreamId::size_type num_streams() const { return num_streams_; }
0072 
0073     // Get references to the params data
0074     template<MemSpace M>
0075     inline P<Ownership::const_reference, M> const& params() const;
0076 
0077     // Get references to the state data for a given stream, allocating if
0078     // necessary.
0079     template<MemSpace M>
0080     inline S<Ownership::reference, M>&
0081     state(StreamId stream_id, size_type size);
0082 
0083     //! Get a pointer to the state data, null if not allocated
0084     template<MemSpace M>
0085     S<Ownership::reference, M> const* state(StreamId stream_id) const
0086     {
0087         return StreamStore::stateptr_impl<M>(*this, stream_id);
0088     }
0089 
0090     //! Get a mutable pointer to the state data, null if not allocated
0091     template<MemSpace M>
0092     S<Ownership::reference, M>* state(StreamId stream_id)
0093     {
0094         return StreamStore::stateptr_impl<M>(*this, stream_id);
0095     }
0096 
0097   private:
0098     //// TYPES ////
0099     using ParamMirror = CollectionMirror<P>;
0100     template<MemSpace M>
0101     using StateStore = CollectionStateStore<S, M>;
0102     template<MemSpace M>
0103     using VecSS = std::vector<StateStore<M>>;
0104 
0105     //// DATA ////
0106 
0107     CollectionMirror<P> params_;
0108     StreamId::size_type num_streams_{0};
0109     VecSS<MemSpace::host> host_states_;
0110     VecSS<MemSpace::device> device_states_;
0111 
0112     //// FUNCTIONS ////
0113 
0114     template<MemSpace M, class Self>
0115     static constexpr decltype(auto) states_impl(Self&& self)
0116     {
0117         if constexpr (M == MemSpace::host)
0118         {
0119             // Extra parens needed to return reference instead of copy
0120             return (self.host_states_);
0121         }
0122         else
0123         {
0124             return (self.device_states_);
0125         }
0126 #if CELER_CUDACC_BUGGY_IF_CONSTEXPR
0127         CELER_ASSERT_UNREACHABLE();
0128 #endif
0129     }
0130 
0131     template<MemSpace M, class Self>
0132     static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0133     {
0134         CELER_EXPECT(stream_id < self.num_streams_ || !self);
0135         using result_type = std::add_pointer_t<
0136             decltype(StreamStore::states_impl<M>(self).front().ref())>;
0137         if (!self)
0138         {
0139             return result_type{nullptr};
0140         }
0141 
0142         auto& state_vec = StreamStore::states_impl<M>(self);
0143         CELER_ASSERT(state_vec.size() == self.num_streams_);
0144         auto& state_store = state_vec[stream_id.unchecked_get()];
0145         if (!state_store)
0146         {
0147             return result_type{nullptr};
0148         }
0149 
0150         return &state_store.ref();
0151     }
0152 };
0153 
0154 //---------------------------------------------------------------------------//
0155 // INLINE DEFINITIONS
0156 //---------------------------------------------------------------------------//
0157 /*!
0158  * Construct with parameters and the number of streams.
0159  *
0160  * The constructor is *not* thread safe and should be called during params
0161  * setup, not at run time.
0162  */
0163 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0164 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0165                                StreamId::size_type num_streams)
0166     : params_(std::move(host)), num_streams_(num_streams)
0167 {
0168     CELER_EXPECT(params_);
0169     CELER_EXPECT(num_streams_ > 0);
0170 
0171     // Resize stores in advance, but don't allocate memory.
0172     host_states_.resize(num_streams_);
0173     device_states_.resize(num_streams_);
0174 }
0175 
0176 //---------------------------------------------------------------------------//
0177 /*!
0178  * Get a reference to the params data.
0179  */
0180 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0181 template<MemSpace M>
0182 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0183 {
0184     CELER_EXPECT(*this);
0185     return params_.template ref<M>();
0186 }
0187 
0188 //---------------------------------------------------------------------------//
0189 /*!
0190  * Get a reference to the state data, allocating if necessary.
0191  */
0192 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0193 template<MemSpace M>
0194 S<Ownership::reference, M>&
0195 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0196 {
0197     CELER_EXPECT(*this);
0198     CELER_EXPECT(stream_id < num_streams_);
0199 
0200     auto& state_vec = StreamStore::states_impl<M>(*this);
0201     CELER_ASSERT(state_vec.size() == num_streams_);
0202     auto& state_store = state_vec[stream_id.unchecked_get()];
0203     if (CELER_UNLIKELY(!state_store))
0204     {
0205         state_store = {this->params<MemSpace::host>(), stream_id, size};
0206     }
0207 
0208     CELER_ENSURE(state_store.size() == size);
0209     return state_store.ref();
0210 }
0211 
0212 //---------------------------------------------------------------------------//
0213 // HELPER FUNCTIONS
0214 //---------------------------------------------------------------------------//
0215 /*!
0216  * Apply a function to all streams.
0217  */
0218 template<class S, class F>
0219 void apply_to_all_streams(S&& store, F&& func)
0220 {
0221     // Apply on host
0222     for (StreamId s : range(StreamId{store.num_streams()}))
0223     {
0224         if (auto* state = store.template state<MemSpace::host>(s))
0225         {
0226             func(*state);
0227         }
0228     }
0229 
0230     // Apply on device
0231     for (StreamId s : range(StreamId{store.num_streams()}))
0232     {
0233         if (auto* state = store.template state<MemSpace::device>(s))
0234         {
0235             func(*state);
0236         }
0237     }
0238 }
0239 
0240 //---------------------------------------------------------------------------//
0241 /*!
0242  * Accumulate data over all streams.
0243  */
0244 template<class S, class F, class T>
0245 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0246 {
0247     std::vector<T> temp_host;
0248 
0249     // Accumulate on host
0250     for (StreamId s : range(StreamId{store.num_streams()}))
0251     {
0252         if (auto* state = store.template state<MemSpace::host>(s))
0253         {
0254             auto data = func(*state)[AllItems<T>{}];
0255             CELER_EXPECT(data.size() == result->size());
0256             for (auto i : range(data.size()))
0257             {
0258                 (*result)[i] += data[i];
0259             }
0260         }
0261     }
0262 
0263     // Accumulate on device
0264     for (StreamId s : range(StreamId{store.num_streams()}))
0265     {
0266         if (auto* state = store.template state<MemSpace::device>(s))
0267         {
0268             auto data = func(*state);
0269             CELER_EXPECT(data.size() == result->size());
0270 
0271             if (temp_host.empty())
0272             {
0273                 temp_host.resize(result->size());
0274             }
0275             copy_to_host(data, make_span(temp_host));
0276 
0277             for (auto i : range(data.size()))
0278             {
0279                 (*result)[i] += temp_host[i];
0280             }
0281         }
0282     }
0283 }
0284 
0285 //---------------------------------------------------------------------------//
0286 }  // namespace celeritas