diff --git a/.gitattributes b/.gitattributes deleted file mode 100644 index dfe0770..0000000 --- a/.gitattributes +++ /dev/null @@ -1,2 +0,0 @@ -# Auto detect text files and perform LF normalization -* text=auto diff --git a/.github/workflows/cmake-single-platform.yml b/.github/workflows/cmake-single-platform.yml index a0db99c..78b1e7b 100644 --- a/.github/workflows/cmake-single-platform.yml +++ b/.github/workflows/cmake-single-platform.yml @@ -1,4 +1,4 @@ -name: Build and Test +name: CI on: push: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b05445..0a974d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,18 +1,50 @@ +cmake_minimum_required(VERSION 3.20) +project(CacheDB VERSION 1.0.0 LANGUAGES CXX) -cmake_minimum_required(VERSION 3.10) -project(CacheDB) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED True) +enable_testing() -add_executable(server server.cpp) -add_executable(client client.cpp) +include(FetchContent) +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/03597a01ee50ed33e9dfd640b249b4be3799d395.zip +) +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) # For Windows: Prevent overriding the parent project's compiler/linker settings +FetchContent_MakeAvailable(googletest) -# enable_testing() -# add_executable(server_test server_test.cpp) -# target_link_libraries(server_test PRIVATE ) -# add_test(NAME ServerTest COMMAND server_test) +file(GLOB_RECURSE CLIENT_FILES "client/*.cc" "client/*.cpp") +file(GLOB_RECURSE SERVER_FILES "server/*.cc" "server/*.cpp") +file(GLOB_RECURSE TEST_FILES "tests/*.cc" "tests/*.cpp") -# add_executable(client_test client_test.cpp) -# target_link_libraries(client_test PRIVATE ) -# add_test(NAME ClientTest COMMAND client_test) \ No newline at end of file +add_executable(cachedb_client ${CLIENT_FILES} ${CMAKE_SOURCE_DIR}/client.cpp) +target_include_directories(cachedb_client + PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/client +) + +add_executable(cachedb_server ${SERVER_FILES} ${CMAKE_SOURCE_DIR}/cachedb.cpp) +target_include_directories(cachedb_server + PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/server +) + +add_executable(all_tests ${CLIENT_FILES} ${SERVER_FILES} ${TEST_FILES}) +target_link_libraries(all_tests + PRIVATE + GTest::gtest + GTest::gtest_main +) +target_include_directories(all_tests + PRIVATE + ${CMAKE_SOURCE_DIR}/include + ${CMAKE_SOURCE_DIR}/tests + ${CMAKE_SOURCE_DIR}/client + ${CMAKE_SOURCE_DIR}/server +) + +include(GoogleTest) +gtest_discover_tests(all_tests) diff --git a/README.md b/README.md index 621a341..953c340 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,27 @@ CacheDB is an in-memory key-value store designed for high performance and low la - TTL Management: Set expiration times for keys to automatically manage data lifecycle. - Non-Blocking I/O: Handles multiple client connections efficiently. +## Getting Started + +1. **Clone the repository:** + ```bash + git clone https://github.com/etbala/CacheDB.git + cd CacheDB + ``` + +2. **Build the project:** + ```bash + mkdir build + cd build + cmake .. + make -j$(nproc) + ``` + +3. **Run tests:** + ```bash + ./all_tests + ``` + ## Commands ### GET diff --git a/cachedb.cpp b/cachedb.cpp new file mode 100644 index 0000000..9379871 --- /dev/null +++ b/cachedb.cpp @@ -0,0 +1,27 @@ +#include +#include +#include +#include +#include +#include +#include "server/server.h" + +int main() { + int listen_fd = ::socket(AF_INET, SOCK_STREAM, 0); + if (listen_fd < 0) { perror("socket"); return 1; } + + int val = 1; + ::setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(1234); + addr.sin_addr.s_addr = htonl(0); + if (::bind(listen_fd, (sockaddr*)&addr, sizeof(addr)) < 0) { perror("bind"); return 1; } + + if (::listen(listen_fd, SOMAXCONN) < 0) { perror("listen"); return 1; } + + Server server; + server.run(listen_fd); + return 0; +} diff --git a/client.cpp b/client.cpp index 966649f..2be06c1 100644 --- a/client.cpp +++ b/client.cpp @@ -1,215 +1,18 @@ - +#include "client/client_loop.h" +#include "client/transport.h" #include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -void die(const std::string& message) { - int err = errno; - std::cerr << "[" << err << "] " << message << std::endl; - std::exit(EXIT_FAILURE); -} - -// Serialization codes -enum { - SER_STR = 0, - SER_NIL = 1, - SER_INT = 2, - SER_ERR = 3, - SER_ARR = 4, - SER_DBL = 5, -}; - -// Serialize the request according to the server's protocol -void serialize_request(const std::vector& cmd, std::vector& out) { - uint32_t argc = cmd.size(); - uint32_t data_len = 4; // length of argc - for (const std::string& arg : cmd) { - data_len += 4; // length of arg_len - data_len += arg.size(); // length of argument data - } - uint32_t total_len = data_len; // total length of data after the initial len field - - out.resize(4 + data_len); // 4 bytes for len, plus data_len bytes for data - uint32_t offset = 0; - - // Write total_len into the first 4 bytes - memcpy(&out[offset], &total_len, 4); - offset += 4; - - // Write argc - memcpy(&out[offset], &argc, 4); - offset += 4; - - for (const std::string& arg : cmd) { - uint32_t arg_len = arg.size(); - memcpy(&out[offset], &arg_len, 4); - offset += 4; - memcpy(&out[offset], arg.data(), arg_len); - offset += arg_len; - } -} - -void deserialize_response(const std::vector& in, size_t& offset) { - if (offset >= in.size()) { - die("Response parsing error: offset out of bounds"); - } - uint8_t type = in[offset++]; - switch (type) { - case SER_STR: { - if (offset + 4 > in.size()) { - die("Response parsing error: string length"); - } - uint32_t len = 0; - memcpy(&len, &in[offset], 4); - offset += 4; - if (offset + len > in.size()) { - die("Response parsing error: string data"); - } - std::string str((char*)&in[offset], len); - offset += len; - std::cout << str << std::endl; - break; - } - case SER_INT: { - if (offset + 8 > in.size()) { - die("Response parsing error: int data"); - } - int64_t val = 0; - memcpy(&val, &in[offset], 8); - offset += 8; - std::cout << val << std::endl; - break; - } - case SER_DBL: { - if (offset + 8 > in.size()) { - die("Response parsing error: double data"); - } - double val = 0; - memcpy(&val, &in[offset], 8); - offset += 8; - std::cout << val << std::endl; - break; - } - case SER_NIL: { - std::cout << "(nil)" << std::endl; - break; - } - case SER_ERR: { - if (offset + 4 > in.size()) { - die("Response parsing error: error length"); - } - uint32_t len = 0; - memcpy(&len, &in[offset], 4); - offset += 4; - if (offset + len > in.size()) { - die("Response parsing error: error data"); - } - std::string str((char*)&in[offset], len); - offset += len; - std::cerr << "(error) " << str << std::endl; - break; - } - case SER_ARR: { - if (offset + 4 > in.size()) { - die("Response parsing error: array length"); - } - uint32_t len = 0; - memcpy(&len, &in[offset], 4); - offset += 4; - for (uint32_t i = 0; i < len; ++i) { - deserialize_response(in, offset); - } - break; - } - default: - die("Response parsing error: unknown type"); - } -} - -int main() { - // Create a socket - int sockfd = socket(AF_INET, SOCK_STREAM, 0); - if (sockfd < 0) { - die("socket failed"); - } - - // Server address - sockaddr_in serv_addr = {}; - serv_addr.sin_family = AF_INET; - serv_addr.sin_port = htons(1234); - - // Convert IPv4 and IPv6 addresses from text to binary form - if (inet_pton(AF_INET, "127.0.0.1", &serv_addr.sin_addr) <= 0) { - die("Invalid address / Address not supported"); - } - - // Connect to the server - if (connect(sockfd, (sockaddr*)&serv_addr, sizeof(serv_addr)) < 0) { - die("Connection failed"); - } - - std::string line; - while (std::cout << "> ", std::getline(std::cin, line)) { - // Split line into arguments - std::istringstream iss(line); - std::vector cmd; - std::string arg; - while (iss >> arg) { - cmd.push_back(arg); - } - if (cmd.empty()) { - continue; - } - - // Serialize the request - std::vector request; - serialize_request(cmd, request); - - // Send the request - size_t total_sent = 0; - while (total_sent < request.size()) { - ssize_t n = send(sockfd, &request[total_sent], request.size() - total_sent, 0); - if (n < 0) { - die("send failed"); - } - total_sent += n; - } - // Receive the response header (4 bytes indicating the length) - std::vector response_header(4); - ssize_t n = recv(sockfd, &response_header[0], 4, MSG_WAITALL); - if (n <= 0) { - die("recv failed"); - } - uint32_t resp_len = 0; - memcpy(&resp_len, &response_header[0], 4); - if (resp_len > 10 * 1024 * 1024) { - die("Response too large"); - } +int main(int argc, char** argv) { + const char* host = "127.0.0.1"; + uint16_t port = 1234; - // Receive the response body - std::vector response_body(resp_len); - size_t total_received = 0; - while (total_received < resp_len) { - n = recv(sockfd, &response_body[total_received], resp_len - total_received, 0); - if (n <= 0) { - die("recv failed"); - } - total_received += n; - } + // (Optional) allow overrides: client 127.0.0.1 1234 + if (argc >= 2) host = argv[1]; + if (argc >= 3) port = static_cast(std::stoi(argv[2])); - // Parse and print the response - size_t offset = 0; - deserialize_response(response_body, offset); - } + TcpTransport transport; + transport.connect(host, port); - close(sockfd); - return 0; + // Use std::cin/std::cout/std::cerr in production; in tests you'll pass string streams. + return run_client_repl(transport, std::cin, std::cout, std::cerr); } diff --git a/client/client_loop.cpp b/client/client_loop.cpp new file mode 100644 index 0000000..c124b45 --- /dev/null +++ b/client/client_loop.cpp @@ -0,0 +1,51 @@ +#include "client/client_loop.h" +#include "client/protocol.h" +#include "client/util.h" + +#include +#include +#include +#include +#include + +int run_client_repl(ITransport& transport, + std::istream& in, + std::ostream& out, + std::ostream& err) { + + std::string line; + while (true) { + out << "> " << std::flush; + if (!std::getline(in, line)) break; + + std::istringstream iss(line); + std::vector cmd; + std::string arg; + while (iss >> arg) cmd.push_back(arg); + if (cmd.empty()) continue; + + // build request (unchanged logic) + std::vector request; + serialize_request(cmd, request); + + // send request + transport.send_all(request.data(), request.size()); + + // header (4 bytes length) + uint8_t header[4]; + transport.recv_all(header, 4); + uint32_t resp_len = 0; + std::memcpy(&resp_len, header, 4); + if (resp_len > 10 * 1024 * 1024) die("Response too large"); + + // body + std::vector body(resp_len); + transport.recv_all(body.data(), resp_len); + + // parse & print + size_t offset = 0; + deserialize_response(body, offset, out, err); + } + + return 0; +} diff --git a/client/protocol.cpp b/client/protocol.cpp new file mode 100644 index 0000000..71ac6e2 --- /dev/null +++ b/client/protocol.cpp @@ -0,0 +1,105 @@ +#include "client/protocol.h" +#include "client/util.h" +#include +#include + +void serialize_request(const std::vector& cmd, std::vector& out) { + uint32_t argc = static_cast(cmd.size()); + uint32_t data_len = 4; // length of argc + for (const std::string& arg : cmd) { + data_len += 4; // length of arg_len + data_len += static_cast(arg.size()); + } + uint32_t total_len = data_len; + + out.resize(4 + data_len); // 4 bytes for len, plus data_len bytes for data + uint32_t offset = 0; + + std::memcpy(&out[offset], &total_len, 4); + offset += 4; + + std::memcpy(&out[offset], &argc, 4); + offset += 4; + + for (const std::string& arg : cmd) { + uint32_t arg_len = static_cast(arg.size()); + std::memcpy(&out[offset], &arg_len, 4); + offset += 4; + std::memcpy(&out[offset], arg.data(), arg_len); + offset += arg_len; + } +} + +// helper so both overloads share identical logic +static void deserialize_impl(const std::vector& in, size_t& offset, + std::ostream& out, std::ostream& err) { + if (offset >= in.size()) { + die("Response parsing error: offset out of bounds"); + } + uint8_t type = in[offset++]; + switch (type) { + case SER_STR: { + if (offset + 4 > in.size()) die("Response parsing error: string length"); + uint32_t len = 0; + std::memcpy(&len, &in[offset], 4); + offset += 4; + if (offset + len > in.size()) die("Response parsing error: string data"); + std::string str(reinterpret_cast(&in[offset]), len); + offset += len; + out << str << std::endl; + break; + } + case SER_INT: { + if (offset + 8 > in.size()) die("Response parsing error: int data"); + int64_t val = 0; + std::memcpy(&val, &in[offset], 8); + offset += 8; + out << val << std::endl; + break; + } + case SER_DBL: { + if (offset + 8 > in.size()) die("Response parsing error: double data"); + double val = 0; + std::memcpy(&val, &in[offset], 8); + offset += 8; + out << val << std::endl; + break; + } + case SER_NIL: { + out << "(nil)" << std::endl; + break; + } + case SER_ERR: { + if (offset + 4 > in.size()) die("Response parsing error: error length"); + uint32_t len = 0; + std::memcpy(&len, &in[offset], 4); + offset += 4; + if (offset + len > in.size()) die("Response parsing error: error data"); + std::string str(reinterpret_cast(&in[offset]), len); + offset += len; + err << "(error) " << str << std::endl; + break; + } + case SER_ARR: { + if (offset + 4 > in.size()) die("Response parsing error: array length"); + uint32_t len = 0; + std::memcpy(&len, &in[offset], 4); + offset += 4; + for (uint32_t i = 0; i < len; ++i) { + deserialize_impl(in, offset, out, err); + } + break; + } + default: + die("Response parsing error: unknown type"); + } +} + +void deserialize_response(const std::vector& in, size_t& offset) { + deserialize_impl(in, offset, std::cout, std::cerr); +} + +void deserialize_response(const std::vector& in, size_t& offset, + std::ostream& out, std::ostream& err) { + deserialize_impl(in, offset, out, err); +} diff --git a/client/transport.cpp b/client/transport.cpp new file mode 100644 index 0000000..9b54b91 --- /dev/null +++ b/client/transport.cpp @@ -0,0 +1,53 @@ +#include "client/transport.h" +#include "client/util.h" + +#include +#include +#include +#include +#include +#include +#include + +TcpTransport::TcpTransport() : sockfd_(-1) {} +TcpTransport::~TcpTransport() { close(); } + +void TcpTransport::connect(const std::string& host, uint16_t port) { + sockfd_ = ::socket(AF_INET, SOCK_STREAM, 0); + if (sockfd_ < 0) die("socket failed"); + + sockaddr_in addr{}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + if (::inet_pton(AF_INET, host.c_str(), &addr.sin_addr) <= 0) { + die("Invalid address / Address not supported"); + } + if (::connect(sockfd_, reinterpret_cast(&addr), sizeof(addr)) < 0) { + die("Connection failed"); + } +} + +void TcpTransport::send_all(const uint8_t* data, size_t len) { + size_t sent = 0; + while (sent < len) { + ssize_t n = ::send(sockfd_, data + sent, len - sent, 0); + if (n < 0) die("send failed"); + sent += static_cast(n); + } +} + +void TcpTransport::recv_all(uint8_t* data, size_t len) { + size_t recvd = 0; + while (recvd < len) { + ssize_t n = ::recv(sockfd_, data + recvd, len - recvd, 0); + if (n <= 0) die("recv failed"); + recvd += static_cast(n); + } +} + +void TcpTransport::close() { + if (sockfd_ >= 0) { + ::close(sockfd_); + sockfd_ = -1; + } +} diff --git a/client/util.cpp b/client/util.cpp new file mode 100644 index 0000000..0567de5 --- /dev/null +++ b/client/util.cpp @@ -0,0 +1,10 @@ +#include "client/util.h" +#include +#include +#include + +[[noreturn]] void die(const std::string& message) { + int err = errno; + std::cerr << "[" << err << "] " << message << std::endl; + std::exit(EXIT_FAILURE); +} diff --git a/client_test.cpp b/client_test.cpp deleted file mode 100644 index e69de29..0000000 diff --git a/include/client/client_loop.h b/include/client/client_loop.h new file mode 100644 index 0000000..1a7ca0e --- /dev/null +++ b/include/client/client_loop.h @@ -0,0 +1,10 @@ +#pragma once +#include "transport.h" +#include + +// A function that runs the interactive loop (REPL). +// It depends on a transport and streams that can be swapped in tests. +int run_client_repl(ITransport& transport, + std::istream& in, + std::ostream& out, + std::ostream& err); diff --git a/include/client/protocol.h b/include/client/protocol.h new file mode 100644 index 0000000..417381d --- /dev/null +++ b/include/client/protocol.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include +#include +#include +#include +#include "common/serialization.h" + +// Moved from client.cpp. Same signatures/logic so you can unit test them. +void serialize_request(const std::vector& cmd, std::vector& out); + +// Original function wrote to std::cout/std::cerr directly. +// Keep it, but also provide an overload that takes output streams (handy for tests). +void deserialize_response(const std::vector& in, size_t& offset); +void deserialize_response(const std::vector& in, size_t& offset, + std::ostream& out, std::ostream& err); diff --git a/include/client/transport.h b/include/client/transport.h new file mode 100644 index 0000000..6bb8e5e --- /dev/null +++ b/include/client/transport.h @@ -0,0 +1,29 @@ +#pragma once +#include +#include +#include +#include + +// Thin transport interface so logic can be tested with a fake transport in gtest. +struct ITransport { + virtual ~ITransport() = default; + virtual void connect(const std::string& host, uint16_t port) = 0; + virtual void send_all(const uint8_t* data, size_t len) = 0; + virtual void recv_all(uint8_t* data, size_t len) = 0; + virtual void close() = 0; +}; + +// Concrete TCP transport used by main program. +class TcpTransport : public ITransport { +public: + TcpTransport(); + ~TcpTransport() override; + + void connect(const std::string& host, uint16_t port) override; + void send_all(const uint8_t* data, size_t len) override; + void recv_all(uint8_t* data, size_t len) override; + void close() override; + +private: + int sockfd_; +}; diff --git a/include/client/util.h b/include/client/util.h new file mode 100644 index 0000000..e29c777 --- /dev/null +++ b/include/client/util.h @@ -0,0 +1,5 @@ +#pragma once +#include + +// Same die() you had, made reusable. +[[noreturn]] void die(const std::string& message); diff --git a/include/common/serialization.h b/include/common/serialization.h new file mode 100644 index 0000000..be62ca3 --- /dev/null +++ b/include/common/serialization.h @@ -0,0 +1,10 @@ +#pragma once + +enum { + SER_STR = 0, + SER_NIL = 1, + SER_INT = 2, + SER_ERR = 3, + SER_ARR = 4, + SER_DBL = 5, +}; \ No newline at end of file diff --git a/include/server/entry.h b/include/server/entry.h new file mode 100644 index 0000000..25b90fc --- /dev/null +++ b/include/server/entry.h @@ -0,0 +1,21 @@ +#pragma once +#include +#include "zset.h" + +class Entry { +public: + enum Type { STRING, ZSET }; + std::string key; + Type type; + std::string str_value; // For STRING type + ZSet* zset_value; // For ZSET type + + Entry(const std::string& k, const std::string& val) + : key(k), type(STRING), str_value(val), zset_value(nullptr) {} + + ~Entry() { + if (type == ZSET && zset_value) { + delete zset_value; + } + } +}; diff --git a/hashtable.h b/include/server/hashtable.h similarity index 99% rename from hashtable.h rename to include/server/hashtable.h index 1fb0c6d..0b32423 100644 --- a/hashtable.h +++ b/include/server/hashtable.h @@ -1,4 +1,3 @@ - #pragma once #include diff --git a/include/server/protocol.h b/include/server/protocol.h new file mode 100644 index 0000000..d228d72 --- /dev/null +++ b/include/server/protocol.h @@ -0,0 +1,17 @@ +#pragma once +#include +#include +#include +#include "common/serialization.h" + +// Serialization helpers (same behavior/signatures as before) +void out_string(std::string& out, const std::string& str); +void out_nil(std::string& out); +void out_int(std::string& out, int64_t val); +void out_error(std::string& out, const std::string& msg); +void out_ok(std::string& out); +void out_array(std::string& out, const std::vector& arr); +void out_double(std::string& out, double val); + +// Request parsing (same behavior as before) +int parse_request(const uint8_t* data, size_t len, std::vector& out); diff --git a/include/server/server.h b/include/server/server.h new file mode 100644 index 0000000..510f2ad --- /dev/null +++ b/include/server/server.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include +#include +#include "server/hashtable.h" +#include "server/entry.h" + +enum ConnectionState { + STATE_REQ, + STATE_RES, + STATE_END +}; + +struct Connection { + int fd; + ConnectionState state; + std::vector rbuf; + std::vector wbuf; + size_t wbuf_sent; + + static const size_t k_max_msg = 4096; + + explicit Connection(int fd_) + : fd(fd_), state(STATE_REQ), wbuf_sent(0) { + rbuf.reserve(4 + k_max_msg); + wbuf.reserve(4 + k_max_msg); + } +}; + +class Server { +public: + Server(); + ~Server(); + + // Runs the poll loop on an already-bound+listening socket. + void run(int listen_fd); + + void handle_command(const std::vector& cmd, std::string& out); + +private: + void accept_new_connection(int listen_fd); + void close_connection(Connection* conn); + void handle_connection_io(Connection* conn); + void handle_read(Connection* conn); + void handle_write(Connection* conn); + +private: + HashTable db_; + std::vector fd2conn_; +}; diff --git a/include/server/zset.h b/include/server/zset.h new file mode 100644 index 0000000..bee3b49 --- /dev/null +++ b/include/server/zset.h @@ -0,0 +1,74 @@ +#pragma once + +#include +#include +#include +#include + +class ZSet { +private: + // Node structure for tree_by_score + class ScoreNode { + public: + double score; + std::string member; + int height; + ScoreNode* left; + ScoreNode* right; + + ScoreNode(double s, const std::string& m) + : score(s), member(m), height(1), left(nullptr), right(nullptr) {} + }; + + // Node structure for tree_by_member + class MemberNode { + public: + std::string member; + double score; + int height; + MemberNode* left; + MemberNode* right; + + MemberNode(const std::string& m, double s) + : member(m), score(s), height(1), left(nullptr), right(nullptr) {} + }; + + ScoreNode* tree_by_score; + MemberNode* tree_by_member; + +public: + ZSet(); + ~ZSet(); + + bool zadd(const std::string& member, double score); + bool zrem(const std::string& member); + bool zscore(const std::string& member, double& out_score); + std::vector> zquery(double min_score, const std::string& min_member, int offset, int limit); + +private: + // AVL tree functions for ScoreNode + int height(ScoreNode* node); + void updateHeight(ScoreNode* node); + int balanceFactor(ScoreNode* node); + ScoreNode* rotateLeft(ScoreNode* x); + ScoreNode* rotateRight(ScoreNode* y); + ScoreNode* balance(ScoreNode* node); + ScoreNode* insert(ScoreNode* node, double score, const std::string& member); + ScoreNode* remove(ScoreNode* node, double score, const std::string& member); + void inorder(ScoreNode* node, double min_score, const std::string& min_member, int& offset, int limit, std::vector>& result); + void destroy(ScoreNode* node); + + // AVL tree functions for MemberNode + int height(MemberNode* node); + void updateHeight(MemberNode* node); + int balanceFactor(MemberNode* node); + MemberNode* rotateLeft(MemberNode* x); + MemberNode* rotateRight(MemberNode* y); + MemberNode* balance(MemberNode* node); + MemberNode* insert(MemberNode* node, const std::string& member, double score); + MemberNode* remove(MemberNode* node, const std::string& member); + MemberNode* find(MemberNode* node, const std::string& member); + void destroy(MemberNode* node); +}; + + diff --git a/server.cpp b/server.cpp deleted file mode 100644 index d580481..0000000 --- a/server.cpp +++ /dev/null @@ -1,465 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "hashtable.h" -#include "zset.h" - -// Serialization codes -enum { - SER_STR = 0, - SER_NIL = 1, - SER_INT = 2, - SER_ERR = 3, - SER_ARR = 4, - SER_DBL = 5, -}; - -// Entry class representing a key-value pair -class Entry { -public: - enum Type { - STRING, - ZSET - }; - std::string key; - Type type; - std::string str_value; // For STRING type - ZSet* zset_value; // For ZSET type - - Entry(const std::string& k, const std::string& val) - : key(k), type(STRING), str_value(val), zset_value(nullptr) {} - ~Entry() { - if (type == ZSET && zset_value) { - delete zset_value; - } - } -}; - -enum ConnectionState { - STATE_REQ, - STATE_RES, - STATE_END -}; - -// Connection struct representing a client connection -struct Connection { - int fd; - ConnectionState state; - std::vector rbuf; - std::vector wbuf; - size_t wbuf_sent; - - static const size_t k_max_msg = 4096; - - Connection(int fd_) : fd(fd_), state(STATE_REQ), wbuf_sent(0) { - rbuf.reserve(4 + k_max_msg); - wbuf.reserve(4 + k_max_msg); - } -}; - -// Global variables -HashTable db; -std::vector fd2conn; - -// Function prototypes -void handle_command(const std::vector& cmd, std::string& out); -void out_string(std::string& out, const std::string& str); -void out_nil(std::string& out); -void out_int(std::string& out, int64_t val); -void out_error(std::string& out, const std::string& msg); -void out_ok(std::string& out); -void out_array(std::string& out, const std::vector& arr); -void out_double(std::string& out, double val); -int parse_request(const uint8_t* data, size_t len, std::vector& out); -void close_connection(Connection* conn); -void handle_connection_io(Connection* conn); -void accept_new_connection(int listen_fd); -void handle_read(Connection* conn); -void handle_write(Connection* conn); - -int main() { - int listen_fd = socket(AF_INET, SOCK_STREAM, 0); - if (listen_fd < 0) { - perror("socket"); - exit(1); - } - - int val = 1; - setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)); - - sockaddr_in addr = {}; - addr.sin_family = AF_INET; - addr.sin_port = htons(1234); - addr.sin_addr.s_addr = htonl(0); - if (bind(listen_fd, (sockaddr*)&addr, sizeof(addr)) < 0) { - perror("bind"); - exit(1); - } - - if (listen(listen_fd, SOMAXCONN) < 0) { - perror("listen"); - exit(1); - } - - fcntl(listen_fd, F_SETFL, O_NONBLOCK); - - while (true) { - std::vector pollfds; - pollfds.push_back({listen_fd, POLLIN, 0}); - for (Connection* conn : fd2conn) { - if (conn) { - pollfd pfd = {conn->fd, 0, 0}; - if (conn->state == STATE_REQ) { - pfd.events = POLLIN; - } else if (conn->state == STATE_RES) { - pfd.events = POLLOUT; - } - pfd.events |= POLLERR; - pollfds.push_back(pfd); - } - } - - int rv = poll(pollfds.data(), pollfds.size(), 1000); - if (rv < 0) { - perror("poll"); - exit(1); - } - - size_t idx = 0; - if (pollfds[idx++].revents & POLLIN) { - accept_new_connection(listen_fd); - } - - for (; idx < pollfds.size(); ++idx) { - pollfd& pfd = pollfds[idx]; - Connection* conn = fd2conn[pfd.fd]; - if (!conn) continue; - if (pfd.revents & (POLLIN | POLLOUT | POLLERR)) { - handle_connection_io(conn); - if (conn->state == STATE_END) { - close_connection(conn); - fd2conn[pfd.fd] = nullptr; - delete conn; - } - } - } - } - return 0; -} - -void accept_new_connection(int listen_fd) { - sockaddr_in client_addr; - socklen_t socklen = sizeof(client_addr); - int conn_fd = accept(listen_fd, (sockaddr*)&client_addr, &socklen); - if (conn_fd >= 0) { - fcntl(conn_fd, F_SETFL, O_NONBLOCK); - Connection* conn = new Connection(conn_fd); - if (fd2conn.size() <= (size_t)conn_fd) { - fd2conn.resize(conn_fd + 1, nullptr); - } - fd2conn[conn_fd] = conn; - } -} - -void close_connection(Connection* conn) { - close(conn->fd); -} - -void handle_connection_io(Connection* conn) { - if (conn->state == STATE_REQ) { - handle_read(conn); - } else if (conn->state == STATE_RES) { - handle_write(conn); - } -} - -void handle_read(Connection* conn) { - while (true) { - uint8_t buf[4096]; - ssize_t n = read(conn->fd, buf, sizeof(buf)); - if (n < 0) { - if (errno == EAGAIN) { - break; - } else { - perror("read"); - conn->state = STATE_END; - break; - } - } else if (n == 0) { - conn->state = STATE_END; - break; - } else { - conn->rbuf.insert(conn->rbuf.end(), buf, buf + n); - while (true) { - if (conn->rbuf.size() < 4) { - break; - } - uint32_t len = 0; - memcpy(&len, &conn->rbuf[0], 4); - if (len > Connection::k_max_msg) { - std::cerr << "Message too long\n"; - conn->state = STATE_END; - break; - } - if (conn->rbuf.size() < 4 + len) { - break; - } - std::vector cmd; - if (parse_request(&conn->rbuf[4], len, cmd) != 0) { - std::cerr << "Bad request\n"; - conn->state = STATE_END; - break; - } - std::string response; - handle_command(cmd, response); - uint32_t wlen = response.size(); - conn->wbuf.resize(4 + wlen); - memcpy(&conn->wbuf[0], &wlen, 4); - memcpy(&conn->wbuf[4], response.data(), wlen); - conn->wbuf_sent = 0; - conn->state = STATE_RES; - conn->rbuf.erase(conn->rbuf.begin(), conn->rbuf.begin() + 4 + len); - break; - } - } - } -} - -void handle_write(Connection* conn) { - while (conn->wbuf_sent < conn->wbuf.size()) { - ssize_t n = write(conn->fd, &conn->wbuf[conn->wbuf_sent], conn->wbuf.size() - conn->wbuf_sent); - if (n < 0) { - if (errno == EAGAIN) { - break; - } else { - perror("write"); - conn->state = STATE_END; - break; - } - } else { - conn->wbuf_sent += n; - } - } - if (conn->wbuf_sent == conn->wbuf.size()) { - conn->wbuf.clear(); - conn->wbuf_sent = 0; - conn->state = STATE_REQ; - } -} - -void handle_command(const std::vector& cmd, std::string& out) { - if (cmd.empty()) { - out_error(out, "Empty command"); - return; - } - const std::string& command = cmd[0]; - if (command == "get") { - if (cmd.size() != 2) { - out_error(out, "Invalid number of arguments for 'get'"); - return; - } - Entry* entry = db.get(cmd[1]); - if (entry && entry->type == Entry::STRING) { - out_string(out, entry->str_value); - } else { - out_nil(out); - } - } else if (command == "set") { - if (cmd.size() != 3) { - out_error(out, "Invalid number of arguments for 'set'"); - return; - } - Entry* entry = db.get(cmd[1]); - if (entry) { - if (entry->type != Entry::STRING) { - out_error(out, "Wrong type"); - return; - } - entry->str_value = cmd[2]; - } else { - entry = new Entry(cmd[1], cmd[2]); - db.put(cmd[1], entry); - } - out_ok(out); - } else if (command == "del") { - if (cmd.size() != 2) { - out_error(out, "Invalid number of arguments for 'del'"); - return; - } - db.remove(cmd[1]); - out_int(out, 1); - } else if (command == "keys") { - if (cmd.size() != 1) { - out_error(out, "Invalid number of arguments for 'keys'"); - return; - } - std::vector keys = db.keys(); - out_array(out, keys); - } else if (command == "zadd") { - if (cmd.size() != 4) { - out_error(out, "Invalid number of arguments for 'zadd'"); - return; - } - const std::string& key = cmd[1]; - double score = std::stod(cmd[2]); - const std::string& member = cmd[3]; - - Entry* entry = db.get(key); - if (!entry) { - entry = new Entry(key, ""); - entry->type = Entry::ZSET; - entry->zset_value = new ZSet(); - db.put(key, entry); - } - if (entry->type != Entry::ZSET) { - out_error(out, "Wrong type"); - return; - } - bool added = entry->zset_value->zadd(member, score); - out_int(out, added ? 1 : 0); - } else if (command == "zrem") { - if (cmd.size() != 3) { - out_error(out, "Invalid number of arguments for 'zrem'"); - return; - } - const std::string& key = cmd[1]; - const std::string& member = cmd[2]; - - Entry* entry = db.get(key); - if (!entry || entry->type != Entry::ZSET) { - out_error(out, "Wrong type or key does not exist"); - return; - } - bool removed = entry->zset_value->zrem(member); - out_int(out, removed ? 1 : 0); - } else if (command == "zscore") { - if (cmd.size() != 3) { - out_error(out, "Invalid number of arguments for 'zscore'"); - return; - } - const std::string& key = cmd[1]; - const std::string& member = cmd[2]; - - Entry* entry = db.get(key); - if (!entry || entry->type != Entry::ZSET) { - out_error(out, "Wrong type or key does not exist"); - return; - } - double score; - if (entry->zset_value->zscore(member, score)) { - out_double(out, score); - } else { - out_nil(out); - } - } else if (command == "zquery") { - if (cmd.size() != 6) { - out_error(out, "Invalid number of arguments for 'zquery'"); - return; - } - const std::string& key = cmd[1]; - double min_score = std::stod(cmd[2]); - const std::string& min_member = cmd[3]; - int offset = std::stoi(cmd[4]); - int limit = std::stoi(cmd[5]); - - Entry* entry = db.get(key); - if (!entry || entry->type != Entry::ZSET) { - out_error(out, "Wrong type or key does not exist"); - return; - } - std::vector> result = entry->zset_value->zquery(min_score, min_member, offset, limit); - out.push_back(SER_ARR); - uint32_t len = result.size() * 2; - out.append((char*)&len, 4); - for (const auto& pair : result) { - out_string(out, pair.first); - out_double(out, pair.second); - } - } else { - out_error(out, "Unknown command"); - } -} - -// Serialization functions -void out_string(std::string& out, const std::string& str) { - out.push_back(SER_STR); - uint32_t len = str.size(); - out.append((char*)&len, 4); - out.append(str); -} - -void out_nil(std::string& out) { - out.push_back(SER_NIL); -} - -void out_int(std::string& out, int64_t val) { - out.push_back(SER_INT); - out.append((char*)&val, 8); -} - -void out_error(std::string& out, const std::string& msg) { - out.push_back(SER_ERR); - uint32_t len = msg.size(); - out.append((char*)&len, 4); - out.append(msg); -} - -void out_ok(std::string& out) { - out_string(out, "OK"); -} - -void out_array(std::string& out, const std::vector& arr) { - out.push_back(SER_ARR); - uint32_t len = arr.size(); - out.append((char*)&len, 4); - for (const std::string& s : arr) { - out_string(out, s); - } -} - -void out_double(std::string& out, double val) { - out.push_back(SER_DBL); - out.append((char*)&val, 8); -} - -// Parsing request from client -int parse_request(const uint8_t* data, size_t len, std::vector& out) { - if (len < 4) { - return -1; - } - uint32_t argc = 0; - memcpy(&argc, data, 4); - if (argc > 1024) { - return -1; - } - size_t pos = 4; - for (uint32_t i = 0; i < argc; ++i) { - if (pos + 4 > len) { - return -1; - } - uint32_t arg_len = 0; - memcpy(&arg_len, data + pos, 4); - pos += 4; - if (pos + arg_len > len) { - return -1; - } - out.push_back(std::string((char*)data + pos, arg_len)); - pos += arg_len; - } - if (pos != len) { - return -1; - } - return 0; -} diff --git a/server/protocol.cpp b/server/protocol.cpp new file mode 100644 index 0000000..bbdb130 --- /dev/null +++ b/server/protocol.cpp @@ -0,0 +1,64 @@ +#include "server/protocol.h" +#include + +void out_string(std::string& out, const std::string& str) { + out.push_back(SER_STR); + uint32_t len = static_cast(str.size()); + out.append((char*)&len, 4); + out.append(str); +} + +void out_nil(std::string& out) { + out.push_back(SER_NIL); +} + +void out_int(std::string& out, int64_t val) { + out.push_back(SER_INT); + out.append((char*)&val, 8); +} + +void out_error(std::string& out, const std::string& msg) { + out.push_back(SER_ERR); + uint32_t len = static_cast(msg.size()); + out.append((char*)&len, 4); + out.append(msg); +} + +void out_ok(std::string& out) { + out_string(out, "OK"); +} + +void out_array(std::string& out, const std::vector& arr) { + out.push_back(SER_ARR); + uint32_t len = static_cast(arr.size()); + out.append((char*)&len, 4); + for (const std::string& s : arr) { + out_string(out, s); + } +} + +void out_double(std::string& out, double val) { + out.push_back(SER_DBL); + out.append((char*)&val, 8); +} + +// Parsing request from client (unchanged) +int parse_request(const uint8_t* data, size_t len, std::vector& out) { + if (len < 4) return -1; + uint32_t argc = 0; + std::memcpy(&argc, data, 4); + if (argc > 1024) return -1; + + size_t pos = 4; + for (uint32_t i = 0; i < argc; ++i) { + if (pos + 4 > len) return -1; + uint32_t arg_len = 0; + std::memcpy(&arg_len, data + pos, 4); + pos += 4; + if (pos + arg_len > len) return -1; + out.emplace_back((char*)data + pos, arg_len); + pos += arg_len; + } + if (pos != len) return -1; + return 0; +} diff --git a/server/server.cpp b/server/server.cpp new file mode 100644 index 0000000..ea83986 --- /dev/null +++ b/server/server.cpp @@ -0,0 +1,262 @@ +#include "server/server.h" +#include "server/protocol.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +Server::Server() = default; +Server::~Server() = default; + +void Server::run(int listen_fd) { + fcntl(listen_fd, F_SETFL, O_NONBLOCK); + + while (true) { + std::vector pollfds; + pollfds.push_back({listen_fd, POLLIN, 0}); + for (Connection* conn : fd2conn_) { + if (conn) { + pollfd pfd = {conn->fd, 0, 0}; + if (conn->state == STATE_REQ) { + pfd.events = POLLIN; + } else if (conn->state == STATE_RES) { + pfd.events = POLLOUT; + } + pfd.events |= POLLERR; + pollfds.push_back(pfd); + } + } + + int rv = ::poll(pollfds.data(), pollfds.size(), 1000); + if (rv < 0) { + perror("poll"); + exit(1); + } + + size_t idx = 0; + if (pollfds[idx++].revents & POLLIN) { + accept_new_connection(listen_fd); + } + + for (; idx < pollfds.size(); ++idx) { + pollfd& pfd = pollfds[idx]; + Connection* conn = fd2conn_[pfd.fd]; + if (!conn) continue; + if (pfd.revents & (POLLIN | POLLOUT | POLLERR)) { + handle_connection_io(conn); + if (conn->state == STATE_END) { + close_connection(conn); + fd2conn_[pfd.fd] = nullptr; + delete conn; + } + } + } + } +} + +void Server::accept_new_connection(int listen_fd) { + sockaddr_in client_addr{}; + socklen_t socklen = sizeof(client_addr); + int conn_fd = ::accept(listen_fd, (sockaddr*)&client_addr, &socklen); + if (conn_fd >= 0) { + fcntl(conn_fd, F_SETFL, O_NONBLOCK); + Connection* conn = new Connection(conn_fd); + if (fd2conn_.size() <= (size_t)conn_fd) { + fd2conn_.resize(conn_fd + 1, nullptr); + } + fd2conn_[conn_fd] = conn; + } +} + +void Server::close_connection(Connection* conn) { + ::close(conn->fd); +} + +void Server::handle_connection_io(Connection* conn) { + if (conn->state == STATE_REQ) { + handle_read(conn); + } else if (conn->state == STATE_RES) { + handle_write(conn); + } +} + +void Server::handle_read(Connection* conn) { + while (true) { + uint8_t buf[4096]; + ssize_t n = ::read(conn->fd, buf, sizeof(buf)); + if (n < 0) { + if (errno == EAGAIN) { + break; + } else { + perror("read"); + conn->state = STATE_END; + break; + } + } else if (n == 0) { + conn->state = STATE_END; + break; + } else { + conn->rbuf.insert(conn->rbuf.end(), buf, buf + n); + while (true) { + if (conn->rbuf.size() < 4) break; + uint32_t len = 0; + std::memcpy(&len, &conn->rbuf[0], 4); + if (len > Connection::k_max_msg) { + std::cerr << "Message too long\n"; + conn->state = STATE_END; + break; + } + if (conn->rbuf.size() < 4 + len) break; + + std::vector cmd; + if (parse_request(&conn->rbuf[4], len, cmd) != 0) { + std::cerr << "Bad request\n"; + conn->state = STATE_END; + break; + } + std::string response; + handle_command(cmd, response); + + uint32_t wlen = static_cast(response.size()); + conn->wbuf.resize(4 + wlen); + std::memcpy(&conn->wbuf[0], &wlen, 4); + std::memcpy(&conn->wbuf[4], response.data(), wlen); + conn->wbuf_sent = 0; + conn->state = STATE_RES; + + conn->rbuf.erase(conn->rbuf.begin(), conn->rbuf.begin() + 4 + len); + break; + } + } + } +} + +void Server::handle_write(Connection* conn) { + while (conn->wbuf_sent < conn->wbuf.size()) { + ssize_t n = ::write(conn->fd, &conn->wbuf[conn->wbuf_sent], conn->wbuf.size() - conn->wbuf_sent); + if (n < 0) { + if (errno == EAGAIN) { + break; + } else { + perror("write"); + conn->state = STATE_END; + break; + } + } else { + conn->wbuf_sent += static_cast(n); + } + } + if (conn->wbuf_sent == conn->wbuf.size()) { + conn->wbuf.clear(); + conn->wbuf_sent = 0; + conn->state = STATE_REQ; + } +} + +// ===== Commands (same logic as your original handle_command) ===== + +#include + +void Server::handle_command(const std::vector& cmd, std::string& out) { + if (cmd.empty()) { + out_error(out, "Empty command"); + return; + } + const std::string& command = cmd[0]; + + if (command == "get") { + if (cmd.size() != 2) { out_error(out, "Invalid number of arguments for 'get'"); return; } + Entry* entry = db_.get(cmd[1]); + if (entry && entry->type == Entry::STRING) out_string(out, entry->str_value); + else out_nil(out); + + } else if (command == "set") { + if (cmd.size() != 3) { out_error(out, "Invalid number of arguments for 'set'"); return; } + Entry* entry = db_.get(cmd[1]); + if (entry) { + if (entry->type != Entry::STRING) { out_error(out, "Wrong type"); return; } + entry->str_value = cmd[2]; + } else { + entry = new Entry(cmd[1], cmd[2]); + db_.put(cmd[1], entry); + } + out_ok(out); + + } else if (command == "del") { + if (cmd.size() != 2) { out_error(out, "Invalid number of arguments for 'del'"); return; } + db_.remove(cmd[1]); + out_int(out, 1); + + } else if (command == "keys") { + if (cmd.size() != 1) { out_error(out, "Invalid number of arguments for 'keys'"); return; } + std::vector keys = db_.keys(); + out_array(out, keys); + + } else if (command == "zadd") { + if (cmd.size() != 4) { out_error(out, "Invalid number of arguments for 'zadd'"); return; } + const std::string& key = cmd[1]; + double score = std::stod(cmd[2]); + const std::string& member = cmd[3]; + + Entry* entry = db_.get(key); + if (!entry) { + entry = new Entry(key, ""); + entry->type = Entry::ZSET; + entry->zset_value = new ZSet(); + db_.put(key, entry); + } + if (entry->type != Entry::ZSET) { out_error(out, "Wrong type"); return; } + bool added = entry->zset_value->zadd(member, score); + out_int(out, added ? 1 : 0); + + } else if (command == "zrem") { + if (cmd.size() != 3) { out_error(out, "Invalid number of arguments for 'zrem'"); return; } + const std::string& key = cmd[1]; + const std::string& member = cmd[2]; + + Entry* entry = db_.get(key); + if (!entry || entry->type != Entry::ZSET) { out_error(out, "Wrong type or key does not exist"); return; } + bool removed = entry->zset_value->zrem(member); + out_int(out, removed ? 1 : 0); + + } else if (command == "zscore") { + if (cmd.size() != 3) { out_error(out, "Invalid number of arguments for 'zscore'"); return; } + const std::string& key = cmd[1]; + const std::string& member = cmd[2]; + + Entry* entry = db_.get(key); + if (!entry || entry->type != Entry::ZSET) { out_error(out, "Wrong type or key does not exist"); return; } + double score; + if (entry->zset_value->zscore(member, score)) out_double(out, score); + else out_nil(out); + + } else if (command == "zquery") { + if (cmd.size() != 6) { out_error(out, "Invalid number of arguments for 'zquery'"); return; } + const std::string& key = cmd[1]; + double min_score = std::stod(cmd[2]); + const std::string& min_member = cmd[3]; + int offset = std::stoi(cmd[4]); + int limit = std::stoi(cmd[5]); + + Entry* entry = db_.get(key); + if (!entry || entry->type != Entry::ZSET) { out_error(out, "Wrong type or key does not exist"); return; } + std::vector> result = entry->zset_value->zquery(min_score, min_member, offset, limit); + + out.push_back(SER_ARR); + uint32_t len = static_cast(result.size() * 2); + out.append((char*)&len, 4); + for (const auto& pair : result) { + out_string(out, pair.first); + out_double(out, pair.second); + } + + } else { + out_error(out, "Unknown command"); + } +} diff --git a/zset.h b/server/zset.cpp similarity index 77% rename from zset.h rename to server/zset.cpp index 2eee25c..9c1c739 100644 --- a/zset.h +++ b/server/zset.cpp @@ -1,77 +1,4 @@ - -#pragma once - -#include -#include -#include -#include - -class ZSet { -private: - // Node structure for tree_by_score - class ScoreNode { - public: - double score; - std::string member; - int height; - ScoreNode* left; - ScoreNode* right; - - ScoreNode(double s, const std::string& m) - : score(s), member(m), height(1), left(nullptr), right(nullptr) {} - }; - - // Node structure for tree_by_member - class MemberNode { - public: - std::string member; - double score; - int height; - MemberNode* left; - MemberNode* right; - - MemberNode(const std::string& m, double s) - : member(m), score(s), height(1), left(nullptr), right(nullptr) {} - }; - - ScoreNode* tree_by_score; - MemberNode* tree_by_member; - -public: - ZSet(); - ~ZSet(); - - bool zadd(const std::string& member, double score); - bool zrem(const std::string& member); - bool zscore(const std::string& member, double& out_score); - std::vector> zquery(double min_score, const std::string& min_member, int offset, int limit); - -private: - // AVL tree functions for ScoreNode - int height(ScoreNode* node); - void updateHeight(ScoreNode* node); - int balanceFactor(ScoreNode* node); - ScoreNode* rotateLeft(ScoreNode* x); - ScoreNode* rotateRight(ScoreNode* y); - ScoreNode* balance(ScoreNode* node); - ScoreNode* insert(ScoreNode* node, double score, const std::string& member); - ScoreNode* remove(ScoreNode* node, double score, const std::string& member); - void inorder(ScoreNode* node, double min_score, const std::string& min_member, int& offset, int limit, std::vector>& result); - void destroy(ScoreNode* node); - - // AVL tree functions for MemberNode - int height(MemberNode* node); - void updateHeight(MemberNode* node); - int balanceFactor(MemberNode* node); - MemberNode* rotateLeft(MemberNode* x); - MemberNode* rotateRight(MemberNode* y); - MemberNode* balance(MemberNode* node); - MemberNode* insert(MemberNode* node, const std::string& member, double score); - MemberNode* remove(MemberNode* node, const std::string& member); - MemberNode* find(MemberNode* node, const std::string& member); - void destroy(MemberNode* node); -}; - +#include "server/zset.h" ZSet::ZSet() : tree_by_score(nullptr), tree_by_member(nullptr) {} diff --git a/server_test.cpp b/server_test.cpp deleted file mode 100644 index e69de29..0000000 diff --git a/tests/client_test.cpp b/tests/client_test.cpp new file mode 100644 index 0000000..eb208a7 --- /dev/null +++ b/tests/client_test.cpp @@ -0,0 +1,341 @@ +#include +#include +#include +#include +#include +#include + +#include "client/protocol.h" +#include "client/client_loop.h" +#include "client/transport.h" +#include "client/util.h" + +// ---------- Test helpers (build reply frames and verify requests) ---------- + +namespace testutil { + +// Build a single RESP-like body element (not framed) per your protocol. +inline std::vector ser_str(const std::string& s) { + std::vector b; + b.push_back(SER_STR); + uint32_t len = static_cast(s.size()); + const uint8_t* p = reinterpret_cast(&len); + b.insert(b.end(), p, p+4); + b.insert(b.end(), s.begin(), s.end()); + return b; +} + +inline std::vector ser_int(int64_t v) { + std::vector b; + b.push_back(SER_INT); + const uint8_t* p = reinterpret_cast(&v); + b.insert(b.end(), p, p+8); + return b; +} + +inline std::vector ser_dbl(double v) { + std::vector b; + b.push_back(SER_DBL); + const uint8_t* p = reinterpret_cast(&v); + b.insert(b.end(), p, p+8); + return b; +} + +inline std::vector ser_nil() { + return std::vector{ static_cast(SER_NIL) }; +} + +inline std::vector ser_err(const std::string& msg) { + std::vector b; + b.push_back(SER_ERR); + uint32_t len = static_cast(msg.size()); + const uint8_t* p = reinterpret_cast(&len); + b.insert(b.end(), p, p+4); + b.insert(b.end(), msg.begin(), msg.end()); + return b; +} + +inline std::vector ser_arr(const std::vector>& elements) { + std::vector b; + b.push_back(SER_ARR); + uint32_t len = static_cast(elements.size()); + const uint8_t* p = reinterpret_cast(&len); + b.insert(b.end(), p, p+4); + for (const auto& e : elements) { + b.insert(b.end(), e.begin(), e.end()); + } + return b; +} + +// Frame a body with a 4-byte little-endian length prefix (as your client expects). +inline std::vector frame(const std::vector& body) { + std::vector out(4); + uint32_t n = static_cast(body.size()); + std::memcpy(out.data(), &n, 4); + out.insert(out.end(), body.begin(), body.end()); + return out; +} + +// Convenience to build the expected request buffer for a given command vector. +inline std::vector build_request(const std::vector& cmd) { + std::vector r; + serialize_request(cmd, r); + return r; +} + +} // namespace testutil + +// ---------- Fake transport for REPL tests ---------- + +struct FakeTransport : ITransport { + // What the "server" will send to the client (a sequence of framed replies). + std::vector scripted_recv; + size_t recv_off = 0; + + // What the client sent to the server. + std::vector captured_send; + + void connect(const std::string&, uint16_t) override {} + void send_all(const uint8_t* data, size_t len) override { + captured_send.insert(captured_send.end(), data, data + len); + } + void recv_all(uint8_t* data, size_t len) override { + if (recv_off + len > scripted_recv.size()) { + // Simulate EOF/failure -> this will trigger die("recv failed"), so we just mimic short read + // but tests that want death should ensure this. + std::memset(data, 0, len); + return; + } + std::memcpy(data, scripted_recv.data() + recv_off, len); + recv_off += len; + } + void close() override {} +}; + +// ======================= Protocol: serialize tests ========================== + +TEST(ProtocolSerialize, CorrectHeaderAndArgs) { + std::vector cmd = {"SET", "k", "v"}; + auto buf = testutil::build_request(cmd); + + ASSERT_GE(buf.size(), 4u); + uint32_t total_len = 0; + std::memcpy(&total_len, buf.data(), 4); + EXPECT_EQ(total_len + 4, buf.size()); // first 4 bytes = payload size + + // argc at offset 4 + uint32_t argc = 0; + std::memcpy(&argc, buf.data() + 4, 4); + EXPECT_EQ(argc, 3u); + + // First arg length should start at offset 8. + uint32_t arg0_len = 0; + std::memcpy(&arg0_len, buf.data() + 8, 4); + EXPECT_EQ(arg0_len, 3u); // "SET" +} + +// ======================= Protocol: deserialize tests ======================== + +TEST(ProtocolDeserialize, StringIntDoubleNilError) { + using namespace testutil; + auto body = ser_arr({ + ser_str("hello"), + ser_int(42), + ser_dbl(3.5), + ser_nil(), + ser_err("boom"), + }); + + size_t off = 0; + std::ostringstream out, err; + deserialize_response(body, off, out, err); + + // Order matters and each prints with '\n' + std::string expected = + "hello\n" + "42\n" + "3.5\n" + "(nil)\n"; + EXPECT_EQ(out.str(), expected); + EXPECT_EQ(err.str(), "(error) boom\n"); + EXPECT_EQ(off, body.size()); +} + +TEST(ProtocolDeserialize, NestedArrays) { + using namespace testutil; + auto body = ser_arr({ + ser_str("outer"), + ser_arr({ + ser_str("inner1"), + ser_int(7), + }), + ser_str("tail"), + }); + + size_t off = 0; + std::ostringstream out, err; + deserialize_response(body, off, out, err); + std::string expected = + "outer\n" + "inner1\n" + "7\n" + "tail\n"; + EXPECT_EQ(out.str(), expected); + EXPECT_TRUE(err.str().empty()); +} + +TEST(ProtocolDeserializeDeath, BadLengthDies) { + using namespace testutil; + // Construct a string with declared length bigger than buffer to trigger die() + std::vector body; + body.push_back(SER_STR); + uint32_t len = 1000; // lie + const uint8_t* p = reinterpret_cast(&len); + body.insert(body.end(), p, p+4); + body.push_back('x'); // only 1 byte of actual data + + size_t off = 0; + std::ostringstream out, err; + // die() calls std::exit, so use a death test. + EXPECT_DEATH(deserialize_response(body, off, out, err), + "Response parsing error: string data"); +} + +// ======================= REPL tests via FakeTransport ======================= + +TEST(ClientRepl, SingleCommandRoundTrip) { + using namespace testutil; + + // Simulate server replying "PONG" after "PING" + auto reply_body = ser_str("PONG"); + auto framed = frame(reply_body); + + FakeTransport t; + t.scripted_recv = framed; // Repl will read header(4), then body + + std::istringstream in("PING\n"); // user enters one line + std::ostringstream out, err; + + // The REPL prints a prompt before reading and again after handling the line. + int rc = run_client_repl(t, in, out, err); + EXPECT_EQ(rc, 0); + + // Out should contain prompt, then server output, then prompt (blocked by EOF so not printed again) + // Our loop prints prompt before each getline; after last iteration (EOF), it exits without printing a final prompt. + EXPECT_EQ(out.str(), "> PONG\n> "); + + // Verify the request that got sent matches protocol encoding for ["PING"] + auto expected_req = build_request({"PING"}); + EXPECT_EQ(t.captured_send, expected_req); + EXPECT_TRUE(err.str().empty()); +} + +TEST(ClientRepl, MultipleCommandsAndEmptyLines) { + using namespace testutil; + + // Server replies 3 lines for three non-empty inputs; empty lines are ignored by client. + std::vector script; + auto a = frame(ser_int(1)); + auto b = frame(ser_str("two")); + auto c = frame(ser_nil()); + script.insert(script.end(), a.begin(), a.end()); + script.insert(script.end(), b.begin(), b.end()); + script.insert(script.end(), c.begin(), c.end()); + + FakeTransport t; + t.scripted_recv = script; + + std::istringstream in("\nINCR\nECHO two\nNULL\n"); + std::ostringstream out, err; + + int rc = run_client_repl(t, in, out, err); + EXPECT_EQ(rc, 0); + + // One prompt per getline; empty first line prints a prompt but no output; then 3 outputs. + std::string expected_out = + "> " // before empty line + "> " // before INCR + "1\n" + "> " // before ECHO two + "two\n" + "> " // before NULL + "(nil)\n" + "> "; // after last line (EOF, so no more output) + EXPECT_EQ(out.str(), expected_out); + EXPECT_TRUE(err.str().empty()); + + // Verify three requests were sent in order + auto r1 = build_request({"INCR"}); + auto r2 = build_request({"ECHO", "two"}); + auto r3 = build_request({"NULL"}); + + // Concatenate expected requests; REPL sends each back-to-back. + std::vector expected; + expected.insert(expected.end(), r1.begin(), r1.end()); + expected.insert(expected.end(), r2.begin(), r2.end()); + expected.insert(expected.end(), r3.begin(), r3.end()); + EXPECT_EQ(t.captured_send, expected); +} + +TEST(ClientReplDeath, ResponseTooLarge) { + FakeTransport t; + + // Header says > 10MB which should trigger die("Response too large") + uint32_t huge = 10 * 1024 * 1024 + 1; + std::vector framed(4); + std::memcpy(framed.data(), &huge, 4); + t.scripted_recv = framed; + + std::istringstream in("ANY\n"); + std::ostringstream out, err; + + EXPECT_DEATH( + { + run_client_repl(t, in, out, err); + }, + "Response too large" + ); +} + +// ======================= Integration-ish serialization check ================= + +TEST(ProtocolSerialize, ArgumentsMarshaledCorrectly) { + std::vector cmd = {"MSET", "key", "value with space", "123"}; + std::vector buf; + serialize_request(cmd, buf); + + // Read argc + uint32_t argc = 0; + std::memcpy(&argc, buf.data() + 4, 4); + ASSERT_EQ(argc, 4u); + + size_t off = 8; + auto read_len = [&](uint32_t& L) { + std::memcpy(&L, buf.data() + off, 4); + off += 4; + }; + auto read_bytes = [&](std::string& s, uint32_t L) { + s.assign(reinterpret_cast(buf.data() + off), L); + off += L; + }; + + uint32_t L0; read_len(L0); + std::string a0; read_bytes(a0, L0); + EXPECT_EQ(a0, "MSET"); + + uint32_t L1; read_len(L1); + std::string a1; read_bytes(a1, L1); + EXPECT_EQ(a1, "key"); + + uint32_t L2; read_len(L2); + std::string a2; read_bytes(a2, L2); + EXPECT_EQ(a2, "value with space"); + + uint32_t L3; read_len(L3); + std::string a3; read_bytes(a3, L3); + EXPECT_EQ(a3, "123"); + + // Buffer should end exactly here + EXPECT_EQ(off, buf.size()); +} + diff --git a/tests/server_test.cpp b/tests/server_test.cpp new file mode 100644 index 0000000..bab6017 --- /dev/null +++ b/tests/server_test.cpp @@ -0,0 +1,306 @@ +#include +#include +#include +#include +#include +#include +#include "server/server.h" +#include "server/protocol.h" + +// -------- Test-only decoder for the server's serialization format -------- +enum TKind { T_STR, T_NIL, T_INT, T_ERR, T_ARR, T_DBL }; + +struct TVal { + TKind kind; + std::string s; + int64_t i = 0; + double d = 0.0; + std::vector arr; +}; + +static bool parseOne(const std::string& buf, size_t& pos, TVal& out) { + if (pos >= buf.size()) return false; + uint8_t tag = static_cast(buf[pos++]); + switch (tag) { + case SER_STR: { + if (pos + 4 > buf.size()) return false; + uint32_t len = 0; std::memcpy(&len, &buf[pos], 4); pos += 4; + if (pos + len > buf.size()) return false; + out.kind = T_STR; out.s.assign(&buf[pos], &buf[pos + len]); pos += len; + return true; + } + case SER_NIL: + out.kind = T_NIL; return true; + case SER_INT: { + if (pos + 8 > buf.size()) return false; + int64_t v = 0; std::memcpy(&v, &buf[pos], 8); pos += 8; + out.kind = T_INT; out.i = v; return true; + } + case SER_ERR: { + if (pos + 4 > buf.size()) return false; + uint32_t len = 0; std::memcpy(&len, &buf[pos], 4); pos += 4; + if (pos + len > buf.size()) return false; + out.kind = T_ERR; out.s.assign(&buf[pos], &buf[pos + len]); pos += len; + return true; + } + case SER_ARR: { + if (pos + 4 > buf.size()) return false; + uint32_t n = 0; std::memcpy(&n, &buf[pos], 4); pos += 4; + out.kind = T_ARR; out.arr.clear(); out.arr.reserve(n); + for (uint32_t k = 0; k < n; ++k) { + TVal elem; + if (!parseOne(buf, pos, elem)) return false; + out.arr.push_back(std::move(elem)); + } + return true; + } + case SER_DBL: { + if (pos + 8 > buf.size()) return false; + double v = 0; std::memcpy(&v, &buf[pos], 8); pos += 8; + out.kind = T_DBL; out.d = v; return true; + } + default: + return false; + } +} + +static TVal decode(const std::string& buf) { + size_t pos = 0; + TVal v{}; + bool ok = parseOne(buf, pos, v); + if (!ok || pos != buf.size()) { + // For tests, fail hard with a clear message. + ADD_FAILURE() << "Failed to decode response buffer (size=" << buf.size() << ", pos=" << pos << ")"; + } + return v; +} + +// ------------------------------ Tests ------------------------------ + +TEST(ServerCommands, EmptyCommandErrors) { + Server s; + std::string out; + s.handle_command({}, out); + TVal v = decode(out); + ASSERT_EQ(v.kind, T_ERR); + EXPECT_NE(v.s.find("Empty command"), std::string::npos); +} + +TEST(ServerCommands, UnknownCommandErrors) { + Server s; + std::string out; + s.handle_command({"wtf"}, out); + TVal v = decode(out); + ASSERT_EQ(v.kind, T_ERR); + EXPECT_NE(v.s.find("Unknown command"), std::string::npos); +} + +TEST(ServerCommands, SetGetRoundTrip) { + Server s; + std::string out; + + s.handle_command({"set", "k1", "v1"}, out); + auto ok = decode(out); + ASSERT_EQ(ok.kind, T_STR); // out_ok => "OK" + EXPECT_EQ(ok.s, "OK"); + + out.clear(); + s.handle_command({"get", "k1"}, out); + auto got = decode(out); + ASSERT_EQ(got.kind, T_STR); + EXPECT_EQ(got.s, "v1"); +} + +TEST(ServerCommands, GetMissingIsNil) { + Server s; + std::string out; + s.handle_command({"get", "nope"}, out); + auto v = decode(out); + ASSERT_EQ(v.kind, T_NIL); +} + +TEST(ServerCommands, DelAlwaysReturnsOne) { + Server s; + std::string out; + + // Delete nonexistent + s.handle_command({"del", "missing"}, out); + auto v1 = decode(out); + ASSERT_EQ(v1.kind, T_INT); + EXPECT_EQ(v1.i, 1); + + out.clear(); + // Put something then delete + s.handle_command({"set", "todel", "x"}, out); + out.clear(); + s.handle_command({"del", "todel"}, out); + auto v2 = decode(out); + ASSERT_EQ(v2.kind, T_INT); + EXPECT_EQ(v2.i, 1); +} + +TEST(ServerCommands, KeysListsAllKeys) { + Server s; + std::string out; + + s.handle_command({"set", "a", "1"}, out); out.clear(); + s.handle_command({"set", "b", "2"}, out); out.clear(); + s.handle_command({"set", "c", "3"}, out); out.clear(); + + s.handle_command({"keys"}, out); + auto v = decode(out); + ASSERT_EQ(v.kind, T_ARR); + + std::vector got; + for (auto& e : v.arr) { + ASSERT_EQ(e.kind, T_STR); + got.push_back(e.s); + } + // Order is unspecified, so check membership. + auto has = [&](const char* k){ return std::find(got.begin(), got.end(), k) != got.end(); }; + EXPECT_TRUE(has("a")); + EXPECT_TRUE(has("b")); + EXPECT_TRUE(has("c")); +} + +TEST(ServerCommands, WrongTypeErrors) { + Server s; + std::string out; + + // Create a ZSET key + s.handle_command({"zadd", "myz", "1.5", "m"}, out); out.clear(); + + // set on ZSET -> error + s.handle_command({"set", "myz", "xxx"}, out); + auto e1 = decode(out); + ASSERT_EQ(e1.kind, T_ERR); + EXPECT_NE(e1.s.find("Wrong type"), std::string::npos); + + out.clear(); + // zadd on STRING -> error + s.handle_command({"set", "skey", "str"}, out); out.clear(); + s.handle_command({"zadd", "skey", "2.0", "m"}, out); + auto e2 = decode(out); + ASSERT_EQ(e2.kind, T_ERR); + EXPECT_NE(e2.s.find("Wrong type"), std::string::npos); + + out.clear(); + // zscore on STRING -> error + s.handle_command({"zscore", "skey", "m"}, out); + auto e3 = decode(out); + ASSERT_EQ(e3.kind, T_ERR); + EXPECT_NE(e3.s.find("Wrong type"), std::string::npos); +} + +TEST(ServerCommands, ZAddZScoreZRem) { + Server s; + std::string out; + + // zadd returns int (implementation always 1) + s.handle_command({"zadd", "myz", "10.5", "alice"}, out); + auto a1 = decode(out); + ASSERT_EQ(a1.kind, T_INT); + EXPECT_EQ(a1.i, 1); + out.clear(); + + // zscore -> 10.5 + s.handle_command({"zscore", "myz", "alice"}, out); + auto sc1 = decode(out); + ASSERT_EQ(sc1.kind, T_DBL); + EXPECT_DOUBLE_EQ(sc1.d, 10.5); + out.clear(); + + // update score (still returns 1) + s.handle_command({"zadd", "myz", "12.0", "alice"}, out); + auto a2 = decode(out); + ASSERT_EQ(a2.kind, T_INT); + EXPECT_EQ(a2.i, 1); + out.clear(); + + s.handle_command({"zscore", "myz", "alice"}, out); + auto sc2 = decode(out); + ASSERT_EQ(sc2.kind, T_DBL); + EXPECT_DOUBLE_EQ(sc2.d, 12.0); + out.clear(); + + // remove + s.handle_command({"zrem", "myz", "alice"}, out); + auto r1 = decode(out); + ASSERT_EQ(r1.kind, T_INT); + EXPECT_EQ(r1.i, 1); + out.clear(); + + // score now NIL + s.handle_command({"zscore", "myz", "alice"}, out); + auto scnil = decode(out); + ASSERT_EQ(scnil.kind, T_NIL); + out.clear(); + + // removing again -> 0 + s.handle_command({"zrem", "myz", "alice"}, out); + auto r0 = decode(out); + ASSERT_EQ(r0.kind, T_INT); + EXPECT_EQ(r0.i, 0); +} + +TEST(ServerCommands, ZScoreMissingMemberIsNil) { + Server s; std::string out; + s.handle_command({"zadd", "myz", "1", "a"}, out); out.clear(); + s.handle_command({"zscore", "myz", "nope"}, out); + auto v = decode(out); + ASSERT_EQ(v.kind, T_NIL); +} + +TEST(ServerCommands, ZQueryPagingAndOrdering) { + Server s; std::string out; + // Prepare sorted set with multiple members + s.handle_command({"zadd", "myz", "1.0", "a"}, out); out.clear(); + s.handle_command({"zadd", "myz", "1.0", "b"}, out); out.clear(); + s.handle_command({"zadd", "myz", "2.0", "c"}, out); out.clear(); + s.handle_command({"zadd", "myz", "3.0", "d"}, out); out.clear(); + + // Query from min_score=1.0, min_member="b", offset=0, limit=3 + s.handle_command({"zquery", "myz", "1.0", "b", "0", "3"}, out); + auto q = decode(out); + ASSERT_EQ(q.kind, T_ARR); + // Response is [member, score, member, score, ...] + ASSERT_EQ(q.arr.size(), 3u * 2u); + ASSERT_EQ(q.arr[0].kind, T_STR); EXPECT_EQ(q.arr[0].s, "b"); + ASSERT_EQ(q.arr[1].kind, T_DBL); EXPECT_DOUBLE_EQ(q.arr[1].d, 1.0); + ASSERT_EQ(q.arr[2].kind, T_STR); EXPECT_EQ(q.arr[2].s, "c"); + ASSERT_EQ(q.arr[3].kind, T_DBL); EXPECT_DOUBLE_EQ(q.arr[3].d, 2.0); + ASSERT_EQ(q.arr[4].kind, T_STR); EXPECT_EQ(q.arr[4].s, "d"); + ASSERT_EQ(q.arr[5].kind, T_DBL); EXPECT_DOUBLE_EQ(q.arr[5].d, 3.0); + + out.clear(); + // With offset=1, limit=2 from same start: expect c, d + s.handle_command({"zquery", "myz", "1.0", "b", "1", "2"}, out); + auto q2 = decode(out); + ASSERT_EQ(q2.kind, T_ARR); + ASSERT_EQ(q2.arr.size(), 2u * 2u); + EXPECT_EQ(q2.arr[0].s, "c"); + EXPECT_DOUBLE_EQ(q2.arr[1].d, 2.0); + EXPECT_EQ(q2.arr[2].s, "d"); + EXPECT_DOUBLE_EQ(q2.arr[3].d, 3.0); +} + +TEST(ServerCommands, ArityErrors) { + Server s; std::string out; + + s.handle_command({"get"}, out); + auto e1 = decode(out); + ASSERT_EQ(e1.kind, T_ERR); + EXPECT_NE(e1.s.find("Invalid number of arguments"), std::string::npos); + out.clear(); + + s.handle_command({"set", "konly"}, out); + auto e2 = decode(out); + ASSERT_EQ(e2.kind, T_ERR); + EXPECT_NE(e2.s.find("Invalid number of arguments"), std::string::npos); + out.clear(); + + s.handle_command({"zadd", "myz", "1.0"}, out); + auto e3 = decode(out); + ASSERT_EQ(e3.kind, T_ERR); + EXPECT_NE(e3.s.find("Invalid number of arguments"), std::string::npos); +}