diff --git a/include/hyprutils/memory/Atomic.hpp b/include/hyprutils/memory/Atomic.hpp index de0753e..813bae3 100644 --- a/include/hyprutils/memory/Atomic.hpp +++ b/include/hyprutils/memory/Atomic.hpp @@ -54,11 +54,11 @@ namespace Hyprutils::Memory { using validHierarchy = std::enable_if_t&, X>, CAtomicSharedPointer&>; public: - explicit CAtomicSharedPointer(T* object) noexcept : m_ptr(new Atomic_::impl(sc(object), _delete)) { + explicit CAtomicSharedPointer(T* object) noexcept : m_ptr(new Atomic_::impl(sc(object), _delete), sc(object)) { ; } - CAtomicSharedPointer(Impl_::impl_base* impl) noexcept : m_ptr(impl) { + CAtomicSharedPointer(Impl_::impl_base* impl, void* data) noexcept : m_ptr(impl, data) { ; } @@ -219,13 +219,17 @@ namespace Hyprutils::Memory { return m_ptr.impl_ ? m_ptr.impl_->ref() : 0; } + Atomic_::impl* impl() const { + return sc(m_ptr.impl_); + } + private: static void _delete(void* p) { std::default_delete{}(sc(p)); } std::lock_guard implLockGuard() const { - return sc(m_ptr.impl_)->lockGuard(); + return impl()->lockGuard(); } CSharedPointer m_ptr; @@ -391,12 +395,16 @@ namespace Hyprutils::Memory { if (!m_ptr.impl_->dataNonNull() || m_ptr.impl_->destroying() || !m_ptr.impl_->lockable()) return {}; - return CAtomicSharedPointer(m_ptr.impl_); + return CAtomicSharedPointer(m_ptr.impl_, m_ptr.m_data); + } + + Atomic_::impl* impl() const { + return sc(m_ptr.impl_); } private: std::lock_guard implLockGuard() const { - return sc(m_ptr.impl_)->lockGuard(); + return impl()->lockGuard(); } CWeakPointer m_ptr; @@ -411,4 +419,19 @@ namespace Hyprutils::Memory { [[nodiscard]] inline CAtomicSharedPointer makeAtomicShared(Args&&... args) { return CAtomicSharedPointer(new U(std::forward(args)...)); } + + template + CAtomicSharedPointer reinterpretPointerCast(const CAtomicSharedPointer& ref) { + return CAtomicSharedPointer(ref.impl(), ref.m_data); + } + + template + CAtomicSharedPointer dynamicPointerCast(const CAtomicSharedPointer& ref) { + if (!ref) + return nullptr; + T* newPtr = dynamic_cast(sc(ref.impl()->getData())); + if (!newPtr) + return nullptr; + return CAtomicSharedPointer(ref.impl(), newPtr); + } } diff --git a/include/hyprutils/memory/SharedPtr.hpp b/include/hyprutils/memory/SharedPtr.hpp index bd1284d..ee9172b 100644 --- a/include/hyprutils/memory/SharedPtr.hpp +++ b/include/hyprutils/memory/SharedPtr.hpp @@ -28,31 +28,33 @@ namespace Hyprutils { /* creates a new shared pointer managing a resource avoid calling. Could duplicate ownership. Prefer makeShared */ - explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl_base(sc(object), _delete)) { + explicit CSharedPointer(T* object) noexcept : impl_(new Impl_::impl_base(sc(object), _delete)), m_data(sc(object)) { increment(); } /* creates a shared pointer from a reference */ template > - CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { + CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_), m_data(ref.m_data) { increment(); } - CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_) { + CSharedPointer(const CSharedPointer& ref) noexcept : impl_(ref.impl_), m_data(ref.m_data) { increment(); } template > CSharedPointer(CSharedPointer&& ref) noexcept { std::swap(impl_, ref.impl_); + std::swap(m_data, ref.m_data); } CSharedPointer(CSharedPointer&& ref) noexcept { std::swap(impl_, ref.impl_); + std::swap(m_data, ref.m_data); } /* allows weakPointer to create from an impl */ - CSharedPointer(Impl_::impl_base* implementation) noexcept : impl_(implementation) { + CSharedPointer(Impl_::impl_base* implementation, void* data) noexcept : impl_(implementation), m_data(data) { increment(); } @@ -74,7 +76,8 @@ namespace Hyprutils { return *this; decrement(); - impl_ = rhs.impl_; + impl_ = rhs.impl_; + m_data = rhs.m_data; increment(); return *this; } @@ -84,7 +87,8 @@ namespace Hyprutils { return *this; decrement(); - impl_ = rhs.impl_; + impl_ = rhs.impl_; + m_data = rhs.m_data; increment(); return *this; } @@ -92,11 +96,13 @@ namespace Hyprutils { template validHierarchy&> operator=(CSharedPointer&& rhs) { std::swap(impl_, rhs.impl_); + std::swap(m_data, rhs.m_data); return *this; } CSharedPointer& operator=(CSharedPointer&& rhs) noexcept { std::swap(impl_, rhs.impl_); + std::swap(m_data, rhs.m_data); return *this; } @@ -104,6 +110,8 @@ namespace Hyprutils { return impl_ && impl_->dataNonNull(); } + // this compares that the pointed-to object is the same, but in multiple inheritance, + // different typed pointers can be equal if the object is the same bool operator==(const CSharedPointer& rhs) const { return impl_ == rhs.impl_; } @@ -126,11 +134,12 @@ namespace Hyprutils { void reset() { decrement(); - impl_ = nullptr; + impl_ = nullptr; + m_data = nullptr; } T* get() const { - return impl_ ? sc(impl_->getData()) : nullptr; + return impl_ && impl_->dataNonNull() ? sc(m_data) : nullptr; } unsigned int strongRef() const { @@ -139,6 +148,9 @@ namespace Hyprutils { Impl_::impl_base* impl_ = nullptr; + // Never use directly: raw data ptr, could be UAF + void* m_data = nullptr; + private: static void _delete(void* p) { std::default_delete{}(sc(p)); @@ -188,7 +200,17 @@ namespace Hyprutils { template CSharedPointer reinterpretPointerCast(const CSharedPointer& ref) { - return CSharedPointer(ref.impl_); + return CSharedPointer(ref.impl_, ref.m_data); + } + + template + CSharedPointer dynamicPointerCast(const CSharedPointer& ref) { + if (!ref) + return nullptr; + T* newPtr = dynamic_cast(sc(ref.impl_->getData())); + if (!newPtr) + return nullptr; + return CSharedPointer(ref.impl_, newPtr); } } } diff --git a/include/hyprutils/memory/WeakPtr.hpp b/include/hyprutils/memory/WeakPtr.hpp index 5077ee6..e062a49 100644 --- a/include/hyprutils/memory/WeakPtr.hpp +++ b/include/hyprutils/memory/WeakPtr.hpp @@ -26,7 +26,8 @@ namespace Hyprutils { if (!ref.impl_) return; - impl_ = ref.impl_; + impl_ = ref.impl_; + m_data = ref.m_data; incrementWeak(); } @@ -36,7 +37,8 @@ namespace Hyprutils { if (!ref.impl_) return; - impl_ = ref.impl_; + impl_ = ref.impl_; + m_data = ref.impl_->getData(); incrementWeak(); } @@ -46,7 +48,8 @@ namespace Hyprutils { if (!ref.impl_) return; - impl_ = ref.impl_; + impl_ = ref.impl_; + m_data = ref.m_data; incrementWeak(); } @@ -54,17 +57,20 @@ namespace Hyprutils { if (!ref.impl_) return; - impl_ = ref.impl_; + impl_ = ref.impl_; + m_data = ref.m_data; incrementWeak(); } template > CWeakPointer(CWeakPointer&& ref) noexcept { std::swap(impl_, ref.impl_); + std::swap(m_data, ref.m_data); } CWeakPointer(CWeakPointer&& ref) noexcept { std::swap(impl_, ref.impl_); + std::swap(m_data, ref.m_data); } /* create a weak ptr from another weak ptr with assignment */ @@ -74,7 +80,8 @@ namespace Hyprutils { return *this; decrementWeak(); - impl_ = rhs.impl_; + impl_ = rhs.impl_; + m_data = rhs.m_data; incrementWeak(); return *this; } @@ -84,7 +91,8 @@ namespace Hyprutils { return *this; decrementWeak(); - impl_ = rhs.impl_; + impl_ = rhs.impl_; + m_data = rhs.m_data; incrementWeak(); return *this; } @@ -96,7 +104,8 @@ namespace Hyprutils { return *this; decrementWeak(); - impl_ = rhs.impl_; + impl_ = rhs.impl_; + m_data = rhs.m_data; incrementWeak(); return *this; } @@ -125,14 +134,15 @@ namespace Hyprutils { void reset() { decrementWeak(); - impl_ = nullptr; + impl_ = nullptr; + m_data = nullptr; } CSharedPointer lock() const { if (!impl_ || !impl_->dataNonNull() || impl_->destroying() || !impl_->lockable()) return {}; - return CSharedPointer(impl_); + return CSharedPointer(impl_, m_data); } /* this returns valid() */ @@ -169,7 +179,7 @@ namespace Hyprutils { } T* get() const { - return impl_ ? sc(impl_->getData()) : nullptr; + return impl_ && impl_->dataNonNull() ? sc(m_data) : nullptr; } T* operator->() const { @@ -182,6 +192,9 @@ namespace Hyprutils { Impl_::impl_base* impl_ = nullptr; + // Never use directly: raw data ptr, could be UAF + void* m_data = nullptr; + private: /* no-op if there is no impl_ */ void decrementWeak() { @@ -207,6 +220,16 @@ namespace Hyprutils { impl_->incWeak(); } }; + + template + CWeakPointer dynamicPointerCast(const CWeakPointer& ref) { + if (!ref) + return nullptr; + T* newPtr = dynamic_cast(sc(ref.impl_->getData())); + if (!newPtr) + return nullptr; + return CWeakPointer(ref.impl_, newPtr); + } } } diff --git a/tests/memory/Memory.cpp b/tests/memory/Memory.cpp index 65ee091..ffb5022 100644 --- a/tests/memory/Memory.cpp +++ b/tests/memory/Memory.cpp @@ -123,6 +123,119 @@ static void testAtomicImpl() { } } +class InterfaceA { + public: + virtual ~InterfaceA() = default; + int m_ifaceAInt = 69; + int m_ifaceAShit = 1; +}; + +class InterfaceB { + public: + virtual ~InterfaceB() = default; + int m_ifaceBInt = 2; + int m_ifaceBShit = 3; +}; + +class CChild : public InterfaceA, public InterfaceB { + public: + virtual ~CChild() = default; + int m_childInt = 4; +}; + +class CChildA : public InterfaceA { + public: + int m_childAInt = 4; +}; + +static void testHierarchy() { + // Same test for atomic and non-atomic + { + SP childA = makeShared(); + auto ifaceA = SP(childA); + EXPECT_TRUE(ifaceA); + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + + auto ifaceB = dynamicPointerCast(SP{}); + EXPECT_TRUE(!ifaceB); + } + + { + SP child = makeShared(); + SP ifaceA = dynamicPointerCast(child); + SP ifaceB = dynamicPointerCast(child); + EXPECT_TRUE(ifaceA); + EXPECT_TRUE(ifaceB); + + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + EXPECT_EQ(ifaceB->m_ifaceBInt, 2); + + WP ifaceAWeak = ifaceA; + + child.reset(); + EXPECT_TRUE(ifaceAWeak); + EXPECT_TRUE(ifaceA); + EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69); + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + ifaceA.reset(); + EXPECT_TRUE(ifaceAWeak); + EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69); + EXPECT_TRUE(ifaceB); + EXPECT_EQ(ifaceB->m_ifaceBInt, 2); + ifaceB.reset(); + EXPECT_TRUE(!ifaceAWeak); + } + + // + + { + ASP childA = makeAtomicShared(); + auto ifaceA = ASP(childA); + EXPECT_TRUE(ifaceA); + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + + auto ifaceB = dynamicPointerCast(ASP{}); + EXPECT_TRUE(!ifaceB); + } + + { + ASP child = makeAtomicShared(); + ASP ifaceA = dynamicPointerCast(child); + ASP ifaceB = dynamicPointerCast(child); + EXPECT_TRUE(ifaceA); + EXPECT_TRUE(ifaceB); + + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + EXPECT_EQ(ifaceB->m_ifaceBInt, 2); + + AWP ifaceAWeak = ifaceA; + AWP ifaceBWeak = dynamicPointerCast(ifaceA); + + child.reset(); + EXPECT_TRUE(ifaceAWeak); + EXPECT_TRUE(ifaceBWeak); + EXPECT_TRUE(ifaceA); + EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69); + EXPECT_EQ(ifaceA->m_ifaceAInt, 69); + EXPECT_EQ(ifaceBWeak->m_ifaceBInt, 2); + ifaceA.reset(); + EXPECT_TRUE(ifaceAWeak); + EXPECT_EQ(ifaceAWeak->m_ifaceAInt, 69); + EXPECT_TRUE(ifaceB); + EXPECT_EQ(ifaceB->m_ifaceBInt, 2); + EXPECT_EQ(ifaceBWeak->m_ifaceBInt, 2); + ifaceB.reset(); + EXPECT_TRUE(!ifaceAWeak); + EXPECT_TRUE(!ifaceBWeak); + } + + // test for leaks + for (size_t i = 0; i < 10000; ++i) { + auto child = makeAtomicShared(); + auto child2 = makeShared(); + } +} + TEST(Memory, memory) { SP intPtr = makeShared(10); SP intPtr2 = makeShared(-1337); @@ -176,4 +289,6 @@ TEST(Memory, memory) { EXPECT_EQ(*intPtr2, 10); testAtomicImpl(); + + testHierarchy(); }