File indexing completed on 2025-08-27 08:47:21
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018 #pragma once
0019
0020 #include <algorithm>
0021 #include <cstddef>
0022 #include <memory>
0023 #include <string>
0024 #include <utility>
0025
0026 #include "arrow/memory_pool.h"
0027 #include "arrow/type_fwd.h"
0028 #include "arrow/util/macros.h"
0029
0030 namespace arrow {
0031 namespace stl {
0032
0033
0034 template <class T>
0035 class allocator {
0036 public:
0037 using value_type = T;
0038 using pointer = T*;
0039 using const_pointer = const T*;
0040 using reference = T&;
0041 using const_reference = const T&;
0042 using size_type = std::size_t;
0043 using difference_type = std::ptrdiff_t;
0044
0045 template <class U>
0046 struct rebind {
0047 using other = allocator<U>;
0048 };
0049
0050
0051 allocator() noexcept : pool_(default_memory_pool()) {}
0052
0053 explicit allocator(MemoryPool* pool) noexcept : pool_(pool) {}
0054
0055 template <class U>
0056 allocator(const allocator<U>& rhs) noexcept : pool_(rhs.pool()) {}
0057
0058 ~allocator() { pool_ = NULLPTR; }
0059
0060 pointer address(reference r) const noexcept { return std::addressof(r); }
0061
0062 const_pointer address(const_reference r) const noexcept { return std::addressof(r); }
0063
0064 pointer allocate(size_type n, const void* = NULLPTR) {
0065 uint8_t* data;
0066 Status s = pool_->Allocate(n * sizeof(T), &data);
0067 if (!s.ok()) throw std::bad_alloc();
0068 return reinterpret_cast<pointer>(data);
0069 }
0070
0071 void deallocate(pointer p, size_type n) {
0072 pool_->Free(reinterpret_cast<uint8_t*>(p), n * sizeof(T));
0073 }
0074
0075 size_type size_max() const noexcept { return size_type(-1) / sizeof(T); }
0076
0077 template <class U, class... Args>
0078 void construct(U* p, Args&&... args) {
0079 new (reinterpret_cast<void*>(p)) U(std::forward<Args>(args)...);
0080 }
0081
0082 template <class U>
0083 void destroy(U* p) {
0084 p->~U();
0085 }
0086
0087 MemoryPool* pool() const noexcept { return pool_; }
0088
0089 private:
0090 MemoryPool* pool_;
0091 };
0092
0093
0094
0095
0096
0097 template <typename Allocator = std::allocator<uint8_t>>
0098 class STLMemoryPool : public MemoryPool {
0099 public:
0100
0101 explicit STLMemoryPool(const Allocator& alloc) : alloc_(alloc) {}
0102
0103 using MemoryPool::Allocate;
0104 using MemoryPool::Free;
0105 using MemoryPool::Reallocate;
0106
0107 Status Allocate(int64_t size, int64_t , uint8_t** out) override {
0108 try {
0109 *out = alloc_.allocate(size);
0110 } catch (std::bad_alloc& e) {
0111 return Status::OutOfMemory(e.what());
0112 }
0113 stats_.DidAllocateBytes(size);
0114 return Status::OK();
0115 }
0116
0117 Status Reallocate(int64_t old_size, int64_t new_size, int64_t ,
0118 uint8_t** ptr) override {
0119 uint8_t* old_ptr = *ptr;
0120 try {
0121 *ptr = alloc_.allocate(new_size);
0122 } catch (std::bad_alloc& e) {
0123 return Status::OutOfMemory(e.what());
0124 }
0125 memcpy(*ptr, old_ptr, std::min(old_size, new_size));
0126 alloc_.deallocate(old_ptr, old_size);
0127 stats_.DidReallocateBytes(old_size, new_size);
0128 return Status::OK();
0129 }
0130
0131 void Free(uint8_t* buffer, int64_t size, int64_t ) override {
0132 alloc_.deallocate(buffer, size);
0133 stats_.DidFreeBytes(size);
0134 }
0135
0136 int64_t bytes_allocated() const override { return stats_.bytes_allocated(); }
0137
0138 int64_t max_memory() const override { return stats_.max_memory(); }
0139
0140 int64_t total_bytes_allocated() const override {
0141 return stats_.total_bytes_allocated();
0142 }
0143
0144 int64_t num_allocations() const override { return stats_.num_allocations(); }
0145
0146 std::string backend_name() const override { return "stl"; }
0147
0148 private:
0149 Allocator alloc_;
0150 arrow::internal::MemoryPoolStats stats_;
0151 };
0152
0153 template <class T1, class T2>
0154 bool operator==(const allocator<T1>& lhs, const allocator<T2>& rhs) noexcept {
0155 return lhs.pool() == rhs.pool();
0156 }
0157
0158 template <class T1, class T2>
0159 bool operator!=(const allocator<T1>& lhs, const allocator<T2>& rhs) noexcept {
0160 return !(lhs == rhs);
0161 }
0162
0163 }
0164 }