Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ucm/store/device/cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ target_compile_options(storedevice PRIVATE
--diag-suppress=128 --diag-suppress=2417 --diag-suppress=2597
-Wall -fPIC
)
add_library(Cuda::cudart UNKNOWN IMPORTED)
set_target_properties(Cuda::cudart PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${CUDA_ROOT}/include"
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcudart.so"
IMPORTED_LOCATION "${CUDA_ROOT}/lib64/libcufile.so"
)
target_link_libraries(storedevice PUBLIC Cuda::cudart)
108 changes: 108 additions & 0 deletions ucm/store/device/cuda/cuda_device.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@
#include <cuda_runtime.h>
#include "ibuffered_device.h"
#include "logger/logger.h"
#include <cufile.h>
#include <mutex>
#include <fcntl.h>
#include <unistd.h>
#include <cerrno>
#include <cstring>
#include <unordered_map>
#include <cstdlib>
#include "infra/template/handle_recorder.h"

#define CUDA_TRANS_UNIT_SIZE (sizeof(uint64_t) * 2)
#define CUDA_TRANS_BLOCK_NUMBER (32)
Expand Down Expand Up @@ -90,6 +99,25 @@ struct fmt::formatter<cudaError_t> : formatter<int32_t> {

namespace UC {

static Status CreateCuFileHandle(int fd, CUfileHandle_t& cuFileHandle)
{
if (fd < 0) {
UC_ERROR("Invalid file descriptor: {}", fd);
return Status::Error();
}

CUfileDescr_t cfDescr{};
cfDescr.handle.fd = fd;
cfDescr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
CUfileError_t err = cuFileHandleRegister(&cuFileHandle, &cfDescr);
if (err.err != CU_FILE_SUCCESS) {
UC_ERROR("Failed to register cuFile handle for fd {}: error {}",
fd, static_cast<int>(err.err));
return Status::Error();
}

return Status::OK();
}
template <typename Api, typename... Args>
Status CudaApi(const char* caller, const char* file, const size_t line, const char* name, Api&& api,
Args&&... args)
Expand Down Expand Up @@ -133,12 +161,23 @@ class CudaDevice : public IBufferedDevice {
return nullptr;
}
static void ReleaseDeviceArray(void* deviceArray) { CUDA_API(cudaFree, deviceArray); }
static std::once_flag gdsOnce_;

public:
static Status InitGdsOnce();
CudaDevice(const int32_t deviceId, const size_t bufferSize, const size_t bufferNumber)
: IBufferedDevice{deviceId, bufferSize, bufferNumber}, stream_{nullptr}
{
}
~CudaDevice() {
HandlePool<int, CUfileHandle_t>::Instance().ClearAll([](CUfileHandle_t h) {
cuFileHandleDeregister(h);
});

if (stream_ != nullptr) {
cudaStreamDestroy((cudaStream_t)stream_);
}
}
Status Setup() override
{
auto status = Status::OK();
Expand All @@ -165,6 +204,52 @@ public:
{
return CUDA_API(cudaMemcpyAsync, dst, src, count, cudaMemcpyDeviceToHost, this->stream_);
}
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
{
CUfileHandle_t cuFileHandle = nullptr;
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
[fd](CUfileHandle_t& handle) -> Status {
return CreateCuFileHandle(fd, handle);
});
if (status.Failure()) {
return status;
}
ssize_t bytesRead = cuFileRead(cuFileHandle, address, length, fileOffset, devOffset);
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
if (h != nullptr) {
cuFileHandleDeregister(h);
}
});

if (bytesRead < 0 || (size_t)bytesRead != length) {
UC_ERROR("cuFileRead failed for fd {}: expected {}, got {}", fd, length, bytesRead);
return Status::Error();
}
return Status::OK();
}
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
{
CUfileHandle_t cuFileHandle = nullptr;
auto status = HandlePool<int, CUfileHandle_t>::Instance().Get(fd, cuFileHandle,
[fd](CUfileHandle_t& handle) -> Status {
return CreateCuFileHandle(fd, handle);
});
if (status.Failure()) {
return status;
}
ssize_t bytesWrite = cuFileWrite(cuFileHandle, address, length, fileOffset, devOffset);
HandlePool<int, CUfileHandle_t>::Instance().Put(fd, [](CUfileHandle_t h) {
if (h != nullptr) {
cuFileHandleDeregister(h);
}
});

if (bytesWrite < 0 || (size_t)bytesWrite != length) {
UC_ERROR("cuFileWrite failed for fd {}: expected {}, got {}", fd, length, bytesWrite);
return Status::Error();
}
return Status::OK();
}
Status AppendCallback(std::function<void(bool)> cb) override
{
auto* c = new (std::nothrow) Closure(cb);
Expand Down Expand Up @@ -226,6 +311,14 @@ private:
cudaStream_t stream_;
};

Status DeviceFactory::Setup(bool useDirect)
{
if (useDirect) {
return CudaDevice::InitGdsOnce();
}
return Status::OK();
}

std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_t bufferSize,
const size_t bufferNumber)
{
Expand All @@ -237,5 +330,20 @@ std::unique_ptr<IDevice> DeviceFactory::Make(const int32_t deviceId, const size_
return nullptr;
}
}
std::once_flag CudaDevice::gdsOnce_{};
Status CudaDevice::InitGdsOnce()
{
Status result = Status::OK();
std::call_once(gdsOnce_, [&result]() {
CUfileError_t ret = cuFileDriverOpen();
if (ret.err == CU_FILE_SUCCESS) {
UC_INFO("GDS driver initialized successfully");
} else {
UC_ERROR("GDS driver initialization failed with error code: {}", static_cast<int>(ret.err));
result = Status::Error();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error handling is insufficient, failures need to be visible to the caller.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Received,I'll fix it.

});
return result;
}

} // namespace UC
3 changes: 3 additions & 0 deletions ucm/store/device/idevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class IDevice {
const size_t count) = 0;
virtual Status D2HBatchSync(std::byte* hArr[], const std::byte* dArr[], const size_t number,
const size_t count) = 0;
virtual Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;
virtual Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) = 0;

protected:
virtual std::shared_ptr<std::byte> MakeBuffer(const size_t size) = 0;
Expand All @@ -59,6 +61,7 @@ class IDevice {

class DeviceFactory {
public:
static Status Setup(bool useDirect = false);
static std::unique_ptr<IDevice> Make(const int32_t deviceId, const size_t bufferSize,
const size_t bufferNumber);
};
Expand Down
8 changes: 8 additions & 0 deletions ucm/store/device/simu/simu_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,14 @@ class SimuDevice : public IBufferedDevice {
this->backend_.Push([=] { std::copy(src, src + count, dst); });
return Status::OK();
}
Status S2DSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
{
return Status::Unsupported();
}
Status D2SSync(int fd, void* address, const size_t length, const size_t fileOffset, const size_t devOffset) override
{
return Status::Unsupported();
}
Status AppendCallback(std::function<void(bool)> cb) override
{
this->backend_.Push([=] { cb(true); });
Expand Down
14 changes: 14 additions & 0 deletions ucm/store/infra/file/file.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ Status File::Write(const std::string& path, const size_t offset, const size_t le
return status;
}

Status File::OpenForDirectIO(const std::string& path, uint32_t flags, int& fd)
{
auto file = std::make_unique<FileImpl>(path);
auto status = file->Open(flags);
if (status.Failure()) {
UC_ERROR("Failed to open file({}) with flags({}).", path, flags);
fd = -1;
return status;
}
fd = file->GetHandle();
file.release();
return Status::OK();
}

void File::MUnmap(void* addr, size_t size) { FileImpl{{}}.MUnmap(addr, size); }

void File::ShmUnlink(const std::string& path) { FileImpl{path}.ShmUnlink(); }
Expand Down
1 change: 1 addition & 0 deletions ucm/store/infra/file/file.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class File {
uintptr_t address, const bool directIo = false);
static Status Write(const std::string& path, const size_t offset, const size_t length,
const uintptr_t address, const bool directIo = false);
static Status OpenForDirectIO(const std::string& path, uint32_t flags, int& fd);
static void MUnmap(void* addr, size_t size);
static void ShmUnlink(const std::string& path);
static void Remove(const std::string& path);
Expand Down
1 change: 1 addition & 0 deletions ucm/store/infra/file/ifile.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class IFile {
IFile(const std::string& path) : path_{path} {}
virtual ~IFile() = default;
const std::string& Path() const { return this->path_; }
virtual int32_t GetHandle() const = 0;
virtual Status MkDir() = 0;
virtual Status RmDir() = 0;
virtual Status Rename(const std::string& newName) = 0;
Expand Down
1 change: 1 addition & 0 deletions ucm/store/infra/file/posix_file.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class PosixFile : public IFile {
public:
PosixFile(const std::string& path) : IFile{path}, handle_{-1} {}
~PosixFile() override;
int32_t GetHandle() const override { return handle_; }
Status MkDir() override;
Status RmDir() override;
Status Rename(const std::string& newName) override;
Expand Down
83 changes: 83 additions & 0 deletions ucm/store/infra/template/handle_recorder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#ifndef UC_INFRA_HANDLE_POOL_H
#define UC_INFRA_HANDLE_POOL_H

#include <functional>
#include "status/status.h"
#include "hashmap.h"

namespace UC {

template <typename KeyType, typename HandleType>
class HandlePool {
private:
struct PoolEntry {
HandleType handle;
uint64_t refCount;
};
using PoolMap = HashMap<KeyType, PoolEntry, std::hash<KeyType>, 10>;
PoolMap pool_;

public:
HandlePool() = default;
HandlePool(const HandlePool&) = delete;
HandlePool& operator=(const HandlePool&) = delete;

static HandlePool& Instance()
{
static HandlePool instance;
return instance;
}

Status Get(const KeyType& key, HandleType& handle,
std::function<Status(HandleType&)> instantiate)
{
auto result = pool_.GetOrCreate(key, [&instantiate](PoolEntry& entry) -> bool {
HandleType h{};

auto status = instantiate(h);
if (status.Failure()) {
return false;
}

entry.handle = h;
entry.refCount = 1;
return true;
});

if (!result.has_value()) {
return Status::Error();
}

auto& entry = result.value().get();
entry.refCount++;
handle = entry.handle;
return Status::OK();
}

void Put(const KeyType& key,
std::function<void(HandleType)> cleanup)
{
pool_.Upsert(key, [&cleanup](PoolEntry& entry) -> bool {
entry.refCount--;
if (entry.refCount > 0) {
return false;
}
cleanup(entry.handle);
return true;
});
}

void ClearAll(std::function<void(HandleType)> cleanup)
{
pool_.ForEach([&cleanup](const KeyType& key, PoolEntry& entry) {
(void)key;
cleanup(entry.handle);
});
pool_.Clear();
}
};

} // namespace UC

#endif

Loading