File indexing completed on 2025-01-30 10:03:42
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <type_traits>
0011
0012 #include "corecel/Assert.hh"
0013 #include "corecel/OpaqueId.hh"
0014 #include "corecel/Types.hh"
0015 #include "corecel/cont/Range.hh"
0016 #include "corecel/sys/Device.hh"
0017 #include "corecel/sys/ThreadId.hh"
0018
0019 #include "Collection.hh"
0020 #include "CollectionMirror.hh"
0021 #include "CollectionStateStore.hh"
0022
0023 namespace celeritas
0024 {
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048
0049
0050 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0051 class StreamStore
0052 {
0053 public:
0054
0055
0056 using ParamsHostVal = P<Ownership::value, MemSpace::host>;
0057
0058
0059 public:
0060
0061 StreamStore() = default;
0062
0063
0064 inline StreamStore(ParamsHostVal&& host, StreamId::size_type num_streams);
0065
0066
0067
0068
0069 explicit operator bool() const { return num_streams_ > 0; }
0070
0071
0072 StreamId::size_type num_streams() const { return num_streams_; }
0073
0074
0075 template<MemSpace M>
0076 inline P<Ownership::const_reference, M> const& params() const;
0077
0078
0079
0080 template<MemSpace M>
0081 inline S<Ownership::reference, M>&
0082 state(StreamId stream_id, size_type size);
0083
0084
0085 template<MemSpace M>
0086 S<Ownership::reference, M> const* state(StreamId stream_id) const
0087 {
0088 return StreamStore::stateptr_impl<M>(*this, stream_id);
0089 }
0090
0091
0092 template<MemSpace M>
0093 S<Ownership::reference, M>* state(StreamId stream_id)
0094 {
0095 return StreamStore::stateptr_impl<M>(*this, stream_id);
0096 }
0097
0098 private:
0099
0100 using ParamMirror = CollectionMirror<P>;
0101 template<MemSpace M>
0102 using StateStore = CollectionStateStore<S, M>;
0103 template<MemSpace M>
0104 using VecSS = std::vector<StateStore<M>>;
0105
0106
0107
0108 CollectionMirror<P> params_;
0109 StreamId::size_type num_streams_{0};
0110 VecSS<MemSpace::host> host_states_;
0111 VecSS<MemSpace::device> device_states_;
0112
0113
0114
0115 template<MemSpace M, class Self>
0116 static constexpr decltype(auto) states_impl(Self&& self)
0117 {
0118 if constexpr (M == MemSpace::host)
0119 {
0120
0121 return (self.host_states_);
0122 }
0123 #ifndef __NVCC__
0124
0125
0126
0127 else
0128 #endif
0129 return (self.device_states_);
0130 }
0131
0132 template<MemSpace M, class Self>
0133 static decltype(auto) stateptr_impl(Self&& self, StreamId stream_id)
0134 {
0135 CELER_EXPECT(stream_id < self.num_streams_ || !self);
0136 using result_type = std::add_pointer_t<
0137 decltype(StreamStore::states_impl<M>(self).front().ref())>;
0138 if (!self)
0139 {
0140 return result_type{nullptr};
0141 }
0142
0143 auto& state_vec = StreamStore::states_impl<M>(self);
0144 CELER_ASSERT(state_vec.size() == self.num_streams_);
0145 auto& state_store = state_vec[stream_id.unchecked_get()];
0146 if (!state_store)
0147 {
0148 return result_type{nullptr};
0149 }
0150
0151 return &state_store.ref();
0152 }
0153 };
0154
0155
0156
0157
0158
0159
0160
0161
0162
0163
0164 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0165 StreamStore<P, S>::StreamStore(ParamsHostVal&& host,
0166 StreamId::size_type num_streams)
0167 : params_(std::move(host)), num_streams_(num_streams)
0168 {
0169 CELER_EXPECT(params_);
0170 CELER_EXPECT(num_streams_ > 0);
0171
0172
0173 host_states_.resize(num_streams_);
0174 device_states_.resize(num_streams_);
0175 }
0176
0177
0178
0179
0180
0181 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0182 template<MemSpace M>
0183 P<Ownership::const_reference, M> const& StreamStore<P, S>::params() const
0184 {
0185 CELER_EXPECT(*this);
0186 return params_.template ref<M>();
0187 }
0188
0189
0190
0191
0192
0193 template<template<Ownership, MemSpace> class P, template<Ownership, MemSpace> class S>
0194 template<MemSpace M>
0195 S<Ownership::reference, M>&
0196 StreamStore<P, S>::state(StreamId stream_id, size_type size)
0197 {
0198 CELER_EXPECT(*this);
0199 CELER_EXPECT(stream_id < num_streams_);
0200
0201 auto& state_vec = StreamStore::states_impl<M>(*this);
0202 CELER_ASSERT(state_vec.size() == num_streams_);
0203 auto& state_store = state_vec[stream_id.unchecked_get()];
0204 if (CELER_UNLIKELY(!state_store))
0205 {
0206 state_store = {this->params<MemSpace::host>(), stream_id, size};
0207 }
0208
0209 CELER_ENSURE(state_store.size() == size);
0210 return state_store.ref();
0211 }
0212
0213
0214
0215
0216
0217
0218
0219 template<class S, class F>
0220 void apply_to_all_streams(S&& store, F&& func)
0221 {
0222
0223 for (StreamId s : range(StreamId{store.num_streams()}))
0224 {
0225 if (auto* state = store.template state<MemSpace::host>(s))
0226 {
0227 func(*state);
0228 }
0229 }
0230
0231
0232 for (StreamId s : range(StreamId{store.num_streams()}))
0233 {
0234 if (auto* state = store.template state<MemSpace::device>(s))
0235 {
0236 func(*state);
0237 }
0238 }
0239 }
0240
0241
0242
0243
0244
0245 template<class S, class F, class T>
0246 void accumulate_over_streams(S&& store, F&& func, std::vector<T>* result)
0247 {
0248 std::vector<T> temp_host;
0249
0250
0251 for (StreamId s : range(StreamId{store.num_streams()}))
0252 {
0253 if (auto* state = store.template state<MemSpace::host>(s))
0254 {
0255 auto data = func(*state)[AllItems<T>{}];
0256 CELER_EXPECT(data.size() == result->size());
0257 for (auto i : range(data.size()))
0258 {
0259 (*result)[i] += data[i];
0260 }
0261 }
0262 }
0263
0264
0265 for (StreamId s : range(StreamId{store.num_streams()}))
0266 {
0267 if (auto* state = store.template state<MemSpace::device>(s))
0268 {
0269 auto data = func(*state);
0270 CELER_EXPECT(data.size() == result->size());
0271
0272 if (temp_host.empty())
0273 {
0274 temp_host.resize(result->size());
0275 }
0276 copy_to_host(data, make_span(temp_host));
0277
0278 for (auto i : range(data.size()))
0279 {
0280 (*result)[i] += temp_host[i];
0281 }
0282 }
0283 }
0284 }
0285
0286
0287 }