Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 08:54:05

0001 //------------------------------- -*- C++ -*- -------------------------------//
0002 // Copyright Celeritas contributors: see top-level COPYRIGHT file for details
0003 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0004 //---------------------------------------------------------------------------//
0005 //! \file corecel/data/DeviceVector.hh
0006 //---------------------------------------------------------------------------//
0007 #pragma once
0008 
0009 #include <type_traits>
0010 
0011 #include "corecel/cont/InitializedValue.hh"
0012 #include "corecel/cont/Span.hh"
0013 #include "corecel/sys/ThreadId.hh"
0014 
0015 #include "DeviceAllocation.hh"
0016 #include "ObserverPtr.hh"
0017 
0018 namespace celeritas
0019 {
0020 //---------------------------------------------------------------------------//
0021 /*!
0022  * Host vector for managing uninitialized device-storage data.
0023  *
0024  * This is a class used only in host memory (not passed to kernels) to manage
0025  * device allocation and host/device copies.  It does \em not perform
0026  * initialization on the data: the host code must define and copy over suitable
0027  * data.
0028  *
0029  * For more complex data usage (dynamic size increases, std vector-like access,
0030  * object initialization), use \c thrust::device_vector inside a \c .cu file.
0031  *
0032  * When a \c StreamId is passed as the last constructor argument,
0033  * all memory operations are asynchronous and ordered within that stream.
0034  *
0035  * \code
0036     DeviceVector<double> myvec(100);
0037     myvec.copy_to_device(make_span(hostvec));
0038     myvec.copy_to_host(make_span(hostvec));
0039    \endcode
0040  *
0041  * - TODO: remove stream? it complicates things
0042  * - TODO: move to detail since this is basically only a backend for Collection
0043  */
0044 template<class T>
0045 class DeviceVector
0046 {
0047 #if !CELERITAS_USE_HIP
0048     // rocrand states have nontrivial destructors, and some HIP integer types
0049     // are not trivially copyable
0050     static_assert(std::is_trivially_copyable<T>::value,
0051                   "DeviceVector element is not trivially copyable");
0052 
0053     static_assert(std::is_trivially_destructible<T>::value,
0054                   "DeviceVector element is not trivially destructible");
0055 #endif
0056 
0057   public:
0058     //!@{
0059     //! \name Type aliases
0060     using value_type = T;
0061     using SpanT = Span<T>;
0062     using SpanConstT = Span<T const>;
0063     //!@}
0064 
0065   public:
0066     // Construct with no elements
0067     DeviceVector() = default;
0068 
0069     // Construct with no elements
0070     explicit DeviceVector(StreamId stream);
0071 
0072     // Construct with a number of elements
0073     explicit DeviceVector(size_type count);
0074 
0075     // Construct with a number of elements
0076     DeviceVector(size_type count, StreamId stream);
0077 
0078     // Swap with another vector
0079     inline void swap(DeviceVector& other) noexcept;
0080 
0081     // Allocate and copy from host pointers
0082     void assign(T const* first, T const* last);
0083 
0084     //// ACCESSORS ////
0085 
0086     //! Get the number of elements
0087     size_type size() const { return size_; }
0088 
0089     //! Whether any elements are stored
0090     bool empty() const { return size_ == 0; }
0091 
0092     //// DEVICE ACCESSORS ////
0093 
0094     // Copy data to device
0095     inline void copy_to_device(SpanConstT host_data);
0096 
0097     // Copy data to host
0098     inline void copy_to_host(SpanT host_data) const;
0099 
0100     // Get a mutable view to device data
0101     SpanT device_ref() { return {this->data(), this->size()}; }
0102 
0103     // Get a const view to device data
0104     SpanConstT device_ref() const { return {this->data(), this->size()}; }
0105 
0106     // Raw pointer to device data (dangerous!)
0107     inline T* data();
0108 
0109     // Raw pointer to device data (dangerous!)
0110     inline T const* data() const;
0111 
0112   private:
0113     DeviceAllocation allocation_;
0114     InitializedValue<size_type> size_;
0115 };
0116 
0117 // Swap two vectors.
0118 template<class T>
0119 inline void swap(DeviceVector<T>& a, DeviceVector<T>& b) noexcept;
0120 
0121 //---------------------------------------------------------------------------//
0122 // INLINE DEFINITIONS
0123 //---------------------------------------------------------------------------//
0124 /*!
0125  * Construct with a stream.
0126  */
0127 template<class T>
0128 DeviceVector<T>::DeviceVector(StreamId stream) : allocation_{stream}, size_{0}
0129 {
0130 }
0131 
0132 //---------------------------------------------------------------------------//
0133 /*!
0134  * Construct with a number of allocated elements.
0135  */
0136 template<class T>
0137 DeviceVector<T>::DeviceVector(size_type count)
0138     : allocation_{count * sizeof(T)}, size_{count}
0139 {
0140 }
0141 
0142 //---------------------------------------------------------------------------//
0143 /*!
0144  * Construct with a number of allocated elements and a stream.
0145  *
0146  * To make resizing eaasier, the stream may be null.
0147  */
0148 template<class T>
0149 DeviceVector<T>::DeviceVector(size_type count, StreamId stream)
0150 {
0151     if (stream)
0152     {
0153         allocation_ = DeviceAllocation{count * sizeof(T), stream};
0154     }
0155     else
0156     {
0157         allocation_ = DeviceAllocation{count * sizeof(T)};
0158     }
0159     size_ = count;
0160 }
0161 
0162 //---------------------------------------------------------------------------//
0163 /*!
0164  * Get the device data pointer.
0165  */
0166 template<class T>
0167 void DeviceVector<T>::swap(DeviceVector& other) noexcept
0168 {
0169     using std::swap;
0170     swap(size_, other.size_);
0171     swap(allocation_, other.allocation_);
0172 }
0173 
0174 //---------------------------------------------------------------------------//
0175 /*!
0176  * Allocate and copy from \em host pointers.
0177  *
0178  * Not exception safe: if the copy fails, the original contents are lost.
0179  */
0180 template<class T>
0181 void DeviceVector<T>::assign(T const* first, T const* last)
0182 {
0183     auto const new_size = static_cast<size_type>(last - first);
0184     if (new_size > size_ && new_size * sizeof(T) > allocation_.size())
0185     {
0186         // Reallocate
0187         *this = DeviceVector<T>(new_size, allocation_.stream_id());
0188     }
0189     else
0190     {
0191         // Update size to fit capacity
0192         size_ = new_size;
0193     }
0194 
0195     this->copy_to_device({first, new_size});
0196 }
0197 
0198 //---------------------------------------------------------------------------//
0199 /*!
0200  * Copy data to device.
0201  */
0202 template<class T>
0203 void DeviceVector<T>::copy_to_device(SpanConstT data)
0204 {
0205     CELER_EXPECT(data.size() == this->size());
0206     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
0207     allocation_.copy_to_device({reinterpret_cast<std::byte const*>(data.data()),
0208                                 data.size() * sizeof(T)});
0209 }
0210 
0211 //---------------------------------------------------------------------------//
0212 /*!
0213  * Copy data to host.
0214  */
0215 template<class T>
0216 void DeviceVector<T>::copy_to_host(SpanT data) const
0217 {
0218     CELER_EXPECT(data.size() == this->size());
0219     allocation_.copy_to_host(
0220         // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
0221         {reinterpret_cast<std::byte*>(data.data()), data.size() * sizeof(T)});
0222 }
0223 
0224 //---------------------------------------------------------------------------//
0225 /*!
0226  * Get a device data pointer.
0227  */
0228 template<class T>
0229 T* DeviceVector<T>::data()
0230 {
0231     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
0232     return reinterpret_cast<T*>(allocation_.device_ref().data());
0233 }
0234 
0235 //---------------------------------------------------------------------------//
0236 /*!
0237  * Get a device data pointer.
0238  */
0239 template<class T>
0240 T const* DeviceVector<T>::data() const
0241 {
0242     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
0243     return reinterpret_cast<T const*>(allocation_.device_ref().data());
0244 }
0245 
0246 //---------------------------------------------------------------------------//
0247 /*!
0248  * Swap two vectors.
0249  */
0250 template<class T>
0251 void swap(DeviceVector<T>& a, DeviceVector<T>& b) noexcept
0252 {
0253     return a.swap(b);
0254 }
0255 
0256 //---------------------------------------------------------------------------//
0257 /*!
0258  * Prevent accidental construction of Span from a device vector.
0259  *
0260  * Use \c dv.device_ref() to get a span.
0261  */
0262 template<class T>
0263 CELER_FUNCTION Span<T const> make_span(DeviceVector<T> const& dv)
0264 {
0265     static_assert(sizeof(T) == 0, "Cannot 'make_span' from a device vector");
0266     return {dv.data(), dv.size()};
0267 }
0268 
0269 //---------------------------------------------------------------------------//
0270 //! Prevent accidental construction of Span from a device vector.
0271 template<class T>
0272 CELER_FUNCTION Span<T> make_span(DeviceVector<T>& dv)
0273 {
0274     static_assert(sizeof(T) == 0, "Cannot 'make_span' from a device vector");
0275     return {dv.data(), dv.size()};
0276 }
0277 
0278 //---------------------------------------------------------------------------//
0279 //! Create an observer pointer from a device vector.
0280 template<class T>
0281 ObserverPtr<T, MemSpace::device> make_observer(DeviceVector<T>& vec) noexcept
0282 {
0283     return ObserverPtr<T, MemSpace::device>{vec.data()};
0284 }
0285 
0286 //---------------------------------------------------------------------------//
0287 //! Create an observer pointer from a pointer in the native memspace.
0288 template<class T>
0289 ObserverPtr<T const, MemSpace::device>
0290 make_observer(DeviceVector<T> const& vec) noexcept
0291 {
0292     return ObserverPtr<T const, MemSpace::device>{vec.data()};
0293 }
0294 
0295 //---------------------------------------------------------------------------//
0296 }  // namespace celeritas