File indexing completed on 2026-05-11 08:38:33
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <type_traits>
0010
0011 #include "corecel/Assert.hh"
0012 #include "corecel/OpaqueId.hh"
0013 #include "corecel/Types.hh"
0014 #include "corecel/cont/Range.hh"
0015 #include "corecel/sys/Device.hh"
0016 #include "corecel/sys/ThreadId.hh"
0017
0018 #include "Collection.hh"
0019 #include "CollectionMirror.hh"
0020 #include "CollectionStateStore.hh"
0021
0022 namespace celeritas
0023 {
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 template<template<Ownership, MemSpace> class P,
0050 template<Ownership, MemSpace> class S>
0051 class StreamStore
0052 {
0053 public:
0054
0055
0056 using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0057
0058
0059 public:
0060
0061 StreamStore() = default;
0062
0063
0064 inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0065
0066
0067
0068
0069 explicit operator bool() const { return num_streams_ > 0; }
0070
0071
0072 StreamId::size_type num_streams() const { return num_streams_; }
0073
0074
0075 template<MemSpace M>
0076 inline P<Ownership::const_reference, M> const& params() const;
0077
0078
0079
0080 template<MemSpace M>
0081 inline S<Ownership::reference, M>&
0082 state(StreamId stream_id, size_type size);
0083
0084
0085 template<MemSpace M>
0086 S<Ownership::reference, M> const* state(StreamId stream_id) const
0087 {
0088 return StreamStore::stateptr_impl<M>(*this, stream_id);
0089 }
0090
0091
0092 template<MemSpace M>
0093 S<Ownership::reference, M>* state(StreamId stream_id)
0094 {
0095 return StreamStore::stateptr_impl<M>(*this, stream_id);
0096 }
0097
0098 private:
0099
0100 using ParamMirror = CollectionMirror<P>;
0101 template<MemSpace M>
0102 using StateStore = CollectionStateStore<S, M>;
0103 template<MemSpace M>
0104 using VecSS = std::vector<StateStore<M>>;
0105
0106
0107
0108 CollectionMirror<P> params_;
0109 StreamId::size_type num_streams_{0};
0110 VecSS<MemSpace::host> host_states_;
0111 VecSS<MemSpace::device> device_states_;
0112
0113
0114
0115 template<MemSpace M, class Self>
0116 static constexpr decltype(auto) states_impl(Self&& self)
0117 {
0118 if constexpr (M == MemSpace::host)
0119 {
0120
0121 return (self.host_states_);
0122 }
0123 else
0124 {
0125 return (self.device_states_);
0126 }
0127 #if CELER_CUDACC_BUGGY_IF_CONSTEXPR
0128 CELER_ASSERT_UNREACHABLE();
0129 #endif
0130 }
0131
0132 template<MemSpace M, class Self>
0133 static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0134 {
0135 CELER_EXPECT(stream_id < self.num_streams_ || !self);
0136 using result_type = std::add_pointer_t<
0137 decltype(StreamStore::states_impl<M>(self).front().ref())>;
0138 if (!self)
0139 {
0140 return result_type{nullptr};
0141 }
0142
0143 auto& state_vec = StreamStore::states_impl<M>(self);
0144 CELER_ASSERT(state_vec.size() == self.num_streams_);
0145 auto& state_store = state_vec[stream_id.unchecked_get()];
0146 if (!state_store)
0147 {
0148 return result_type{nullptr};
0149 }
0150
0151 return &state_store.ref();
0152 }
0153 };
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164 template<template<Ownership, MemSpace> class P,
0165 template<Ownership, MemSpace> class S>
0166 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0167 StreamId::size_type num_streams)
0168 : params_(std::move(host)), num_streams_(num_streams)
0169 {
0170 CELER_EXPECT(params_);
0171 CELER_EXPECT(num_streams_ > 0);
0172
0173
0174 host_states_.resize(num_streams_);
0175 device_states_.resize(num_streams_);
0176 }
0177
0178
0179
0180
0181
0182 template<template<Ownership, MemSpace> class P,
0183 template<Ownership, MemSpace> class S>
0184 template<MemSpace M>
0185 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0186 {
0187 CELER_EXPECT(*this);
0188 return params_.template ref<M>();
0189 }
0190
0191
0192
0193
0194
0195 template<template<Ownership, MemSpace> class P,
0196 template<Ownership, MemSpace> class S>
0197 template<MemSpace M>
0198 S<Ownership::reference, M>&
0199 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0200 {
0201 CELER_EXPECT(*this);
0202 CELER_EXPECT(stream_id < num_streams_);
0203
0204 auto& state_vec = StreamStore::states_impl<M>(*this);
0205 CELER_ASSERT(state_vec.size() == num_streams_);
0206 auto& state_store = state_vec[stream_id.unchecked_get()];
0207 if (CELER_UNLIKELY(!state_store))
0208 {
0209 state_store = {this->params<MemSpace::host>(), stream_id, size};
0210 }
0211
0212 CELER_ENSURE(state_store.size() == size);
0213 return state_store.ref();
0214 }
0215
0216
0217
0218
0219
0220
0221
0222 template<class S, class F>
0223 void apply_to_all_streams(S&& store, F&& func)
0224 {
0225
0226 for (StreamId s : range(StreamId{store.num_streams()}))
0227 {
0228 if (auto* state = store.template state<MemSpace::host>(s))
0229 {
0230 func(*state);
0231 }
0232 }
0233
0234
0235 for (StreamId s : range(StreamId{store.num_streams()}))
0236 {
0237 if (auto* state = store.template state<MemSpace::device>(s))
0238 {
0239 func(*state);
0240 }
0241 }
0242 }
0243
0244
0245
0246
0247
0248 template<class S, class F, class T>
0249 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0250 {
0251 std::vector<T> temp_host;
0252
0253
0254 for (StreamId s : range(StreamId{store.num_streams()}))
0255 {
0256 if (auto* state = store.template state<MemSpace::host>(s))
0257 {
0258 auto data = func(*state)[AllItems<T>{}];
0259 CELER_EXPECT(data.size() == result->size());
0260 for (auto i : range(data.size()))
0261 {
0262 (*result)[i] += data[i];
0263 }
0264 }
0265 }
0266
0267
0268 for (StreamId s : range(StreamId{store.num_streams()}))
0269 {
0270 if (auto* state = store.template state<MemSpace::device>(s))
0271 {
0272 auto data = func(*state);
0273 CELER_EXPECT(data.size() == result->size());
0274
0275 if (temp_host.empty())
0276 {
0277 temp_host.resize(result->size());
0278 }
0279 copy_to_host(data, make_span(temp_host));
0280
0281 for (auto i : range(data.size()))
0282 {
0283 (*result)[i] += temp_host[i];
0284 }
0285 }
0286 }
0287 }
0288
0289
0290 }