|
13 | 13 | #endif
|
14 | 14 |
|
15 | 15 | namespace cp_algo {
|
16 |
| - template <typename T> |
17 |
| - class big_alloc: public std::allocator<T> { |
18 |
| - public: |
| 16 | + template <typename T, std::size_t Align = 32> |
| 17 | + class big_alloc { |
| 18 | + static_assert( Align >= alignof(void*), "Align must be at least pointer-size"); |
| 19 | + static_assert(std::popcount(Align) == 1, "Align must be a power of two"); |
| 20 | + public: |
19 | 21 | using value_type = T;
|
20 |
| - using base = std::allocator<T>; |
| 22 | + template <class U> struct rebind { using other = big_alloc<U, Align>; }; |
21 | 23 |
|
22 | 24 | big_alloc() noexcept = default;
|
| 25 | + template <typename U, std::size_t A> |
| 26 | + big_alloc(const big_alloc<U, A>&) noexcept {} |
23 | 27 |
|
24 |
| - template <typename U> |
25 |
| - big_alloc(const big_alloc<U>&) noexcept {} |
26 |
| - |
27 |
| -#if CP_ALGO_USE_MMAP |
28 | 28 | [[nodiscard]] T* allocate(std::size_t n) {
|
29 |
| - if(n * sizeof(T) < 1024 * 1024) { |
30 |
| - return base::allocate(n); |
| 29 | + std::size_t padded = round_up(n * sizeof(T)); |
| 30 | + std::size_t align = std::max<std::size_t>(alignof(T), Align); |
| 31 | +#if CP_ALGO_USE_MMAP |
| 32 | + if (padded >= MEGABYTE) { |
| 33 | + void* raw = mmap(nullptr, padded, |
| 34 | + PROT_READ | PROT_WRITE, |
| 35 | + MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); |
| 36 | + madvise(raw, padded, MADV_HUGEPAGE); |
| 37 | + madvise(raw, padded, MADV_POPULATE_WRITE); |
| 38 | + return static_cast<T*>(raw); |
31 | 39 | }
|
32 |
| - n *= sizeof(T); |
33 |
| - void* raw = mmap(nullptr, n, |
34 |
| - PROT_READ | PROT_WRITE, |
35 |
| - MAP_PRIVATE | MAP_ANONYMOUS, |
36 |
| - -1, 0); |
37 |
| - madvise(raw, n, MADV_HUGEPAGE); |
38 |
| - madvise(raw, n, MADV_POPULATE_WRITE); |
39 |
| - return static_cast<T*>(raw); |
40 |
| - } |
41 | 40 | #endif
|
| 41 | + return static_cast<T*>(::operator new(padded, std::align_val_t(align))); |
| 42 | + } |
42 | 43 |
|
43 |
| -#if CP_ALGO_USE_MMAP |
44 | 44 | void deallocate(T* p, std::size_t n) noexcept {
|
45 |
| - if(n * sizeof(T) < 1024 * 1024) { |
46 |
| - return base::deallocate(p, n); |
47 |
| - } |
48 |
| - if(p) { |
49 |
| - munmap(p, n * sizeof(T)); |
50 |
| - } |
| 45 | + if (!p) return; |
| 46 | + std::size_t padded = round_up(n * sizeof(T)); |
| 47 | + std::size_t align = std::max<std::size_t>(alignof(T), Align); |
| 48 | + #if CP_ALGO_USE_MMAP |
| 49 | + if (padded >= MEGABYTE) { munmap(p, padded); return; } |
| 50 | + #endif |
| 51 | + ::operator delete(p, padded, std::align_val_t(align)); |
| 52 | + } |
| 53 | + |
| 54 | + private: |
| 55 | + static constexpr std::size_t MEGABYTE = 1 << 20; |
| 56 | + static constexpr std::size_t round_up(std::size_t x) noexcept { |
| 57 | + return (x + Align - 1) / Align * Align; |
51 | 58 | }
|
52 |
| -#endif |
53 | 59 | };
|
54 | 60 | }
|
55 | 61 | #endif // CP_ALGO_UTIL_big_alloc_HPP
|
0 commit comments