File indexing completed on 2025-01-18 09:12:16
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010 #include "Acts/Plugins/Cuda/Seeding2/Details/CountDublets.hpp"
0011 #include "Acts/Plugins/Cuda/Seeding2/Details/Types.hpp"
0012
0013 #include "../Utilities/ErrorCheck.cuh"
0014
0015
0016 #include <cuda_runtime.h>
0017
0018
0019 #include <algorithm>
0020
0021 namespace Acts {
0022 namespace Cuda {
0023 namespace Kernels {
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039 __global__ void countDublets(std::size_t nMiddleSPs,
0040 const unsigned int* middleBottomCounts,
0041 const unsigned int* middleTopCounts,
0042 Details::DubletCounts* dubletCounts) {
0043
0044 extern __shared__ Details::DubletCounts sum[];
0045
0046
0047
0048 const int middleIndex = blockIdx.x * blockDim.x * 2 + threadIdx.x;
0049
0050 Details::DubletCounts thisSum;
0051 if (middleIndex < nMiddleSPs) {
0052 thisSum.nDublets =
0053 (middleBottomCounts[middleIndex] + middleTopCounts[middleIndex]);
0054 thisSum.nTriplets =
0055 (middleBottomCounts[middleIndex] * middleTopCounts[middleIndex]);
0056 thisSum.maxMBDublets = middleBottomCounts[middleIndex];
0057 thisSum.maxMTDublets = middleTopCounts[middleIndex];
0058 thisSum.maxTriplets = thisSum.nTriplets;
0059 }
0060 if (middleIndex + blockDim.x < nMiddleSPs) {
0061 thisSum.nDublets += (middleBottomCounts[middleIndex + blockDim.x] +
0062 middleTopCounts[middleIndex + blockDim.x]);
0063 thisSum.nTriplets += (middleBottomCounts[middleIndex + blockDim.x] *
0064 middleTopCounts[middleIndex + blockDim.x]);
0065 thisSum.maxMBDublets =
0066 max(middleBottomCounts[middleIndex + blockDim.x], thisSum.maxMBDublets);
0067 thisSum.maxMTDublets =
0068 max(middleTopCounts[middleIndex + blockDim.x], thisSum.maxMTDublets);
0069 thisSum.maxTriplets = max((middleBottomCounts[middleIndex + blockDim.x] *
0070 middleTopCounts[middleIndex + blockDim.x]),
0071 thisSum.maxTriplets);
0072 }
0073
0074
0075 sum[threadIdx.x] = thisSum;
0076 __syncthreads();
0077
0078
0079 for (unsigned int i = blockDim.x / 2; i > 0; i >>= 1) {
0080 if (threadIdx.x < i) {
0081 const Details::DubletCounts& otherSum = sum[threadIdx.x + i];
0082 thisSum.nDublets += otherSum.nDublets;
0083 thisSum.nTriplets += otherSum.nTriplets;
0084 thisSum.maxMBDublets = max(thisSum.maxMBDublets, otherSum.maxMBDublets);
0085 thisSum.maxMTDublets = max(thisSum.maxMTDublets, otherSum.maxMTDublets);
0086 thisSum.maxTriplets = max(thisSum.maxTriplets, otherSum.maxTriplets);
0087 sum[threadIdx.x] = thisSum;
0088 }
0089 __syncthreads();
0090 }
0091
0092
0093 if (threadIdx.x == 0) {
0094 dubletCounts[blockIdx.x] = thisSum;
0095 }
0096 return;
0097 }
0098
0099 }
0100
0101 namespace Details {
0102
0103 DubletCounts countDublets(
0104 std::size_t maxBlockSize, std::size_t nMiddleSP,
0105 const device_array<unsigned int>& middleBottomCountArray,
0106 const device_array<unsigned int>& middleTopCountArray) {
0107
0108 const int numBlocks = (nMiddleSP + maxBlockSize - 1) / maxBlockSize;
0109 const int sharedMem = maxBlockSize * sizeof(DubletCounts);
0110
0111
0112
0113 auto dubletCountsDevice = make_device_array<DubletCounts>(numBlocks);
0114
0115
0116 Kernels::countDublets<<<numBlocks, maxBlockSize, sharedMem>>>(
0117 nMiddleSP, middleBottomCountArray.get(), middleTopCountArray.get(),
0118 dubletCountsDevice.get());
0119 ACTS_CUDA_ERROR_CHECK(cudaGetLastError());
0120 ACTS_CUDA_ERROR_CHECK(cudaDeviceSynchronize());
0121
0122
0123 auto dubletCountsHost = make_host_array<DubletCounts>(numBlocks);
0124 copyToHost(dubletCountsHost, dubletCountsDevice, numBlocks);
0125
0126
0127
0128
0129
0130 DubletCounts result;
0131 for (int i = 0; i < numBlocks; ++i) {
0132 result.nDublets += dubletCountsHost.get()[i].nDublets;
0133 result.nTriplets += dubletCountsHost.get()[i].nTriplets;
0134 result.maxMBDublets =
0135 std::max(dubletCountsHost.get()[i].maxMBDublets, result.maxMBDublets);
0136 result.maxMTDublets =
0137 std::max(dubletCountsHost.get()[i].maxMTDublets, result.maxMTDublets);
0138 result.maxTriplets =
0139 std::max(dubletCountsHost.get()[i].maxTriplets, result.maxTriplets);
0140 }
0141 return result;
0142 }
0143
0144 }
0145 }
0146 }