Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:03:42

0001 //----------------------------------*-C++-*----------------------------------//
0002 // Copyright 2023-2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file corecel/data/StreamStore.hh
0007 //---------------------------------------------------------------------------//
0008 #pragma once
0009 
0010 #include <type_traits>
0011 
0012 #include "corecel/Assert.hh"
0013 #include "corecel/OpaqueId.hh"
0014 #include "corecel/Types.hh"
0015 #include "corecel/cont/Range.hh"
0016 #include "corecel/sys/Device.hh"
0017 #include "corecel/sys/ThreadId.hh"
0018 
0019 #include "Collection.hh"
0020 #include "CollectionMirror.hh"
0021 #include "CollectionStateStore.hh"
0022 
0023 namespace celeritas
0024 {
0025 //---------------------------------------------------------------------------//
0026 /*!
0027  * Helper class for storing parameters and multiple stream-dependent states.
0028  *
0029  * This requires a templated ParamsData and StateData. Hopefully this
0030  * frankenstein of a class will be replaced by a std::any-like data container
0031  * owned by each (possibly thread-local) State.
0032  *
0033  * Usage:
0034  * \code
0035    StreamStore<FooParams, FooState> store{host_val, num_streams};
0036    assert(store);
0037 
0038    execute_kernel(store.params(), store.state<Memspace::host>(StreamId{0},
0039  state_size))
0040 
0041    if (auto* state = store.state<Memspace::device>(StreamId{1}))
0042    {
0043        cout << "Have device data for stream 1" << endl;
0044    }
0045    \endcode
0046  *
0047  * There is some additional complexity in the "state" accessors to allow for
0048  * const correctness.
0049  */
0050 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0051 class StreamStore
0052 {
0053   public:
0054     //!@{
0055     //! \name Type aliases
0056     using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0057     //!@}
0058 
0059   public:
0060     // Default for unassigned/lazy construction
0061     StreamStore() = default;
0062 
0063     // Construct with number of streams and host data
0064     inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0065 
0066     //// ACCESSORS ////
0067 
0068     //! Whether the instance is ready for storing data
0069     explicit operator bool() const { return num_streams_ > 0; }
0070 
0071     //! Number of streams being stored
0072     StreamId::size_type num_streams() const { return num_streams_; }
0073 
0074     // Get references to the params data
0075     template<MemSpace M>
0076     inline P<Ownership::const_reference, M> const& params() const;
0077 
0078     // Get references to the state data for a given stream, allocating if
0079     // necessary.
0080     template<MemSpace M>
0081     inline S<Ownership::reference, M>&
0082     state(StreamId stream_id, size_type size);
0083 
0084     //! Get a pointer to the state data, null if not allocated
0085     template<MemSpace M>
0086     S<Ownership::reference, M> const* state(StreamId stream_id) const
0087     {
0088         return StreamStore::stateptr_impl<M>(*this, stream_id);
0089     }
0090 
0091     //! Get a mutable pointer to the state data, null if not allocated
0092     template<MemSpace M>
0093     S<Ownership::reference, M>* state(StreamId stream_id)
0094     {
0095         return StreamStore::stateptr_impl<M>(*this, stream_id);
0096     }
0097 
0098   private:
0099     //// TYPES ////
0100     using ParamMirror = CollectionMirror<P>;
0101     template<MemSpace M>
0102     using StateStore = CollectionStateStore<S, M>;
0103     template<MemSpace M>
0104     using VecSS = std::vector<StateStore<M>>;
0105 
0106     //// DATA ////
0107 
0108     CollectionMirror<P> params_;
0109     StreamId::size_type num_streams_{0};
0110     VecSS<MemSpace::host> host_states_;
0111     VecSS<MemSpace::device> device_states_;
0112 
0113     //// FUNCTIONS ////
0114 
0115     template<MemSpace M, class Self>
0116     static constexpr decltype(auto) states_impl(Self&& self)
0117     {
0118         if constexpr (M == MemSpace::host)
0119         {
0120             // Extra parens needed to return reference instead of copy
0121             return (self.host_states_);
0122         }
0123 #ifndef __NVCC__
0124         // CUDA 11.4 complains about 'else if constexpr' ("missing return
0125         // statement") and GCC 11.2 complains about leaving off the 'else'
0126         // ("inconsistent deduction for auto return type")
0127         else
0128 #endif
0129             return (self.device_states_);
0130     }
0131 
0132     template<MemSpace M, class Self>
0133     static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0134     {
0135         CELER_EXPECT(stream_id < self.num_streams_ || !self);
0136         using result_type = std::add_pointer_t<
0137             decltype(StreamStore::states_impl<M>(self).front().ref())>;
0138         if (!self)
0139         {
0140             return result_type{nullptr};
0141         }
0142 
0143         auto& state_vec = StreamStore::states_impl<M>(self);
0144         CELER_ASSERT(state_vec.size() == self.num_streams_);
0145         auto& state_store = state_vec[stream_id.unchecked_get()];
0146         if (!state_store)
0147         {
0148             return result_type{nullptr};
0149         }
0150 
0151         return &state_store.ref();
0152     }
0153 };
0154 
0155 //---------------------------------------------------------------------------//
0156 // INLINE DEFINITIONS
0157 //---------------------------------------------------------------------------//
0158 /*!
0159  * Construct with parameters and the number of streams.
0160  *
0161  * The constructor is *not* thread safe and should be called during params
0162  * setup, not at run time.
0163  */
0164 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0165 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0166                                StreamId::size_type num_streams)
0167     : params_(std::move(host)), num_streams_(num_streams)
0168 {
0169     CELER_EXPECT(params_);
0170     CELER_EXPECT(num_streams_ > 0);
0171 
0172     // Resize stores in advance, but don't allocate memory.
0173     host_states_.resize(num_streams_);
0174     device_states_.resize(num_streams_);
0175 }
0176 
0177 //---------------------------------------------------------------------------//
0178 /*!
0179  * Get a reference to the params data.
0180  */
0181 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0182 template<MemSpace M>
0183 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0184 {
0185     CELER_EXPECT(*this);
0186     return params_.template ref<M>();
0187 }
0188 
0189 //---------------------------------------------------------------------------//
0190 /*!
0191  * Get a reference to the state data, allocating if necessary.
0192  */
0193 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0194 template<MemSpace M>
0195 S<Ownership::reference, M>&
0196 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0197 {
0198     CELER_EXPECT(*this);
0199     CELER_EXPECT(stream_id < num_streams_);
0200 
0201     auto& state_vec = StreamStore::states_impl<M>(*this);
0202     CELER_ASSERT(state_vec.size() == num_streams_);
0203     auto& state_store = state_vec[stream_id.unchecked_get()];
0204     if (CELER_UNLIKELY(!state_store))
0205     {
0206         state_store = {this->params<MemSpace::host>(), stream_id, size};
0207     }
0208 
0209     CELER_ENSURE(state_store.size() == size);
0210     return state_store.ref();
0211 }
0212 
0213 //---------------------------------------------------------------------------//
0214 // HELPER FUNCTIONS
0215 //---------------------------------------------------------------------------//
0216 /*!
0217  * Apply a function to all streams.
0218  */
0219 template<class S, class F>
0220 void apply_to_all_streams(S&& store, F&& func)
0221 {
0222     // Apply on host
0223     for (StreamId s : range(StreamId{store.num_streams()}))
0224     {
0225         if (auto* state = store.template state<MemSpace::host>(s))
0226         {
0227             func(*state);
0228         }
0229     }
0230 
0231     // Apply on device
0232     for (StreamId s : range(StreamId{store.num_streams()}))
0233     {
0234         if (auto* state = store.template state<MemSpace::device>(s))
0235         {
0236             func(*state);
0237         }
0238     }
0239 }
0240 
0241 //---------------------------------------------------------------------------//
0242 /*!
0243  * Accumulate data over all streams.
0244  */
0245 template<class S, class F, class T>
0246 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0247 {
0248     std::vector<T> temp_host;
0249 
0250     // Accumulate on host
0251     for (StreamId s : range(StreamId{store.num_streams()}))
0252     {
0253         if (auto* state = store.template state<MemSpace::host>(s))
0254         {
0255             auto data = func(*state)[AllItems<T>{}];
0256             CELER_EXPECT(data.size() == result->size());
0257             for (auto i : range(data.size()))
0258             {
0259                 (*result)[i] += data[i];
0260             }
0261         }
0262     }
0263 
0264     // Accumulate on device
0265     for (StreamId s : range(StreamId{store.num_streams()}))
0266     {
0267         if (auto* state = store.template state<MemSpace::device>(s))
0268         {
0269             auto data = func(*state);
0270             CELER_EXPECT(data.size() == result->size());
0271 
0272             if (temp_host.empty())
0273             {
0274                 temp_host.resize(result->size());
0275             }
0276             copy_to_host(data, make_span(temp_host));
0277 
0278             for (auto i : range(data.size()))
0279             {
0280                 (*result)[i] += temp_host[i];
0281             }
0282         }
0283     }
0284 }
0285 
0286 //---------------------------------------------------------------------------//
0287 }  // namespace celeritas