Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:29:55

0001 //---------------------------------------------------------------------------//
0002 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
0003 //
0004 // Distributed under the Boost Software License, Version 1.0
0005 // See accompanying file LICENSE_1_0.txt or copy at
0006 // http://www.boost.org/LICENSE_1_0.txt
0007 //
0008 // See http://boostorg.github.com/compute for more information.
0009 //---------------------------------------------------------------------------//
0010 
0011 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
0012 #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
0013 
0014 #include <boost/compute/kernel.hpp>
0015 #include <boost/compute/detail/meta_kernel.hpp>
0016 #include <boost/compute/command_queue.hpp>
0017 #include <boost/compute/container/vector.hpp>
0018 #include <boost/compute/detail/iterator_range_size.hpp>
0019 #include <boost/compute/memory/local_buffer.hpp>
0020 #include <boost/compute/iterator/buffer_iterator.hpp>
0021 
0022 namespace boost {
0023 namespace compute {
0024 namespace detail {
0025 
0026 template<class InputIterator, class OutputIterator, class BinaryOperator>
0027 class local_scan_kernel : public meta_kernel
0028 {
0029 public:
0030     local_scan_kernel(InputIterator first,
0031                       InputIterator last,
0032                       OutputIterator result,
0033                       bool exclusive,
0034                       BinaryOperator op)
0035         : meta_kernel("local_scan")
0036     {
0037         typedef typename std::iterator_traits<InputIterator>::value_type T;
0038 
0039         (void) last;
0040 
0041         bool checked = true;
0042 
0043         m_block_sums_arg = add_arg<T *>(memory_object::global_memory, "block_sums");
0044         m_scratch_arg = add_arg<T *>(memory_object::local_memory, "scratch");
0045         m_block_size_arg = add_arg<const cl_uint>("block_size");
0046         m_count_arg = add_arg<const cl_uint>("count");
0047         m_init_value_arg = add_arg<const T>("init");
0048 
0049         // work-item parameters
0050         *this <<
0051             "const uint gid = get_global_id(0);\n" <<
0052             "const uint lid = get_local_id(0);\n";
0053 
0054         // check against data size
0055         if(checked){
0056             *this <<
0057                 "if(gid < count){\n";
0058         }
0059 
0060         // copy values from input to local memory
0061         if(exclusive){
0062             *this <<
0063                 decl<const T>("local_init") << "= (gid == 0) ? init : 0;\n" <<
0064                 "if(lid == 0){ scratch[lid] = local_init; }\n" <<
0065                 "else { scratch[lid] = " << first[expr<cl_uint>("gid-1")] << "; }\n";
0066         }
0067         else{
0068             *this <<
0069                 "scratch[lid] = " << first[expr<cl_uint>("gid")] << ";\n";
0070         }
0071 
0072         if(checked){
0073             *this <<
0074                 "}\n"
0075                 "else {\n" <<
0076                 "    scratch[lid] = 0;\n" <<
0077                 "}\n";
0078         }
0079 
0080         // wait for all threads to read from input
0081         *this <<
0082             "barrier(CLK_LOCAL_MEM_FENCE);\n";
0083 
0084         // perform scan
0085         *this <<
0086             "for(uint i = 1; i < block_size; i <<= 1){\n" <<
0087             "    " << decl<const T>("x") << " = lid >= i ? scratch[lid-i] : 0;\n" <<
0088             "    barrier(CLK_LOCAL_MEM_FENCE);\n" <<
0089             "    if(lid >= i){\n" <<
0090             "        scratch[lid] = " << op(var<T>("scratch[lid]"), var<T>("x")) << ";\n" <<
0091             "    }\n" <<
0092             "    barrier(CLK_LOCAL_MEM_FENCE);\n" <<
0093             "}\n";
0094 
0095         // copy results to output
0096         if(checked){
0097             *this <<
0098                 "if(gid < count){\n";
0099         }
0100 
0101         *this <<
0102             result[expr<cl_uint>("gid")] << " = scratch[lid];\n";
0103 
0104         if(checked){
0105             *this << "}\n";
0106         }
0107 
0108         // store sum for the block
0109         if(exclusive){
0110             *this <<
0111                 "if(lid == block_size - 1 && gid < count) {\n" <<
0112                 "    block_sums[get_group_id(0)] = " <<
0113                        op(first[expr<cl_uint>("gid")], var<T>("scratch[lid]")) <<
0114                        ";\n" <<
0115                 "}\n";
0116         }
0117         else {
0118             *this <<
0119                 "if(lid == block_size - 1){\n" <<
0120                 "    block_sums[get_group_id(0)] = scratch[lid];\n" <<
0121                 "}\n";
0122         }
0123     }
0124 
0125     size_t m_block_sums_arg;
0126     size_t m_scratch_arg;
0127     size_t m_block_size_arg;
0128     size_t m_count_arg;
0129     size_t m_init_value_arg;
0130 };
0131 
0132 template<class T, class BinaryOperator>
0133 class write_scanned_output_kernel : public meta_kernel
0134 {
0135 public:
0136     write_scanned_output_kernel(BinaryOperator op)
0137         : meta_kernel("write_scanned_output")
0138     {
0139         bool checked = true;
0140 
0141         m_output_arg = add_arg<T *>(memory_object::global_memory, "output");
0142         m_block_sums_arg = add_arg<const T *>(memory_object::global_memory, "block_sums");
0143         m_count_arg = add_arg<const cl_uint>("count");
0144 
0145         // work-item parameters
0146         *this <<
0147             "const uint gid = get_global_id(0);\n" <<
0148             "const uint block_id = get_group_id(0);\n";
0149 
0150         // check against data size
0151         if(checked){
0152             *this << "if(gid < count){\n";
0153         }
0154 
0155         // write output
0156         *this <<
0157             "output[gid] = " <<
0158                 op(var<T>("block_sums[block_id]"), var<T>("output[gid] ")) << ";\n";
0159 
0160         if(checked){
0161             *this << "}\n";
0162         }
0163     }
0164 
0165     size_t m_output_arg;
0166     size_t m_block_sums_arg;
0167     size_t m_count_arg;
0168 };
0169 
0170 template<class InputIterator>
0171 inline size_t pick_scan_block_size(InputIterator first, InputIterator last)
0172 {
0173     size_t count = iterator_range_size(first, last);
0174 
0175     if(count == 0)        { return 0; }
0176     else if(count <= 1)   { return 1; }
0177     else if(count <= 2)   { return 2; }
0178     else if(count <= 4)   { return 4; }
0179     else if(count <= 8)   { return 8; }
0180     else if(count <= 16)  { return 16; }
0181     else if(count <= 32)  { return 32; }
0182     else if(count <= 64)  { return 64; }
0183     else if(count <= 128) { return 128; }
0184     else                  { return 256; }
0185 }
0186 
0187 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
0188 inline OutputIterator scan_impl(InputIterator first,
0189                                 InputIterator last,
0190                                 OutputIterator result,
0191                                 bool exclusive,
0192                                 T init,
0193                                 BinaryOperator op,
0194                                 command_queue &queue)
0195 {
0196     typedef typename
0197         std::iterator_traits<InputIterator>::value_type
0198         input_type;
0199     typedef typename
0200         std::iterator_traits<InputIterator>::difference_type
0201         difference_type;
0202     typedef typename
0203         std::iterator_traits<OutputIterator>::value_type
0204         output_type;
0205 
0206     const context &context = queue.get_context();
0207     const size_t count = detail::iterator_range_size(first, last);
0208 
0209     size_t block_size = pick_scan_block_size(first, last);
0210     size_t block_count = count / block_size;
0211 
0212     if(block_count * block_size < count){
0213         block_count++;
0214     }
0215 
0216     ::boost::compute::vector<input_type> block_sums(block_count, context);
0217 
0218     // zero block sums
0219     input_type zero;
0220     std::memset(&zero, 0, sizeof(input_type));
0221     ::boost::compute::fill(block_sums.begin(), block_sums.end(), zero, queue);
0222 
0223     // local scan
0224     local_scan_kernel<InputIterator, OutputIterator, BinaryOperator>
0225         local_scan_kernel(first, last, result, exclusive, op);
0226 
0227     ::boost::compute::kernel kernel = local_scan_kernel.compile(context);
0228     kernel.set_arg(local_scan_kernel.m_scratch_arg, local_buffer<input_type>(block_size));
0229     kernel.set_arg(local_scan_kernel.m_block_sums_arg, block_sums);
0230     kernel.set_arg(local_scan_kernel.m_block_size_arg, static_cast<cl_uint>(block_size));
0231     kernel.set_arg(local_scan_kernel.m_count_arg, static_cast<cl_uint>(count));
0232     kernel.set_arg(local_scan_kernel.m_init_value_arg, static_cast<output_type>(init));
0233 
0234     queue.enqueue_1d_range_kernel(kernel,
0235                                   0,
0236                                   block_count * block_size,
0237                                   block_size);
0238 
0239     // inclusive scan block sums
0240     if(block_count > 1){
0241         scan_impl(block_sums.begin(),
0242                   block_sums.end(),
0243                   block_sums.begin(),
0244                   false,
0245                   init,
0246                   op,
0247                   queue
0248         );
0249     }
0250 
0251     // add block sums to each block
0252     if(block_count > 1){
0253         write_scanned_output_kernel<input_type, BinaryOperator>
0254             write_output_kernel(op);
0255         kernel = write_output_kernel.compile(context);
0256         kernel.set_arg(write_output_kernel.m_output_arg, result.get_buffer());
0257         kernel.set_arg(write_output_kernel.m_block_sums_arg, block_sums);
0258         kernel.set_arg(write_output_kernel.m_count_arg, static_cast<cl_uint>(count));
0259 
0260         queue.enqueue_1d_range_kernel(kernel,
0261                                       block_size,
0262                                       block_count * block_size,
0263                                       block_size);
0264     }
0265 
0266     return result + static_cast<difference_type>(count);
0267 }
0268 
0269 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
0270 inline OutputIterator dispatch_scan(InputIterator first,
0271                                     InputIterator last,
0272                                     OutputIterator result,
0273                                     bool exclusive,
0274                                     T init,
0275                                     BinaryOperator op,
0276                                     command_queue &queue)
0277 {
0278     return scan_impl(first, last, result, exclusive, init, op, queue);
0279 }
0280 
0281 template<class InputIterator, class T, class BinaryOperator>
0282 inline InputIterator dispatch_scan(InputIterator first,
0283                                    InputIterator last,
0284                                    InputIterator result,
0285                                    bool exclusive,
0286                                    T init,
0287                                    BinaryOperator op,
0288                                    command_queue &queue)
0289 {
0290     typedef typename std::iterator_traits<InputIterator>::value_type value_type;
0291 
0292     if(first == result){
0293         // scan input in-place
0294         const context &context = queue.get_context();
0295 
0296         // make a temporary copy the input
0297         size_t count = iterator_range_size(first, last);
0298         vector<value_type> tmp(count, context);
0299         copy(first, last, tmp.begin(), queue);
0300 
0301         // scan from temporary values
0302         return scan_impl(tmp.begin(), tmp.end(), first, exclusive, init, op, queue);
0303     }
0304     else {
0305         // scan input to output
0306         return scan_impl(first, last, result, exclusive, init, op, queue);
0307     }
0308 }
0309 
0310 template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
0311 inline OutputIterator scan_on_gpu(InputIterator first,
0312                                   InputIterator last,
0313                                   OutputIterator result,
0314                                   bool exclusive,
0315                                   T init,
0316                                   BinaryOperator op,
0317                                   command_queue &queue)
0318 {
0319     if(first == last){
0320         return result;
0321     }
0322 
0323     return dispatch_scan(first, last, result, exclusive, init, op, queue);
0324 }
0325 
0326 } // end detail namespace
0327 } // end compute namespace
0328 } // end boost namespace
0329 
0330 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP