Skip to content

Commit 0dc7345

Browse files
committed
[Offload] Have olMemFree accept a platform as a param
In a future change, most of the allocation tracking will be removed from liboffload itself and be delegated to the plugins. Therefore, we will need to know which plugin is in charge of the allocation.
1 parent 3f3f7d1 commit 0dc7345

File tree

15 files changed

+94
-50
lines changed

15 files changed

+94
-50
lines changed

offload/liboffload/API/Memory.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ def olMemAlloc : Function {
3737
def olMemFree : Function {
3838
let desc = "Frees a memory allocation previously made by olMemAlloc.";
3939
let params = [
40+
Param<"ol_platform_handle_t", "Platform", "handle of the platform that allocated this memory", PARAM_IN>,
4041
Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
4142
];
42-
let returns = [];
43+
let returns = [
44+
Return<"OL_ERRC_NOT_FOUND", ["memory was not allocated by this platform"]>
45+
];
4346
}
4447

4548
def olMemcpy : Function {

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
632632
return Error::success();
633633
}
634634

635-
Error olMemFree_impl(void *Address) {
635+
Error olMemFree_impl(ol_platform_handle_t Platform, void *Address) {
636636
ol_device_handle_t Device;
637637
ol_alloc_type_t Type;
638638
{
@@ -646,6 +646,7 @@ Error olMemFree_impl(void *Address) {
646646
Type = AllocInfo.Type;
647647
OffloadContext::get().AllocInfoMap.erase(Address);
648648
}
649+
assert(Platform == Device->Platform);
649650

650651
if (auto Res =
651652
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))

offload/unittests/Conformance/include/mathtest/DeviceContext.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,13 @@ class DeviceContext {
5757
explicit DeviceContext(llvm::StringRef Platform, std::size_t DeviceId = 0);
5858

5959
template <typename T>
60-
ManagedBuffer<T> createManagedBuffer(std::size_t Size) const noexcept {
60+
ManagedBuffer<T> createManagedBuffer(std::size_t Size) noexcept {
6161
void *UntypedAddress = nullptr;
6262

6363
detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress);
6464
T *TypedAddress = static_cast<T *>(UntypedAddress);
6565

66-
return ManagedBuffer<T>(TypedAddress, Size);
66+
return ManagedBuffer<T>(getPlatformHandle(), TypedAddress, Size);
6767
}
6868

6969
[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
@@ -120,6 +120,9 @@ class DeviceContext {
120120

121121
[[nodiscard]] llvm::StringRef getPlatform() const noexcept;
122122

123+
[[nodiscard]] llvm::Expected<ol_platform_handle_t>
124+
getPlatformHandle() noexcept;
125+
123126
private:
124127
[[nodiscard]] llvm::Expected<ol_symbol_handle_t>
125128
getKernelHandle(ol_program_handle_t ProgramHandle,
@@ -131,6 +134,7 @@ class DeviceContext {
131134

132135
std::size_t GlobalDeviceId;
133136
ol_device_handle_t DeviceHandle;
137+
ol_platform_handle_t PlatformHandle = nullptr;
134138
};
135139
} // namespace mathtest
136140

offload/unittests/Conformance/include/mathtest/DeviceResources.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class DeviceContext;
2929

3030
namespace detail {
3131

32-
void freeDeviceMemory(void *Address) noexcept;
32+
void freeDeviceMemory(ol_platform_handle_t Platform, void *Address) noexcept;
3333
} // namespace detail
3434

3535
//===----------------------------------------------------------------------===//
@@ -40,7 +40,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
4040
public:
4141
~ManagedBuffer() noexcept {
4242
if (Address)
43-
detail::freeDeviceMemory(Address);
43+
detail::freeDeviceMemory(Platform, Address);
4444
}
4545

4646
ManagedBuffer(const ManagedBuffer &) = delete;
@@ -57,7 +57,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
5757
return *this;
5858

5959
if (Address)
60-
detail::freeDeviceMemory(Address);
60+
detail::freeDeviceMemory(Platform, Address);
6161

6262
Address = Other.Address;
6363
Size = Other.Size;
@@ -85,9 +85,11 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
8585
private:
8686
friend class DeviceContext;
8787

88-
explicit ManagedBuffer(T *Address, std::size_t Size) noexcept
89-
: Address(Address), Size(Size) {}
88+
explicit ManagedBuffer(ol_platform_handle_t Platform, T *Address,
89+
std::size_t Size) noexcept
90+
: Platform(Platform), Address(Address), Size(Size) {}
9091

92+
ol_platform_handle_t Platform;
9193
T *Address = nullptr;
9294
std::size_t Size = 0;
9395
};

offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class [[nodiscard]] GpuMathTest final {
7575

7676
ResultType run(GeneratorType &Generator,
7777
std::size_t BufferSize = DefaultBufferSize,
78-
uint32_t GroupSize = DefaultGroupSize) const noexcept {
78+
uint32_t GroupSize = DefaultGroupSize) noexcept {
7979
assert(BufferSize > 0 && "Buffer size must be a positive value");
8080
assert(GroupSize > 0 && "Group size must be a positive value");
8181

@@ -128,7 +128,7 @@ class [[nodiscard]] GpuMathTest final {
128128
return *ExpectedKernel;
129129
}
130130

131-
[[nodiscard]] auto createBuffers(std::size_t BufferSize) const {
131+
[[nodiscard]] auto createBuffers(std::size_t BufferSize) {
132132
auto InBuffersTuple = std::apply(
133133
[&](auto... InTypeIdentities) {
134134
return std::make_tuple(

offload/unittests/Conformance/include/mathtest/OffloadForward.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ typedef struct ol_program_impl_t *ol_program_handle_t;
3232
struct ol_symbol_impl_t;
3333
typedef struct ol_symbol_impl_t *ol_symbol_handle_t;
3434

35+
struct ol_platform_impl_t;
36+
typedef struct ol_platform_impl_t *ol_platform_handle_t;
37+
3538
#ifdef __cplusplus
3639
}
3740
#endif // __cplusplus

offload/unittests/Conformance/lib/DeviceContext.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,29 @@ DeviceContext::getKernelHandle(ol_program_handle_t ProgramHandle,
286286
return Handle;
287287
}
288288

289+
llvm::Expected<ol_platform_handle_t>
290+
DeviceContext::getPlatformHandle() noexcept {
291+
if (!PlatformHandle) {
292+
const ol_result_t OlResult =
293+
olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PLATFORM,
294+
sizeof(PlatformHandle), &PlatformHandle);
295+
296+
if (OlResult != OL_SUCCESS) {
297+
PlatformHandle = nullptr;
298+
llvm::StringRef Details =
299+
OlResult->Details ? OlResult->Details : "No details provided";
300+
301+
// clang-format off
302+
return llvm::createStringError(
303+
llvm::Twine(Details) +
304+
" (code " + llvm::Twine(OlResult->Code) + ")");
305+
// clang-format on
306+
}
307+
}
308+
309+
return PlatformHandle;
310+
}
311+
289312
void DeviceContext::launchKernelImpl(
290313
ol_symbol_handle_t KernelHandle, uint32_t NumGroups, uint32_t GroupSize,
291314
const void *KernelArgs, std::size_t KernelArgsSize) const noexcept {

offload/unittests/Conformance/lib/DeviceResources.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ using namespace mathtest;
2424
// Helpers
2525
//===----------------------------------------------------------------------===//
2626

27-
void detail::freeDeviceMemory(void *Address) noexcept {
27+
void detail::freeDeviceMemory(ol_platform_handle_t Platform,
28+
void *Address) noexcept {
2829
if (Address)
29-
OL_CHECK(olMemFree(Address));
30+
OL_CHECK(olMemFree(Platform, Address));
3031
}
3132

3233
//===----------------------------------------------------------------------===//

offload/unittests/OffloadAPI/common/Fixtures.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,20 +137,20 @@ struct OffloadDeviceTest
137137
Device = DeviceParam.Handle;
138138
if (Device == nullptr)
139139
GTEST_SKIP() << "No available devices.";
140+
141+
ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
142+
sizeof(ol_platform_handle_t), &Platform));
140143
}
141144

142145
ol_platform_backend_t getPlatformBackend() const {
143-
ol_platform_handle_t Platform = nullptr;
144-
if (olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
145-
sizeof(ol_platform_handle_t), &Platform))
146-
return OL_PLATFORM_BACKEND_UNKNOWN;
147146
ol_platform_backend_t Backend;
148147
if (olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND,
149148
sizeof(ol_platform_backend_t), &Backend))
150149
return OL_PLATFORM_BACKEND_UNKNOWN;
151150
return Backend;
152151
}
153152

153+
ol_platform_handle_t Platform = nullptr;
154154
ol_device_handle_t Device = nullptr;
155155
};
156156

offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
101101
ASSERT_EQ(Data[i], i);
102102
}
103103

104-
ASSERT_SUCCESS(olMemFree(Mem));
104+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
105105
}
106106

107107
TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
@@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
123123
ASSERT_EQ(Data[i], i);
124124
}
125125

126-
ASSERT_SUCCESS(olMemFree(Mem));
126+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
127127
});
128128
}
129129

@@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
151151
ASSERT_EQ(Data[i], i);
152152
}
153153

154-
ASSERT_SUCCESS(olMemFree(Mem));
154+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
155155
}
156156

157157
TEST_P(olLaunchKernelLocalMemTest, Success) {
@@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
176176
for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
177177
ASSERT_EQ(Data[i], (i % 64) * 2);
178178

179-
ASSERT_SUCCESS(olMemFree(Mem));
179+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
180180
}
181181

182182
TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
@@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
199199
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
200200
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
201201

202-
ASSERT_SUCCESS(olMemFree(Mem));
202+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
203203
}
204204

205205
TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
@@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
222222
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
223223
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
224224

225-
ASSERT_SUCCESS(olMemFree(Mem));
225+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
226226
}
227227

228228
TEST_P(olLaunchKernelGlobalTest, Success) {
@@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
245245
ASSERT_EQ(Data[i], i * 2);
246246
}
247247

248-
ASSERT_SUCCESS(olMemFree(Mem));
248+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
249249
}
250250

251251
TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
@@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
273273
ASSERT_EQ(Data[i], i + 100);
274274
}
275275

276-
ASSERT_SUCCESS(olMemFree(Mem));
276+
ASSERT_SUCCESS(olMemFree(Platform, Mem));
277277
}
278278

279279
TEST_P(olLaunchKernelGlobalDtorTest, Success) {

0 commit comments

Comments
 (0)