Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:29:54

0001 //---------------------------------------------------------------------------//
0002 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
0003 //
0004 // Distributed under the Boost Software License, Version 1.0
0005 // See accompanying file LICENSE_1_0.txt or copy at
0006 // http://www.boost.org/LICENSE_1_0.txt
0007 //
0008 // See http://boostorg.github.com/compute for more information.
0009 //---------------------------------------------------------------------------//
0010 
0011 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
0012 #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
0013 
0014 #include <iterator>
0015 
0016 #include <boost/assert.hpp>
0017 #include <boost/type_traits/is_signed.hpp>
0018 #include <boost/type_traits/is_floating_point.hpp>
0019 
0020 #include <boost/mpl/and.hpp>
0021 #include <boost/mpl/not.hpp>
0022 
0023 #include <boost/compute/kernel.hpp>
0024 #include <boost/compute/program.hpp>
0025 #include <boost/compute/command_queue.hpp>
0026 #include <boost/compute/algorithm/exclusive_scan.hpp>
0027 #include <boost/compute/container/vector.hpp>
0028 #include <boost/compute/detail/iterator_range_size.hpp>
0029 #include <boost/compute/detail/parameter_cache.hpp>
0030 #include <boost/compute/type_traits/type_name.hpp>
0031 #include <boost/compute/type_traits/is_fundamental.hpp>
0032 #include <boost/compute/type_traits/is_vector_type.hpp>
0033 #include <boost/compute/utility/program_cache.hpp>
0034 
0035 namespace boost {
0036 namespace compute {
0037 namespace detail {
0038 
0039 // meta-function returning true if type T is radix-sortable
0040 template<class T>
0041 struct is_radix_sortable :
0042     boost::mpl::and_<
0043         typename ::boost::compute::is_fundamental<T>::type,
0044         typename boost::mpl::not_<typename is_vector_type<T>::type>::type
0045     >
0046 {
0047 };
0048 
0049 template<size_t N>
0050 struct radix_sort_value_type
0051 {
0052 };
0053 
0054 template<>
0055 struct radix_sort_value_type<1>
0056 {
0057     typedef uchar_ type;
0058 };
0059 
0060 template<>
0061 struct radix_sort_value_type<2>
0062 {
0063     typedef ushort_ type;
0064 };
0065 
0066 template<>
0067 struct radix_sort_value_type<4>
0068 {
0069     typedef uint_ type;
0070 };
0071 
0072 template<>
0073 struct radix_sort_value_type<8>
0074 {
0075     typedef ulong_ type;
0076 };
0077 
0078 template<typename T>
0079 inline const char* enable_double()
0080 {
0081     return " -DT2_double=0";
0082 }
0083 
0084 template<>
0085 inline const char* enable_double<double>()
0086 {
0087     return " -DT2_double=1";
0088 }
0089 
0090 const char radix_sort_source[] =
0091 "#if T2_double\n"
0092 "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
0093 "#endif\n"
0094 "#define K2_BITS (1 << K_BITS)\n"
0095 "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
0096 "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
0097 
0098 "#if defined(ASC)\n" // asc order
0099 
0100 "inline uint radix(const T x, const uint low_bit)\n"
0101 "{\n"
0102 "#if defined(IS_FLOATING_POINT)\n"
0103 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
0104 "    return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
0105 "#elif defined(IS_SIGNED)\n"
0106 "    return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
0107 "#else\n"
0108 "    return (x >> low_bit) & RADIX_MASK;\n"
0109 "#endif\n"
0110 "}\n"
0111 
0112 "#else\n" // desc order
0113 
0114 // For signed types we just negate the x and for unsigned types we
0115 // subtract the x from max value of its type ((T)(-1) is a max value
0116 // of type T when T is an unsigned type).
0117 "inline uint radix(const T x, const uint low_bit)\n"
0118 "{\n"
0119 "#if defined(IS_FLOATING_POINT)\n"
0120 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
0121 "    return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
0122 "#elif defined(IS_SIGNED)\n"
0123 "    return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
0124 "#else\n"
0125 "    return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
0126 "#endif\n"
0127 "}\n"
0128 
0129 "#endif\n" // #if defined(ASC)
0130 
0131 "__kernel void count(__global const T *input,\n"
0132 "                    const uint input_offset,\n"
0133 "                    const uint input_size,\n"
0134 "                    __global uint *global_counts,\n"
0135 "                    __global uint *global_offsets,\n"
0136 "                    __local uint *local_counts,\n"
0137 "                    const uint low_bit)\n"
0138 "{\n"
0139      // work-item parameters
0140 "    const uint gid = get_global_id(0);\n"
0141 "    const uint lid = get_local_id(0);\n"
0142 
0143      // zero local counts
0144 "    if(lid < K2_BITS){\n"
0145 "        local_counts[lid] = 0;\n"
0146 "    }\n"
0147 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
0148 
0149      // reduce local counts
0150 "    if(gid < input_size){\n"
0151 "        T value = input[input_offset+gid];\n"
0152 "        uint bucket = radix(value, low_bit);\n"
0153 "        atomic_inc(local_counts + bucket);\n"
0154 "    }\n"
0155 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
0156 
0157      // write block-relative offsets
0158 "    if(lid < K2_BITS){\n"
0159 "        global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
0160 
0161          // write global offsets
0162 "        if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
0163 "            global_offsets[lid] = local_counts[lid];\n"
0164 "        }\n"
0165 "    }\n"
0166 "}\n"
0167 
0168 "__kernel void scan(__global const uint *block_offsets,\n"
0169 "                   __global uint *global_offsets,\n"
0170 "                   const uint block_count)\n"
0171 "{\n"
0172 "    __global const uint *last_block_offsets =\n"
0173 "        block_offsets + K2_BITS * (block_count - 1);\n"
0174 
0175      // calculate and scan global_offsets
0176 "    uint sum = 0;\n"
0177 "    for(uint i = 0; i < K2_BITS; i++){\n"
0178 "        uint x = global_offsets[i] + last_block_offsets[i];\n"
0179 "        mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX 500/Vega bug, see #811
0180 "        global_offsets[i] = sum;\n"
0181 "        sum += x;\n"
0182 "        mem_fence(CLK_GLOBAL_MEM_FENCE);\n" // work around the RX Vega bug, see #811
0183 "    }\n"
0184 "}\n"
0185 
0186 "__kernel void scatter(__global const T *input,\n"
0187 "                      const uint input_offset,\n"
0188 "                      const uint input_size,\n"
0189 "                      const uint low_bit,\n"
0190 "                      __global const uint *counts,\n"
0191 "                      __global const uint *global_offsets,\n"
0192 "#ifndef SORT_BY_KEY\n"
0193 "                      __global T *output,\n"
0194 "                      const uint output_offset)\n"
0195 "#else\n"
0196 "                      __global T *keys_output,\n"
0197 "                      const uint keys_output_offset,\n"
0198 "                      __global T2 *values_input,\n"
0199 "                      const uint values_input_offset,\n"
0200 "                      __global T2 *values_output,\n"
0201 "                      const uint values_output_offset)\n"
0202 "#endif\n"
0203 "{\n"
0204      // work-item parameters
0205 "    const uint gid = get_global_id(0);\n"
0206 "    const uint lid = get_local_id(0);\n"
0207 
0208      // copy input to local memory
0209 "    T value;\n"
0210 "    uint bucket;\n"
0211 "    __local uint local_input[BLOCK_SIZE];\n"
0212 "    if(gid < input_size){\n"
0213 "        value = input[input_offset+gid];\n"
0214 "        bucket = radix(value, low_bit);\n"
0215 "        local_input[lid] = bucket;\n"
0216 "    }\n"
0217 
0218      // copy block counts to local memory
0219 "    __local uint local_counts[(1 << K_BITS)];\n"
0220 "    if(lid < K2_BITS){\n"
0221 "        local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
0222 "    }\n"
0223 
0224      // wait until local memory is ready
0225 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
0226 
0227 "    if(gid >= input_size){\n"
0228 "        return;\n"
0229 "    }\n"
0230 
0231      // get global offset
0232 "    uint offset = global_offsets[bucket] + local_counts[bucket];\n"
0233 
0234      // calculate local offset
0235 "    uint local_offset = 0;\n"
0236 "    for(uint i = 0; i < lid; i++){\n"
0237 "        if(local_input[i] == bucket)\n"
0238 "            local_offset++;\n"
0239 "    }\n"
0240 
0241 "#ifndef SORT_BY_KEY\n"
0242      // write value to output
0243 "    output[output_offset + offset + local_offset] = value;\n"
0244 "#else\n"
0245      // write key and value if doing sort_by_key
0246 "    keys_output[keys_output_offset+offset + local_offset] = value;\n"
0247 "    values_output[values_output_offset+offset + local_offset] =\n"
0248 "        values_input[values_input_offset+gid];\n"
0249 "#endif\n"
0250 "}\n";
0251 
0252 template<class T, class T2>
0253 inline void radix_sort_impl(const buffer_iterator<T> first,
0254                             const buffer_iterator<T> last,
0255                             const buffer_iterator<T2> values_first,
0256                             const bool ascending,
0257                             command_queue &queue)
0258 {
0259 
0260     typedef T value_type;
0261     typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
0262 
0263     const device &device = queue.get_device();
0264     const context &context = queue.get_context();
0265 
0266 
0267     // if we have a valid values iterator then we are doing a
0268     // sort by key and have to set up the values buffer
0269     bool sort_by_key = (values_first.get_buffer().get() != 0);
0270 
0271     // load (or create) radix sort program
0272     std::string cache_key =
0273         std::string("__boost_radix_sort_") + type_name<value_type>();
0274 
0275     if(sort_by_key){
0276         cache_key += std::string("_with_") + type_name<T2>();
0277     }
0278 
0279     boost::shared_ptr<program_cache> cache =
0280         program_cache::get_global_cache(context);
0281     boost::shared_ptr<parameter_cache> parameters =
0282         detail::parameter_cache::get_global_cache(device);
0283 
0284     // sort parameters
0285     const uint_ k = parameters->get(cache_key, "k", 4);
0286     const uint_ k2 = 1 << k;
0287     const uint_ block_size = parameters->get(cache_key, "tpb", 128);
0288 
0289     // sort program compiler options
0290     std::stringstream options;
0291     options << "-DK_BITS=" << k;
0292     options << " -DT=" << type_name<sort_type>();
0293     options << " -DBLOCK_SIZE=" << block_size;
0294 
0295     if(boost::is_floating_point<value_type>::value){
0296         options << " -DIS_FLOATING_POINT";
0297     }
0298 
0299     if(boost::is_signed<value_type>::value){
0300         options << " -DIS_SIGNED";
0301     }
0302 
0303     if(sort_by_key){
0304         options << " -DSORT_BY_KEY";
0305         options << " -DT2=" << type_name<T2>();
0306         options << enable_double<T2>();
0307     }
0308 
0309     if(ascending){
0310         options << " -DASC";
0311     }
0312 
0313     // get type definition if it is a custom struct
0314     std::string custom_type_def = boost::compute::type_definition<T2>() + "\n";
0315 
0316     // load radix sort program
0317     program radix_sort_program = cache->get_or_build(
0318        cache_key, options.str(), custom_type_def + radix_sort_source, context
0319     );
0320 
0321     kernel count_kernel(radix_sort_program, "count");
0322     kernel scan_kernel(radix_sort_program, "scan");
0323     kernel scatter_kernel(radix_sort_program, "scatter");
0324 
0325     size_t count = detail::iterator_range_size(first, last);
0326 
0327     uint_ block_count = static_cast<uint_>(count / block_size);
0328     if(block_count * block_size != count){
0329         block_count++;
0330     }
0331 
0332     // setup temporary buffers
0333     vector<value_type> output(count, context);
0334     vector<T2> values_output(sort_by_key ? count : 0, context);
0335     vector<uint_> offsets(k2, context);
0336     vector<uint_> counts(block_count * k2, context);
0337 
0338     const buffer *input_buffer = &first.get_buffer();
0339     uint_ input_offset = static_cast<uint_>(first.get_index());
0340     const buffer *output_buffer = &output.get_buffer();
0341     uint_ output_offset = 0;
0342     const buffer *values_input_buffer = &values_first.get_buffer();
0343     uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
0344     const buffer *values_output_buffer = &values_output.get_buffer();
0345     uint_ values_output_offset = 0;
0346 
0347     for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
0348         // write counts
0349         count_kernel.set_arg(0, *input_buffer);
0350         count_kernel.set_arg(1, input_offset);
0351         count_kernel.set_arg(2, static_cast<uint_>(count));
0352         count_kernel.set_arg(3, counts);
0353         count_kernel.set_arg(4, offsets);
0354         count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
0355         count_kernel.set_arg(6, i * k);
0356         queue.enqueue_1d_range_kernel(count_kernel,
0357                                       0,
0358                                       block_count * block_size,
0359                                       block_size);
0360 
0361         // scan counts
0362         if(k == 1){
0363             typedef uint2_ counter_type;
0364             ::boost::compute::exclusive_scan(
0365                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
0366                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
0367                 make_buffer_iterator<counter_type>(counts.get_buffer()),
0368                 queue
0369             );
0370         }
0371         else if(k == 2){
0372             typedef uint4_ counter_type;
0373             ::boost::compute::exclusive_scan(
0374                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
0375                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
0376                 make_buffer_iterator<counter_type>(counts.get_buffer()),
0377                 queue
0378             );
0379         }
0380         else if(k == 4){
0381             typedef uint16_ counter_type;
0382             ::boost::compute::exclusive_scan(
0383                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
0384                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
0385                 make_buffer_iterator<counter_type>(counts.get_buffer()),
0386                 queue
0387             );
0388         }
0389         else {
0390             BOOST_ASSERT(false && "unknown k");
0391             break;
0392         }
0393 
0394         // scan global offsets
0395         scan_kernel.set_arg(0, counts);
0396         scan_kernel.set_arg(1, offsets);
0397         scan_kernel.set_arg(2, block_count);
0398         queue.enqueue_task(scan_kernel);
0399 
0400         // scatter values
0401         scatter_kernel.set_arg(0, *input_buffer);
0402         scatter_kernel.set_arg(1, input_offset);
0403         scatter_kernel.set_arg(2, static_cast<uint_>(count));
0404         scatter_kernel.set_arg(3, i * k);
0405         scatter_kernel.set_arg(4, counts);
0406         scatter_kernel.set_arg(5, offsets);
0407         scatter_kernel.set_arg(6, *output_buffer);
0408         scatter_kernel.set_arg(7, output_offset);
0409         if(sort_by_key){
0410             scatter_kernel.set_arg(8, *values_input_buffer);
0411             scatter_kernel.set_arg(9, values_input_offset);
0412             scatter_kernel.set_arg(10, *values_output_buffer);
0413             scatter_kernel.set_arg(11, values_output_offset);
0414         }
0415         queue.enqueue_1d_range_kernel(scatter_kernel,
0416                                       0,
0417                                       block_count * block_size,
0418                                       block_size);
0419 
0420         // swap buffers
0421         std::swap(input_buffer, output_buffer);
0422         std::swap(values_input_buffer, values_output_buffer);
0423         std::swap(input_offset, output_offset);
0424         std::swap(values_input_offset, values_output_offset);
0425     }
0426 }
0427 
0428 template<class Iterator>
0429 inline void radix_sort(Iterator first,
0430                        Iterator last,
0431                        command_queue &queue)
0432 {
0433     radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
0434 }
0435 
0436 template<class KeyIterator, class ValueIterator>
0437 inline void radix_sort_by_key(KeyIterator keys_first,
0438                               KeyIterator keys_last,
0439                               ValueIterator values_first,
0440                               command_queue &queue)
0441 {
0442     radix_sort_impl(keys_first, keys_last, values_first, true, queue);
0443 }
0444 
0445 template<class Iterator>
0446 inline void radix_sort(Iterator first,
0447                        Iterator last,
0448                        const bool ascending,
0449                        command_queue &queue)
0450 {
0451     radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
0452 }
0453 
0454 template<class KeyIterator, class ValueIterator>
0455 inline void radix_sort_by_key(KeyIterator keys_first,
0456                               KeyIterator keys_last,
0457                               ValueIterator values_first,
0458                               const bool ascending,
0459                               command_queue &queue)
0460 {
0461     radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
0462 }
0463 
0464 
0465 } // end detail namespace
0466 } // end compute namespace
0467 } // end boost namespace
0468 
0469 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP