File indexing completed on 2025-09-17 08:54:05
0001
0002
0003
0004
0005
0006
0007 #pragma once
0008
0009 #include <type_traits>
0010
0011 #include "corecel/Assert.hh"
0012 #include "corecel/OpaqueId.hh"
0013 #include "corecel/Types.hh"
0014 #include "corecel/cont/Range.hh"
0015 #include "corecel/sys/Device.hh"
0016 #include "corecel/sys/ThreadId.hh"
0017
0018 #include "Collection.hh"
0019 #include "CollectionMirror.hh"
0020 #include "CollectionStateStore.hh"
0021
0022 namespace celeritas
0023 {
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0050 class StreamStore
0051 {
0052 public:
0053
0054
0055 using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0056
0057
0058 public:
0059
0060 StreamStore() = default;
0061
0062
0063 inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0064
0065
0066
0067
0068 explicit operator bool() const { return num_streams_ > 0; }
0069
0070
0071 StreamId::size_type num_streams() const { return num_streams_; }
0072
0073
0074 template<MemSpace M>
0075 inline P<Ownership::const_reference, M> const& params() const;
0076
0077
0078
0079 template<MemSpace M>
0080 inline S<Ownership::reference, M>&
0081 state(StreamId stream_id, size_type size);
0082
0083
0084 template<MemSpace M>
0085 S<Ownership::reference, M> const* state(StreamId stream_id) const
0086 {
0087 return StreamStore::stateptr_impl<M>(*this, stream_id);
0088 }
0089
0090
0091 template<MemSpace M>
0092 S<Ownership::reference, M>* state(StreamId stream_id)
0093 {
0094 return StreamStore::stateptr_impl<M>(*this, stream_id);
0095 }
0096
0097 private:
0098
0099 using ParamMirror = CollectionMirror<P>;
0100 template<MemSpace M>
0101 using StateStore = CollectionStateStore<S, M>;
0102 template<MemSpace M>
0103 using VecSS = std::vector<StateStore<M>>;
0104
0105
0106
0107 CollectionMirror<P> params_;
0108 StreamId::size_type num_streams_{0};
0109 VecSS<MemSpace::host> host_states_;
0110 VecSS<MemSpace::device> device_states_;
0111
0112
0113
0114 template<MemSpace M, class Self>
0115 static constexpr decltype(auto) states_impl(Self&& self)
0116 {
0117 if constexpr (M == MemSpace::host)
0118 {
0119
0120 return (self.host_states_);
0121 }
0122 else
0123 {
0124 return (self.device_states_);
0125 }
0126 #if CELER_CUDACC_BUGGY_IF_CONSTEXPR
0127 CELER_ASSERT_UNREACHABLE();
0128 #endif
0129 }
0130
0131 template<MemSpace M, class Self>
0132 static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0133 {
0134 CELER_EXPECT(stream_id < self.num_streams_ || !self);
0135 using result_type = std::add_pointer_t<
0136 decltype(StreamStore::states_impl<M>(self).front().ref())>;
0137 if (!self)
0138 {
0139 return result_type{nullptr};
0140 }
0141
0142 auto& state_vec = StreamStore::states_impl<M>(self);
0143 CELER_ASSERT(state_vec.size() == self.num_streams_);
0144 auto& state_store = state_vec[stream_id.unchecked_get()];
0145 if (!state_store)
0146 {
0147 return result_type{nullptr};
0148 }
0149
0150 return &state_store.ref();
0151 }
0152 };
0153
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0164 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0165 StreamId::size_type num_streams)
0166 : params_(std::move(host)), num_streams_(num_streams)
0167 {
0168 CELER_EXPECT(params_);
0169 CELER_EXPECT(num_streams_ > 0);
0170
0171
0172 host_states_.resize(num_streams_);
0173 device_states_.resize(num_streams_);
0174 }
0175
0176
0177
0178
0179
0180 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0181 template<MemSpace M>
0182 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0183 {
0184 CELER_EXPECT(*this);
0185 return params_.template ref<M>();
0186 }
0187
0188
0189
0190
0191
0192 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0193 template<MemSpace M>
0194 S<Ownership::reference, M>&
0195 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0196 {
0197 CELER_EXPECT(*this);
0198 CELER_EXPECT(stream_id < num_streams_);
0199
0200 auto& state_vec = StreamStore::states_impl<M>(*this);
0201 CELER_ASSERT(state_vec.size() == num_streams_);
0202 auto& state_store = state_vec[stream_id.unchecked_get()];
0203 if (CELER_UNLIKELY(!state_store))
0204 {
0205 state_store = {this->params<MemSpace::host>(), stream_id, size};
0206 }
0207
0208 CELER_ENSURE(state_store.size() == size);
0209 return state_store.ref();
0210 }
0211
0212
0213
0214
0215
0216
0217
0218 template<class S, class F>
0219 void apply_to_all_streams(S&& store, F&& func)
0220 {
0221
0222 for (StreamId s : range(StreamId{store.num_streams()}))
0223 {
0224 if (auto* state = store.template state<MemSpace::host>(s))
0225 {
0226 func(*state);
0227 }
0228 }
0229
0230
0231 for (StreamId s : range(StreamId{store.num_streams()}))
0232 {
0233 if (auto* state = store.template state<MemSpace::device>(s))
0234 {
0235 func(*state);
0236 }
0237 }
0238 }
0239
0240
0241
0242
0243
0244 template<class S, class F, class T>
0245 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0246 {
0247 std::vector<T> temp_host;
0248
0249
0250 for (StreamId s : range(StreamId{store.num_streams()}))
0251 {
0252 if (auto* state = store.template state<MemSpace::host>(s))
0253 {
0254 auto data = func(*state)[AllItems<T>{}];
0255 CELER_EXPECT(data.size() == result->size());
0256 for (auto i : range(data.size()))
0257 {
0258 (*result)[i] += data[i];
0259 }
0260 }
0261 }
0262
0263
0264 for (StreamId s : range(StreamId{store.num_streams()}))
0265 {
0266 if (auto* state = store.template state<MemSpace::device>(s))
0267 {
0268 auto data = func(*state);
0269 CELER_EXPECT(data.size() == result->size());
0270
0271 if (temp_host.empty())
0272 {
0273 temp_host.resize(result->size());
0274 }
0275 copy_to_host(data, make_span(temp_host));
0276
0277 for (auto i : range(data.size()))
0278 {
0279 (*result)[i] += temp_host[i];
0280 }
0281 }
0282 }
0283 }
0284
0285
0286 }