File indexing completed on 2025-01-19 09:51:46
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023 #if defined(EIGEN_USE_SYCL) && \
0024 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
0025 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
0026
0027 #include <CL/sycl.hpp>
0028 #ifdef EIGEN_EXCEPTIONS
0029 #include <stdexcept>
0030 #endif
0031 #include <cstddef>
0032 #include <queue>
0033 #include <set>
0034 #include <unordered_map>
0035
0036 namespace Eigen {
0037 namespace TensorSycl {
0038 namespace internal {
0039
0040 using sycl_acc_target = cl::sycl::access::target;
0041 using sycl_acc_mode = cl::sycl::access::mode;
0042
0043
0044
0045
0046 using buffer_data_type_t = uint8_t;
0047 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
0048 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
0049
0050
0051
0052
0053
0054
0055 class PointerMapper {
0056 public:
0057 using base_ptr_t = std::intptr_t;
0058
0059
0060
0061
0062
0063
0064
0065 struct virtual_pointer_t {
0066
0067
0068 base_ptr_t m_contents;
0069
0070
0071
0072
0073 operator void *() const { return reinterpret_cast<void *>(m_contents); }
0074
0075
0076
0077
0078 operator base_ptr_t() const { return m_contents; }
0079
0080
0081
0082
0083
0084 virtual_pointer_t operator+(size_t off) { return m_contents + off; }
0085
0086
0087 bool operator<(virtual_pointer_t rhs) const {
0088 return (static_cast<base_ptr_t>(m_contents) <
0089 static_cast<base_ptr_t>(rhs.m_contents));
0090 }
0091
0092 bool operator>(virtual_pointer_t rhs) const {
0093 return (static_cast<base_ptr_t>(m_contents) >
0094 static_cast<base_ptr_t>(rhs.m_contents));
0095 }
0096
0097
0098
0099
0100 bool operator==(virtual_pointer_t rhs) const {
0101 return (static_cast<base_ptr_t>(m_contents) ==
0102 static_cast<base_ptr_t>(rhs.m_contents));
0103 }
0104
0105
0106
0107
0108 bool operator!=(virtual_pointer_t rhs) const {
0109 return !(this->operator==(rhs));
0110 }
0111
0112
0113
0114
0115
0116
0117
0118 virtual_pointer_t(const void *ptr)
0119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
0120
0121
0122
0123
0124
0125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
0126 };
0127
0128
0129
0130 const virtual_pointer_t null_virtual_ptr = nullptr;
0131
0132
0133
0134
0135
0136 static inline bool is_nullptr(virtual_pointer_t ptr) {
0137 return (static_cast<void *>(ptr) == nullptr);
0138 }
0139
0140
0141
0142 using buffer_t = cl::sycl::buffer_mem;
0143
0144
0145
0146
0147
0148
0149 struct pMapNode_t {
0150 buffer_t m_buffer;
0151 size_t m_size;
0152 bool m_free;
0153
0154 pMapNode_t(buffer_t b, size_t size, bool f)
0155 : m_buffer{b}, m_size{size}, m_free{f} {
0156 m_buffer.set_final_data(nullptr);
0157 }
0158
0159 bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
0160 };
0161
0162
0163
0164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
0165
0166
0167
0168
0169
0170
0171 typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
0172 typename pointerMap_t::iterator retVal;
0173 bool reuse = false;
0174 if (!m_freeList.empty()) {
0175
0176 for (auto freeElem : m_freeList) {
0177 if (freeElem->second.m_size >= requiredSize) {
0178 retVal = freeElem;
0179 reuse = true;
0180
0181 m_freeList.erase(freeElem);
0182 break;
0183 }
0184 }
0185 }
0186 if (!reuse) {
0187 retVal = std::prev(m_pointerMap.end());
0188 }
0189 return retVal;
0190 }
0191
0192
0193
0194
0195
0196
0197
0198
0199
0200
0201
0202 typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
0203 if (this->count() == 0) {
0204 m_pointerMap.clear();
0205 EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
0206
0207 }
0208 if (is_nullptr(ptr)) {
0209 m_pointerMap.clear();
0210 EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
0211 }
0212
0213
0214 auto node = m_pointerMap.lower_bound(ptr);
0215
0216
0217 if (node == std::end(m_pointerMap)) {
0218 --node;
0219 } else if (node->first != ptr) {
0220 if (node == std::begin(m_pointerMap)) {
0221 m_pointerMap.clear();
0222 EIGEN_THROW_X(
0223 std::out_of_range("The pointer is not registered in the map\n"));
0224
0225 }
0226 --node;
0227 }
0228
0229 return node;
0230 }
0231
0232
0233
0234
0235 template <typename buffer_data_type = buffer_data_type_t>
0236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
0237 const virtual_pointer_t ptr) {
0238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
0239
0240
0241
0242
0243
0244 auto node = get_node(ptr);
0245 eigen_assert(node->first == ptr || node->first < ptr);
0246 eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
0247 node->first));
0248 return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
0249 }
0250
0251
0252
0253
0254
0255
0256
0257 template <sycl_acc_mode access_mode = default_acc_mode,
0258 sycl_acc_target access_target = default_acc_target,
0259 typename buffer_data_type = buffer_data_type_t>
0260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
0261 get_access(const virtual_pointer_t ptr) {
0262 auto buf = get_buffer<buffer_data_type>(ptr);
0263 return buf.template get_access<access_mode, access_target>();
0264 }
0265
0266
0267
0268
0269
0270
0271
0272
0273
0274 template <sycl_acc_mode access_mode = default_acc_mode,
0275 sycl_acc_target access_target = default_acc_target,
0276 typename buffer_data_type = buffer_data_type_t>
0277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
0278 get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
0279 auto buf = get_buffer<buffer_data_type>(ptr);
0280 return buf.template get_access<access_mode, access_target>(cgh);
0281 }
0282
0283
0284
0285
0286 inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
0287
0288
0289 auto node = get_node(ptr);
0290 auto start = node->first;
0291 eigen_assert(start == ptr || start < ptr);
0292 eigen_assert(ptr < start + node->second.m_size);
0293 return (ptr - start);
0294 }
0295
0296
0297
0298
0299
0300 template <typename buffer_data_type>
0301 inline size_t get_element_offset(const virtual_pointer_t ptr) {
0302 return get_offset(ptr) / sizeof(buffer_data_type);
0303 }
0304
0305
0306
0307
0308 PointerMapper(base_ptr_t baseAddress = 4096)
0309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
0310 if (m_baseAddress == 0) {
0311 EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
0312 }
0313 };
0314
0315
0316
0317
0318 PointerMapper(const PointerMapper &) = delete;
0319
0320
0321
0322
0323 inline void clear() {
0324 m_freeList.clear();
0325 m_pointerMap.clear();
0326 }
0327
0328
0329
0330
0331 inline virtual_pointer_t add_pointer(const buffer_t &b) {
0332 return add_pointer_impl(b);
0333 }
0334
0335
0336
0337
0338 inline virtual_pointer_t add_pointer(buffer_t &&b) {
0339 return add_pointer_impl(b);
0340 }
0341
0342
0343
0344
0345
0346
0347
0348 void fuse_forward(typename pointerMap_t::iterator &node) {
0349 while (node != std::prev(m_pointerMap.end())) {
0350
0351
0352 auto fwd_node = std::next(node);
0353 if (!fwd_node->second.m_free) {
0354 break;
0355 }
0356 auto fwd_size = fwd_node->second.m_size;
0357 m_freeList.erase(fwd_node);
0358 m_pointerMap.erase(fwd_node);
0359
0360 node->second.m_size += fwd_size;
0361 }
0362 }
0363
0364
0365
0366
0367
0368
0369
0370 void fuse_backward(typename pointerMap_t::iterator &node) {
0371 while (node != m_pointerMap.begin()) {
0372
0373
0374 auto prev_node = std::prev(node);
0375 if (!prev_node->second.m_free) {
0376 break;
0377 }
0378 prev_node->second.m_size += node->second.m_size;
0379
0380
0381 m_freeList.erase(node);
0382 m_pointerMap.erase(node);
0383
0384
0385 node = prev_node;
0386 }
0387 }
0388
0389
0390
0391
0392
0393 template <bool ReUse = true>
0394 void remove_pointer(const virtual_pointer_t ptr) {
0395 if (is_nullptr(ptr)) {
0396 return;
0397 }
0398 auto node = this->get_node(ptr);
0399
0400 node->second.m_free = true;
0401 m_freeList.emplace(node);
0402
0403
0404
0405 fuse_forward(node);
0406 fuse_backward(node);
0407
0408
0409
0410 if (node == std::prev(m_pointerMap.end())) {
0411 m_freeList.erase(node);
0412 m_pointerMap.erase(node);
0413 }
0414 }
0415
0416
0417
0418
0419
0420 size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
0421
0422 private:
0423
0424
0425
0426
0427 template <class BufferT>
0428 virtual_pointer_t add_pointer_impl(BufferT b) {
0429 virtual_pointer_t retVal = nullptr;
0430 size_t bufSize = b.get_count();
0431 pMapNode_t p{b, bufSize, false};
0432
0433 if (m_pointerMap.empty()) {
0434 virtual_pointer_t initialVal{m_baseAddress};
0435 m_pointerMap.emplace(initialVal, p);
0436 return initialVal;
0437 }
0438
0439 auto lastElemIter = get_insertion_point(bufSize);
0440
0441 if (lastElemIter->second.m_free) {
0442 lastElemIter->second.m_buffer = b;
0443 lastElemIter->second.m_free = false;
0444
0445
0446
0447 if (lastElemIter->second.m_size > bufSize) {
0448
0449 auto remainingSize = lastElemIter->second.m_size - bufSize;
0450 pMapNode_t p2{b, remainingSize, true};
0451
0452
0453 lastElemIter->second.m_size = bufSize;
0454
0455
0456 auto newFreePtr = lastElemIter->first + bufSize;
0457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
0458 m_freeList.emplace(freeNode);
0459 }
0460
0461 retVal = lastElemIter->first;
0462 } else {
0463 size_t lastSize = lastElemIter->second.m_size;
0464 retVal = lastElemIter->first + lastSize;
0465 m_pointerMap.emplace(retVal, p);
0466 }
0467 return retVal;
0468 }
0469
0470
0471
0472
0473
0474 struct SortBySize {
0475 bool operator()(typename pointerMap_t::iterator a,
0476 typename pointerMap_t::iterator b) const {
0477 return ((a->first < b->first) && (a->second <= b->second)) ||
0478 ((a->first < b->first) && (b->second <= a->second));
0479 }
0480 };
0481
0482
0483
0484 pointerMap_t m_pointerMap;
0485
0486
0487
0488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
0489
0490
0491
0492 std::intptr_t m_baseAddress;
0493 };
0494
0495
0496
0497
0498
0499 template <>
0500 inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
0501 if (is_nullptr(ptr)) {
0502 return;
0503 }
0504 m_pointerMap.erase(this->get_node(ptr));
0505 }
0506
0507
0508
0509
0510
0511
0512
0513
0514 inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
0515 if (size == 0) {
0516 return nullptr;
0517 }
0518
0519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
0520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
0521
0522 return static_cast<void *>(thePointer);
0523 }
0524
0525
0526
0527
0528
0529
0530
0531
0532 template <bool ReUse = true, typename PointerMapper>
0533 inline void SYCLfree(void *ptr, PointerMapper &pMap) {
0534 pMap.template remove_pointer<ReUse>(ptr);
0535 }
0536
0537
0538
0539
0540 template <typename PointerMapper>
0541 inline void SYCLfreeAll(PointerMapper &pMap) {
0542 pMap.clear();
0543 }
0544
0545 template <cl::sycl::access::mode AcMd, typename T>
0546 struct RangeAccess {
0547 static const auto global_access = cl::sycl::access::target::global_buffer;
0548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
0549 typedef T scalar_t;
0550 typedef scalar_t &ref_t;
0551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
0552
0553
0554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
0555 accessor;
0556
0557 typedef RangeAccess<AcMd, T> self_t;
0558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
0559 size_t offset,
0560 std::intptr_t virtual_ptr)
0561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
0562
0563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
0564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
0565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
0566
0567
0568 RangeAccess(std::nullptr_t) : RangeAccess() {}
0569
0570 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
0571 return (access_.get_pointer().get() + offset_);
0572 }
0573 template <typename Index>
0574 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
0575 offset_ += (offset);
0576 return *this;
0577 }
0578 template <typename Index>
0579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
0580 return self_t(access_, offset_ + offset, virtual_ptr_);
0581 }
0582 template <typename Index>
0583 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
0584 return self_t(access_, offset_ - offset, virtual_ptr_);
0585 }
0586 template <typename Index>
0587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
0588 offset_ -= offset;
0589 return *this;
0590 }
0591
0592
0593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
0594 const RangeAccess &lhs, std::nullptr_t) {
0595 return ((lhs.virtual_ptr_ == -1));
0596 }
0597 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
0598 const RangeAccess &lhs, std::nullptr_t i) {
0599 return !(lhs == i);
0600 }
0601
0602
0603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
0604 std::nullptr_t, const RangeAccess &rhs) {
0605 return ((rhs.virtual_ptr_ == -1));
0606 }
0607 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
0608 std::nullptr_t i, const RangeAccess &rhs) {
0609 return !(i == rhs);
0610 }
0611
0612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
0613 offset_++;
0614 return (*this);
0615 }
0616
0617
0618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
0619 EIGEN_UNUSED_VARIABLE(i);
0620 self_t temp_iterator(*this);
0621 offset_++;
0622 return temp_iterator;
0623 }
0624
0625 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
0626 return (access_.get_count() - offset_);
0627 }
0628
0629 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
0630 return offset_;
0631 }
0632
0633 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
0634 offset_ = offset;
0635 }
0636
0637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
0638 return *get_pointer();
0639 }
0640
0641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
0642 return *get_pointer();
0643 }
0644
0645 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
0646
0647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
0648 return *(get_pointer() + x);
0649 }
0650
0651 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
0652 return *(get_pointer() + x);
0653 }
0654
0655 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
0656 return reinterpret_cast<scalar_t *>(virtual_ptr_ +
0657 (offset_ * sizeof(scalar_t)));
0658 }
0659
0660 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
0661 return (virtual_ptr_ != -1);
0662 }
0663
0664 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
0665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
0666 }
0667
0668 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
0669 operator RangeAccess<AcMd, const T>() const {
0670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
0671 }
0672
0673 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
0674 cl::sycl::handler &cgh) const {
0675 cgh.require(access_);
0676 }
0677
0678 private:
0679 accessor access_;
0680 size_t offset_;
0681 std::intptr_t virtual_ptr_;
0682 };
0683
0684 template <cl::sycl::access::mode AcMd, typename T>
0685 struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
0686 typedef RangeAccess<AcMd, T> Base;
0687 using Base::Base;
0688 };
0689
0690 }
0691 }
0692 }
0693
0694 #endif