Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-17 08:54:12

0001 /*
0002  * SPDX-PackageName: "covfie, a part of the ACTS project"
0003  * SPDX-FileCopyrightText: 2022 CERN
0004  * SPDX-License-Identifier: MPL-2.0
0005  */
0006 
0007 #pragma once
0008 
0009 #include <memory>
0010 #include <optional>
0011 
0012 #include <hip/hip_runtime.h>
0013 
0014 #include <covfie/hip/error_check.hpp>
0015 #include <covfie/hip/utility/unique_ptr.hpp>
0016 
0017 namespace covfie::utility::hip {
0018 template <typename T>
0019 unique_device_ptr<T> device_allocate()
0020 {
0021     static_assert(
0022         !(std::is_array_v<T> && std::extent_v<T> == 0),
0023         "Allocation pointer type cannot be an unbounded array."
0024     );
0025 
0026     using pointer_t =
0027         std::conditional_t<std::is_array_v<T>, std::decay_t<T>, T *>;
0028 
0029     pointer_t p;
0030 
0031     hipErrorCheck(hipMalloc(&p, sizeof(T)));
0032 
0033     return unique_device_ptr<T>(p);
0034 }
0035 
0036 template <typename T>
0037 unique_device_ptr<T> device_allocate(std::size_t n)
0038 {
0039     static_assert(
0040         std::is_array_v<T>, "Allocation pointer type must be an array type."
0041     );
0042     static_assert(
0043         std::extent_v<T> == 0, "Allocation pointer type must be unbounded."
0044     );
0045 
0046     using pointer_t =
0047         std::conditional_t<std::is_array_v<T>, std::decay_t<T>, T *>;
0048 
0049     pointer_t p;
0050 
0051     hipErrorCheck(hipMalloc(&p, n * sizeof(std::remove_extent_t<T>)));
0052 
0053     return unique_device_ptr<T>(p);
0054 }
0055 
0056 template <typename T>
0057 unique_device_ptr<T[]>
0058 device_copy_h2d(const T * h, std::optional<hipStream_t> stream = std::nullopt)
0059 {
0060     unique_device_ptr<T[]> r = device_allocate<T[]>();
0061 
0062     if (stream.has_value()) {
0063         hipErrorCheck(hipMemcpyAsync(
0064             r.get(), h, sizeof(T), hipMemcpyHostToDevice, *stream
0065         ));
0066         hipErrorCheck(hipStreamSynchronize(*stream));
0067     } else {
0068         hipErrorCheck(hipMemcpy(r.get(), h, sizeof(T), hipMemcpyHostToDevice));
0069     }
0070 
0071     return r;
0072 }
0073 
0074 template <typename T>
0075 unique_device_ptr<T[]> device_copy_h2d(
0076     const T * h, std::size_t n, std::optional<hipStream_t> stream = std::nullopt
0077 )
0078 {
0079     unique_device_ptr<T[]> r = device_allocate<T[]>(n);
0080 
0081     if (stream.has_value()) {
0082         hipErrorCheck(hipMemcpyAsync(
0083             r.get(),
0084             h,
0085             n * sizeof(std::remove_extent_t<T>),
0086             hipMemcpyHostToDevice,
0087             *stream
0088         ));
0089         hipErrorCheck(hipStreamSynchronize(*stream));
0090     } else {
0091         hipErrorCheck(hipMemcpy(
0092             r.get(),
0093             h,
0094             n * sizeof(std::remove_extent_t<T>),
0095             hipMemcpyHostToDevice
0096         ));
0097     }
0098 
0099     return r;
0100 }
0101 
0102 template <typename T>
0103 unique_device_ptr<T[]>
0104 device_copy_d2d(const T * h, std::optional<hipStream_t> stream = std::nullopt)
0105 {
0106     unique_device_ptr<T[]> r = device_allocate<T[]>();
0107 
0108     if (stream.has_value()) {
0109         hipErrorCheck(hipMemcpyAsync(
0110             r.get(), h, sizeof(T), hipMemcpyDeviceToDevice, *stream
0111         ));
0112         hipErrorCheck(hipStreamSynchronize(*stream));
0113     } else {
0114         hipErrorCheck(hipMemcpy(r.get(), h, sizeof(T), hipMemcpyDeviceToDevice)
0115         );
0116     }
0117 
0118     return r;
0119 }
0120 
0121 template <typename T>
0122 unique_device_ptr<T[]> device_copy_d2d(
0123     const T * h, std::size_t n, std::optional<hipStream_t> stream = std::nullopt
0124 )
0125 {
0126     unique_device_ptr<T[]> r = device_allocate<T[]>(n);
0127 
0128     if (stream.has_value()) {
0129         hipErrorCheck(hipMemcpyAsync(
0130             r.get(),
0131             h,
0132             n * sizeof(std::remove_extent_t<T>),
0133             hipMemcpyDeviceToDevice,
0134             *stream
0135         ));
0136         hipErrorCheck(hipStreamSynchronize(*stream));
0137     } else {
0138         hipErrorCheck(hipMemcpy(
0139             r.get(),
0140             h,
0141             n * sizeof(std::remove_extent_t<T>),
0142             hipMemcpyDeviceToDevice
0143         ));
0144     }
0145 
0146     return r;
0147 }
0148 }