File indexing completed on 2025-10-31 09:00:27
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 
0009 
0010 
0011 
0012 
0013 
0014 
0015 
0016 
0017 
0018 
0019 
0020 
0021 
0022 
0023 #if defined(EIGEN_USE_SYCL) && \
0024     !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
0025 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
0026 
0027 #include <CL/sycl.hpp>
0028 #ifdef EIGEN_EXCEPTIONS
0029 #include <stdexcept>
0030 #endif
0031 #include <cstddef>
0032 #include <queue>
0033 #include <set>
0034 #include <unordered_map>
0035 
0036 namespace Eigen {
0037 namespace TensorSycl {
0038 namespace internal {
0039 
0040 using sycl_acc_target = cl::sycl::access::target;
0041 using sycl_acc_mode = cl::sycl::access::mode;
0042 
0043 
0044 
0045 
0046 using buffer_data_type_t = uint8_t;
0047 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
0048 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
0049 
0050 
0051 
0052 
0053 
0054 
0055 class PointerMapper {
0056  public:
0057   using base_ptr_t = std::intptr_t;
0058 
0059   
0060 
0061 
0062 
0063 
0064 
0065   struct virtual_pointer_t {
0066     
0067 
0068     base_ptr_t m_contents;
0069 
0070     
0071 
0072 
0073     operator void *() const { return reinterpret_cast<void *>(m_contents); }
0074 
0075     
0076 
0077 
0078     operator base_ptr_t() const { return m_contents; }
0079 
0080     
0081 
0082 
0083 
0084     virtual_pointer_t operator+(size_t off) { return m_contents + off; }
0085 
0086     
0087     bool operator<(virtual_pointer_t rhs) const {
0088       return (static_cast<base_ptr_t>(m_contents) <
0089               static_cast<base_ptr_t>(rhs.m_contents));
0090     }
0091 
0092     bool operator>(virtual_pointer_t rhs) const {
0093       return (static_cast<base_ptr_t>(m_contents) >
0094               static_cast<base_ptr_t>(rhs.m_contents));
0095     }
0096 
0097     
0098 
0099 
0100     bool operator==(virtual_pointer_t rhs) const {
0101       return (static_cast<base_ptr_t>(m_contents) ==
0102               static_cast<base_ptr_t>(rhs.m_contents));
0103     }
0104 
0105     
0106 
0107 
0108     bool operator!=(virtual_pointer_t rhs) const {
0109       return !(this->operator==(rhs));
0110     }
0111 
0112     
0113 
0114 
0115 
0116 
0117 
0118     virtual_pointer_t(const void *ptr)
0119         : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
0120 
0121     
0122 
0123 
0124 
0125     virtual_pointer_t(base_ptr_t u) : m_contents(u){};
0126   };
0127 
0128   
0129 
0130   const virtual_pointer_t null_virtual_ptr = nullptr;
0131 
0132   
0133 
0134 
0135 
0136   static inline bool is_nullptr(virtual_pointer_t ptr) {
0137     return (static_cast<void *>(ptr) == nullptr);
0138   }
0139 
0140   
0141 
0142   using buffer_t = cl::sycl::buffer_mem;
0143 
0144   
0145 
0146 
0147 
0148 
0149   struct pMapNode_t {
0150     buffer_t m_buffer;
0151     size_t m_size;
0152     bool m_free;
0153 
0154     pMapNode_t(buffer_t b, size_t size, bool f)
0155         : m_buffer{b}, m_size{size}, m_free{f} {
0156       m_buffer.set_final_data(nullptr);
0157     }
0158 
0159     bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
0160   };
0161 
0162   
0163 
0164   using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
0165 
0166   
0167 
0168 
0169 
0170 
0171   typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
0172     typename pointerMap_t::iterator retVal;
0173     bool reuse = false;
0174     if (!m_freeList.empty()) {
0175       
0176       for (auto freeElem : m_freeList) {
0177         if (freeElem->second.m_size >= requiredSize) {
0178           retVal = freeElem;
0179           reuse = true;
0180           
0181           m_freeList.erase(freeElem);
0182           break;
0183         }
0184       }
0185     }
0186     if (!reuse) {
0187       retVal = std::prev(m_pointerMap.end());
0188     }
0189     return retVal;
0190   }
0191 
0192   
0193 
0194 
0195 
0196 
0197 
0198 
0199 
0200 
0201 
0202   typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
0203     if (this->count() == 0) {
0204       m_pointerMap.clear();
0205       EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
0206 
0207     }
0208     if (is_nullptr(ptr)) {
0209       m_pointerMap.clear();
0210       EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
0211     }
0212     
0213     
0214     auto node = m_pointerMap.lower_bound(ptr);
0215     
0216     
0217     if (node == std::end(m_pointerMap)) {
0218       --node;
0219     } else if (node->first != ptr) {
0220       if (node == std::begin(m_pointerMap)) {
0221         m_pointerMap.clear();
0222         EIGEN_THROW_X(
0223             std::out_of_range("The pointer is not registered in the map\n"));
0224 
0225       }
0226       --node;
0227     }
0228 
0229     return node;
0230   }
0231 
0232   
0233 
0234 
0235   template <typename buffer_data_type = buffer_data_type_t>
0236   cl::sycl::buffer<buffer_data_type, 1> get_buffer(
0237       const virtual_pointer_t ptr) {
0238     using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
0239 
0240     
0241     
0242     
0243     
0244     auto node = get_node(ptr);
0245     eigen_assert(node->first == ptr || node->first < ptr);
0246     eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
0247                                                       node->first));
0248     return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
0249   }
0250 
0251   
0252 
0253 
0254 
0255 
0256 
0257   template <sycl_acc_mode access_mode = default_acc_mode,
0258             sycl_acc_target access_target = default_acc_target,
0259             typename buffer_data_type = buffer_data_type_t>
0260   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
0261   get_access(const virtual_pointer_t ptr) {
0262     auto buf = get_buffer<buffer_data_type>(ptr);
0263     return buf.template get_access<access_mode, access_target>();
0264   }
0265 
0266   
0267 
0268 
0269 
0270 
0271 
0272 
0273 
0274   template <sycl_acc_mode access_mode = default_acc_mode,
0275             sycl_acc_target access_target = default_acc_target,
0276             typename buffer_data_type = buffer_data_type_t>
0277   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
0278   get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
0279     auto buf = get_buffer<buffer_data_type>(ptr);
0280     return buf.template get_access<access_mode, access_target>(cgh);
0281   }
0282 
0283   
0284 
0285 
0286   inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
0287     
0288     
0289     auto node = get_node(ptr);
0290     auto start = node->first;
0291     eigen_assert(start == ptr || start < ptr);
0292     eigen_assert(ptr < start + node->second.m_size);
0293     return (ptr - start);
0294   }
0295 
0296   
0297 
0298 
0299 
0300   template <typename buffer_data_type>
0301   inline size_t get_element_offset(const virtual_pointer_t ptr) {
0302     return get_offset(ptr) / sizeof(buffer_data_type);
0303   }
0304 
0305   
0306 
0307 
0308   PointerMapper(base_ptr_t baseAddress = 4096)
0309       : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
0310     if (m_baseAddress == 0) {
0311       EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
0312     }
0313   };
0314 
0315   
0316 
0317 
0318   PointerMapper(const PointerMapper &) = delete;
0319 
0320   
0321 
0322 
0323   inline void clear() {
0324     m_freeList.clear();
0325     m_pointerMap.clear();
0326   }
0327 
0328   
0329 
0330 
0331   inline virtual_pointer_t add_pointer(const buffer_t &b) {
0332     return add_pointer_impl(b);
0333   }
0334 
0335   
0336 
0337 
0338   inline virtual_pointer_t add_pointer(buffer_t &&b) {
0339     return add_pointer_impl(b);
0340   }
0341 
0342   
0343 
0344 
0345 
0346 
0347 
0348   void fuse_forward(typename pointerMap_t::iterator &node) {
0349     while (node != std::prev(m_pointerMap.end())) {
0350       
0351       
0352       auto fwd_node = std::next(node);
0353       if (!fwd_node->second.m_free) {
0354         break;
0355       }
0356       auto fwd_size = fwd_node->second.m_size;
0357       m_freeList.erase(fwd_node);
0358       m_pointerMap.erase(fwd_node);
0359 
0360       node->second.m_size += fwd_size;
0361     }
0362   }
0363 
0364   
0365 
0366 
0367 
0368 
0369 
0370   void fuse_backward(typename pointerMap_t::iterator &node) {
0371     while (node != m_pointerMap.begin()) {
0372       
0373       
0374       auto prev_node = std::prev(node);
0375       if (!prev_node->second.m_free) {
0376         break;
0377       }
0378       prev_node->second.m_size += node->second.m_size;
0379 
0380       
0381       m_freeList.erase(node);
0382       m_pointerMap.erase(node);
0383 
0384       
0385       node = prev_node;
0386     }
0387   }
0388 
0389   
0390 
0391 
0392 
0393   template <bool ReUse = true>
0394   void remove_pointer(const virtual_pointer_t ptr) {
0395     if (is_nullptr(ptr)) {
0396       return;
0397     }
0398     auto node = this->get_node(ptr);
0399 
0400     node->second.m_free = true;
0401     m_freeList.emplace(node);
0402 
0403     
0404     
0405     fuse_forward(node);
0406     fuse_backward(node);
0407 
0408     
0409     
0410     if (node == std::prev(m_pointerMap.end())) {
0411       m_freeList.erase(node);
0412       m_pointerMap.erase(node);
0413     }
0414   }
0415 
0416   
0417 
0418 
0419 
0420   size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
0421 
0422  private:
0423   
0424 
0425 
0426 
0427   template <class BufferT>
0428   virtual_pointer_t add_pointer_impl(BufferT b) {
0429     virtual_pointer_t retVal = nullptr;
0430     size_t bufSize = b.get_count();
0431     pMapNode_t p{b, bufSize, false};
0432     
0433     if (m_pointerMap.empty()) {
0434       virtual_pointer_t initialVal{m_baseAddress};
0435       m_pointerMap.emplace(initialVal, p);
0436       return initialVal;
0437     }
0438 
0439     auto lastElemIter = get_insertion_point(bufSize);
0440     
0441     if (lastElemIter->second.m_free) {
0442       lastElemIter->second.m_buffer = b;
0443       lastElemIter->second.m_free = false;
0444 
0445       
0446       
0447       if (lastElemIter->second.m_size > bufSize) {
0448         
0449         auto remainingSize = lastElemIter->second.m_size - bufSize;
0450         pMapNode_t p2{b, remainingSize, true};
0451 
0452         
0453         lastElemIter->second.m_size = bufSize;
0454 
0455         
0456         auto newFreePtr = lastElemIter->first + bufSize;
0457         auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
0458         m_freeList.emplace(freeNode);
0459       }
0460 
0461       retVal = lastElemIter->first;
0462     } else {
0463       size_t lastSize = lastElemIter->second.m_size;
0464       retVal = lastElemIter->first + lastSize;
0465       m_pointerMap.emplace(retVal, p);
0466     }
0467     return retVal;
0468   }
0469 
0470   
0471 
0472 
0473 
0474   struct SortBySize {
0475     bool operator()(typename pointerMap_t::iterator a,
0476                     typename pointerMap_t::iterator b) const {
0477       return ((a->first < b->first) && (a->second <= b->second)) ||
0478              ((a->first < b->first) && (b->second <= a->second));
0479     }
0480   };
0481 
0482   
0483 
0484   pointerMap_t m_pointerMap;
0485 
0486   
0487 
0488   std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
0489 
0490   
0491 
0492   std::intptr_t m_baseAddress;
0493 };
0494 
0495 
0496 
0497 
0498 
0499 template <>
0500 inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
0501   if (is_nullptr(ptr)) {
0502     return;
0503   }
0504   m_pointerMap.erase(this->get_node(ptr));
0505 }
0506 
0507 
0508 
0509 
0510 
0511 
0512 
0513 
0514 inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
0515   if (size == 0) {
0516     return nullptr;
0517   }
0518   
0519   using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
0520   auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
0521   
0522   return static_cast<void *>(thePointer);
0523 }
0524 
0525 
0526 
0527 
0528 
0529 
0530 
0531 
0532 template <bool ReUse = true, typename PointerMapper>
0533 inline void SYCLfree(void *ptr, PointerMapper &pMap) {
0534   pMap.template remove_pointer<ReUse>(ptr);
0535 }
0536 
0537 
0538 
0539 
0540 template <typename PointerMapper>
0541 inline void SYCLfreeAll(PointerMapper &pMap) {
0542   pMap.clear();
0543 }
0544 
0545 template <cl::sycl::access::mode AcMd, typename T>
0546 struct RangeAccess {
0547   static const auto global_access = cl::sycl::access::target::global_buffer;
0548   static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
0549   typedef T scalar_t;
0550   typedef scalar_t &ref_t;
0551   typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
0552 
0553   
0554   typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
0555       accessor;
0556 
0557   typedef RangeAccess<AcMd, T> self_t;
0558   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
0559                                                     size_t offset,
0560                                                     std::intptr_t virtual_ptr)
0561       : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
0562 
0563   RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
0564                   cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
0565       : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
0566 
0567   
0568   RangeAccess(std::nullptr_t) : RangeAccess() {}
0569   
0570   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
0571     return (access_.get_pointer().get() + offset_);
0572   }
0573   template <typename Index>
0574   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
0575     offset_ += (offset);
0576     return *this;
0577   }
0578   template <typename Index>
0579   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
0580     return self_t(access_, offset_ + offset, virtual_ptr_);
0581   }
0582   template <typename Index>
0583   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
0584     return self_t(access_, offset_ - offset, virtual_ptr_);
0585   }
0586   template <typename Index>
0587   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
0588     offset_ -= offset;
0589     return *this;
0590   }
0591 
0592   
0593   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
0594       const RangeAccess &lhs, std::nullptr_t) {
0595     return ((lhs.virtual_ptr_ == -1));
0596   }
0597   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
0598       const RangeAccess &lhs, std::nullptr_t i) {
0599     return !(lhs == i);
0600   }
0601 
0602   
0603   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
0604       std::nullptr_t, const RangeAccess &rhs) {
0605     return ((rhs.virtual_ptr_ == -1));
0606   }
0607   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
0608       std::nullptr_t i, const RangeAccess &rhs) {
0609     return !(i == rhs);
0610   }
0611   
0612   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
0613     offset_++;
0614     return (*this);
0615   }
0616 
0617   
0618   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
0619     EIGEN_UNUSED_VARIABLE(i);
0620     self_t temp_iterator(*this);
0621     offset_++;
0622     return temp_iterator;
0623   }
0624 
0625   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
0626     return (access_.get_count() - offset_);
0627   }
0628 
0629   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
0630     return offset_;
0631   }
0632 
0633   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
0634     offset_ = offset;
0635   }
0636 
0637   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
0638     return *get_pointer();
0639   }
0640 
0641   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
0642     return *get_pointer();
0643   }
0644 
0645   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
0646 
0647   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
0648     return *(get_pointer() + x);
0649   }
0650 
0651   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
0652     return *(get_pointer() + x);
0653   }
0654 
0655   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
0656     return reinterpret_cast<scalar_t *>(virtual_ptr_ +
0657                                         (offset_ * sizeof(scalar_t)));
0658   }
0659 
0660   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
0661     return (virtual_ptr_ != -1);
0662   }
0663 
0664   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
0665     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
0666   }
0667 
0668   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0669   operator RangeAccess<AcMd, const T>() const {
0670     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
0671   }
0672   
0673   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
0674       cl::sycl::handler &cgh) const {
0675     cgh.require(access_);
0676   }
0677 
0678  private:
0679   accessor access_;
0680   size_t offset_;
0681   std::intptr_t virtual_ptr_;  
0682 };
0683 
0684 template <cl::sycl::access::mode AcMd, typename T>
0685 struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
0686   typedef RangeAccess<AcMd, T> Base;
0687   using Base::Base;
0688 };
0689 
0690 }  
0691 }  
0692 }  
0693 
0694 #endif