Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:30:03

0001 //---------------------------------------------------------------------------//
0002 // Copyright (c) 2013-2014 Kyle Lutz <kyle.r.lutz@gmail.com>
0003 //
0004 // Distributed under the Boost Software License, Version 1.0
0005 // See accompanying file LICENSE_1_0.txt or copy at
0006 // http://www.boost.org/LICENSE_1_0.txt
0007 //
0008 // See http://boostorg.github.com/compute for more information.
0009 //---------------------------------------------------------------------------//
0010 
0011 #ifndef BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
0012 #define BOOST_COMPUTE_MEMORY_SVM_PTR_HPP
0013 
0014 #include <boost/type_traits.hpp>
0015 #include <boost/static_assert.hpp>
0016 #include <boost/assert.hpp>
0017 
0018 #include <boost/compute/cl.hpp>
0019 #include <boost/compute/kernel.hpp>
0020 #include <boost/compute/context.hpp>
0021 #include <boost/compute/command_queue.hpp>
0022 #include <boost/compute/type_traits/is_device_iterator.hpp>
0023 
0024 namespace boost {
0025 namespace compute {
0026 
0027 // forward declaration for svm_ptr<T>
0028 template<class T>
0029 class svm_ptr;
0030 
0031 // svm functions require OpenCL 2.0
0032 #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
0033 namespace detail {
0034 
0035 template<class T, class IndexExpr>
0036 struct svm_ptr_index_expr
0037 {
0038     typedef T result_type;
0039 
0040     svm_ptr_index_expr(const svm_ptr<T> &svm_ptr,
0041                        const IndexExpr &expr)
0042         : m_svm_ptr(svm_ptr),
0043           m_expr(expr)
0044     {
0045     }
0046 
0047     operator T() const
0048     {
0049         BOOST_STATIC_ASSERT_MSG(boost::is_integral<IndexExpr>::value,
0050                                 "Index expression must be integral");
0051 
0052         BOOST_ASSERT(m_svm_ptr.get());
0053 
0054         const context &context = m_svm_ptr.get_context();
0055         const device &device = context.get_device();
0056         command_queue queue(context, device);
0057 
0058         T value;
0059         T* ptr =
0060             static_cast<T*>(m_svm_ptr.get()) + static_cast<std::ptrdiff_t>(m_expr);
0061         queue.enqueue_svm_map(static_cast<void*>(ptr), sizeof(T), CL_MAP_READ);
0062         value = *(ptr);
0063         queue.enqueue_svm_unmap(static_cast<void*>(ptr)).wait();
0064 
0065         return value;
0066     }
0067 
0068     const svm_ptr<T> &m_svm_ptr;
0069     IndexExpr m_expr;
0070 };
0071 
0072 } // end detail namespace
0073 #endif
0074 
0075 template<class T>
0076 class svm_ptr
0077 {
0078 public:
0079     typedef T value_type;
0080     typedef std::ptrdiff_t difference_type;
0081     typedef T* pointer;
0082     typedef T& reference;
0083     typedef std::random_access_iterator_tag iterator_category;
0084 
0085     svm_ptr()
0086         : m_ptr(0)
0087     {
0088     }
0089 
0090     svm_ptr(void *ptr, const context &context)
0091         : m_ptr(static_cast<T*>(ptr)),
0092           m_context(context)
0093     {
0094     }
0095 
0096     svm_ptr(const svm_ptr<T> &other)
0097         : m_ptr(other.m_ptr),
0098           m_context(other.m_context)
0099     {
0100     }
0101 
0102     svm_ptr<T>& operator=(const svm_ptr<T> &other)
0103     {
0104         m_ptr = other.m_ptr;
0105         m_context = other.m_context;
0106         return *this;
0107     }
0108 
0109     ~svm_ptr()
0110     {
0111     }
0112 
0113     void* get() const
0114     {
0115         return m_ptr;
0116     }
0117 
0118     svm_ptr<T> operator+(difference_type n)
0119     {
0120         return svm_ptr<T>(m_ptr + n, m_context);
0121     }
0122 
0123     difference_type operator-(svm_ptr<T> other)
0124     {
0125         BOOST_ASSERT(other.m_context == m_context);
0126         return m_ptr - other.m_ptr;
0127     }
0128 
0129     const context& get_context() const
0130     {
0131         return m_context;
0132     }
0133 
0134     bool operator==(const svm_ptr<T>& other) const
0135     {
0136         return (other.m_context == m_context) && (m_ptr == other.m_ptr);
0137     }
0138 
0139     bool operator!=(const svm_ptr<T>& other) const
0140     {
0141         return (other.m_context != m_context) || (m_ptr != other.m_ptr);
0142     }
0143 
0144     // svm functions require OpenCL 2.0
0145     #if defined(BOOST_COMPUTE_CL_VERSION_2_0) || defined(BOOST_COMPUTE_DOXYGEN_INVOKED)
0146     /// \internal_
0147     template<class Expr>
0148     detail::svm_ptr_index_expr<T, Expr>
0149     operator[](const Expr &expr) const
0150     {
0151         BOOST_ASSERT(m_ptr);
0152 
0153         return detail::svm_ptr_index_expr<T, Expr>(*this,
0154                                                    expr);
0155     }
0156     #endif
0157 
0158 private:
0159     T *m_ptr;
0160     context m_context;
0161 };
0162 
0163 namespace detail {
0164 
0165 /// \internal_
0166 template<class T>
0167 struct set_kernel_arg<svm_ptr<T> >
0168 {
0169     void operator()(kernel &kernel_, size_t index, const svm_ptr<T> &ptr)
0170     {
0171         kernel_.set_arg_svm_ptr(index, ptr.get());
0172     }
0173 };
0174 
0175 } // end detail namespace
0176 
0177 /// \internal_ (is_device_iterator specialization for svm_ptr)
0178 template<class T>
0179 struct is_device_iterator<svm_ptr<T> > : boost::true_type {};
0180 
0181 } // end compute namespace
0182 } // end boost namespace
0183 
0184 #endif // BOOST_COMPUTE_MEMORY_SVM_PTR_HPP