diff --git a/CMakeLists.txt b/CMakeLists.txt index 8de0cb5e..0fd094fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -98,6 +98,10 @@ if (QUESTDB_TESTS_AND_EXAMPLES) line_sender_c_example examples/concat.c examples/line_sender_c_example.c) + compile_example( + line_sender_c_example_array + examples/concat.c + examples/line_sender_c_example_array.c) compile_example( line_sender_c_example_auth examples/concat.c @@ -123,6 +127,9 @@ if (QUESTDB_TESTS_AND_EXAMPLES) compile_example( line_sender_cpp_example examples/line_sender_cpp_example.cpp) + compile_example( + line_sender_cpp_example_array + examples/line_sender_cpp_example_array.cpp) compile_example( line_sender_cpp_example_auth examples/line_sender_cpp_example_auth.cpp) diff --git a/ci/compile.yaml b/ci/compile.yaml index aa9e1059..4c892ce4 100644 --- a/ci/compile.yaml +++ b/ci/compile.yaml @@ -4,6 +4,10 @@ steps: rustup default $(toolchain) condition: ne(variables['toolchain'], '') displayName: "Update and set Rust toolchain" + - script: | + python -m pip install --upgrade pip + pip install numpy + displayName: 'Install Python Dependencies' - script: cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DQUESTDB_TESTS_AND_EXAMPLES=ON env: JAVA_HOME: $(JAVA_HOME_11_X64) diff --git a/ci/run_all_tests.py b/ci/run_all_tests.py index b4e4dc4a..ee261896 100644 --- a/ci/run_all_tests.py +++ b/ci/run_all_tests.py @@ -41,7 +41,7 @@ def main(): build_cxx20_dir.glob(f'**/test_line_sender{exe_suffix}'))) system_test_path = pathlib.Path('system_test') / 'test.py' - qdb_v = '8.2.3' # The version of QuestDB we'll test against. + #qdb_v = '8.2.3' # The version of QuestDB we'll test against. run_cmd('cargo', 'test', '--', '--nocapture', cwd='questdb-rs') @@ -53,7 +53,8 @@ def main(): '--', '--nocapture', cwd='questdb-rs') run_cmd(str(test_line_sender_path)) run_cmd(str(test_line_sender_path_CXX20)) - run_cmd('python3', str(system_test_path), 'run', '--versions', qdb_v, '-v') + #run_cmd('python3', str(system_test_path), 'run', '--versions', qdb_v, '-v') + run_cmd('python3', str(system_test_path), 'run', '--repo', './questdb_nd_arr', '-v') if __name__ == '__main__': diff --git a/ci/run_tests_pipeline.yaml b/ci/run_tests_pipeline.yaml index 1fced6be..03b8d129 100644 --- a/ci/run_tests_pipeline.yaml +++ b/ci/run_tests_pipeline.yaml @@ -54,14 +54,25 @@ stages: cd questdb-rs cargo build --examples --features almost-all-features displayName: "Build Rust examples" + ############################# temp for test begin ##################### + - script: | + git clone -b nd_arr --depth 1 https://github.com/questdb/questdb.git ./questdb_nd_arr + displayName: git clone questdb + - task: Maven@3 + displayName: "Compile QuestDB" + inputs: + mavenPOMFile: 'questdb_nd_arr/pom.xml' + jdkVersionOption: '1.11' + options: "-DskipTests -Pbuild-web-console" + ############################# temp for test end ##################### - script: python3 ci/run_all_tests.py env: JAVA_HOME: $(JAVA_HOME_11_X64) displayName: "Tests" - - task: PublishBuildArtifacts@1 - inputs: - pathToPublish: ./build - displayName: "Publish build directory" + # - task: PublishBuildArtifacts@1 + # inputs: + # pathToPublish: ./build + # displayName: "Publish build directory" - job: FormatAndLinting displayName: "cargo fmt and clippy" pool: @@ -115,7 +126,7 @@ stages: submodules: false - template: compile.yaml - script: | - git clone --depth 1 https://github.com/questdb/questdb.git + git clone -b nd_arr --depth 1 https://github.com/questdb/questdb.git displayName: git clone questdb - task: Maven@3 displayName: "Compile QuestDB" diff --git a/cpp_test/mock_server.cpp b/cpp_test/mock_server.cpp index 0e8a1e85..e21d77d0 100644 --- a/cpp_test/mock_server.cpp +++ b/cpp_test/mock_server.cpp @@ -25,6 +25,7 @@ #include "mock_server.hpp" #include <string.h> +#include <string> #if defined(PLATFORM_UNIX) # include <fcntl.h> @@ -190,43 +191,81 @@ bool mock_server::wait_for_data(std::optional<double> wait_timeout_sec) return !!count; } +int32_t bytes_to_int32_le(const std::byte* bytes) +{ + return static_cast<int32_t>( + (bytes[0] << 0) | (bytes[1] << 8) | (bytes[2] << 16) | + (bytes[3] << 24)); +} + size_t mock_server::recv(double wait_timeout_sec) { if (!wait_for_data(wait_timeout_sec)) return 0; - char chunk[1024]; + std::byte chunk[1024]; size_t chunk_len{sizeof(chunk)}; - std::vector<char> accum; + std::vector<std::byte> accum; for (;;) { wait_for_data(); - sock_ssize_t count = - ::recv(_conn_fd, &chunk[0], static_cast<sock_len_t>(chunk_len), 0); + sock_ssize_t count = ::recv( + _conn_fd, + reinterpret_cast<char*>(&chunk[0]), + static_cast<sock_len_t>(chunk_len), + 0); if (count == -1) throw std::runtime_error{"Bad `recv()`."}; const size_t u_count = static_cast<size_t>(count); accum.insert(accum.end(), chunk, chunk + u_count); if (accum.size() < 2) continue; - if ((accum[accum.size() - 1] == '\n') && - (accum[accum.size() - 2] != '\\')) + if ((accum[accum.size() - 1] == std::byte('\n')) && + (accum[accum.size() - 2] != std::byte('\\'))) break; } size_t received_count{0}; - const char* head{&accum[0]}; - for (size_t index = 1; index < accum.size(); ++index) + const std::byte* head{&accum[0]}; + size_t index{1}; + while (index < accum.size()) { - const char& last = accum[index]; - const char& prev = accum[index - 1]; - if ((last == '\n') && (prev != '\\')) + const std::byte& last = accum[index]; + const std::byte& prev = accum[index - 1]; + if (last == std::byte('=') && prev == std::byte('=')) + { + index++; + std::byte& binary_type = accum[index]; + if (binary_type == std::byte(16)) // DOUBLE_BINARY_FORMAT_TYPE + index += sizeof(double) + 1; + else if (binary_type == std::byte(14)) // ARRAY_BINARY_FORMAT_TYPE + { + index++; + const std::byte& array_elem_type = accum[index]; + if (array_elem_type == std::byte(10)) + { + index++; + const size_t dims = size_t(accum[index]); + index++; + size_t data_size{sizeof(double)}; + for (size_t i = 0; i < dims; i++) + { + data_size *= bytes_to_int32_le(&accum[index]); + index += sizeof(int32_t); + } + index += data_size; + } + } + continue; + } + else if ((last == std::byte('\n')) && (prev != std::byte('\\'))) { - const char* tail{&last + 1}; - _msgs.emplace_back(head, tail - head); + const std::byte* tail{&last + 1}; + _msgs.emplace_back(head, tail); head = tail; ++received_count; } + index++; } return received_count; } diff --git a/cpp_test/mock_server.hpp b/cpp_test/mock_server.hpp index ba66efb0..8adb6ad7 100644 --- a/cpp_test/mock_server.hpp +++ b/cpp_test/mock_server.hpp @@ -24,13 +24,17 @@ #pragma once +#include <cassert> #include <vector> -#include <string> #include <cstdint> #include <optional> #include <stdexcept> - #include "build_env.h" +#if __cplusplus < 202002L +# include "questdb/ingress/line_sender.hpp" +#else +# include <span> +#endif #if defined(PLATFORM_UNIX) typedef int socketfd_t; @@ -60,9 +64,14 @@ class mock_server size_t recv(double wait_timeout_sec = 0.1); - const std::vector<std::string>& msgs() const +#if __cplusplus >= 202002L + using buffer_view = std::span<const std::byte>; +#endif + + buffer_view msgs(size_t index) const { - return _msgs; + assert(index < _msgs.size()); + return {_msgs[index].data(), _msgs[index].size()}; } void close(); @@ -75,7 +84,7 @@ class mock_server socketfd_t _listen_fd; socketfd_t _conn_fd; uint16_t _port; - std::vector<std::string> _msgs; + std::vector<std::vector<std::byte>> _msgs; }; } // namespace questdb::ingress::test diff --git a/cpp_test/test_line_sender.cpp b/cpp_test/test_line_sender.cpp index 1510aef2..e6f4f663 100644 --- a/cpp_test/test_line_sender.cpp +++ b/cpp_test/test_line_sender.cpp @@ -56,6 +56,68 @@ class on_scope_exit F _f; }; +#if __cplusplus >= 202002L +template <size_t N> +bool operator==(std::span<const std::byte> lhs, const char (&rhs)[N]) +{ + constexpr size_t bytelen = N - 1; // Exclude null terminator + const std::span<const std::byte> rhs_span{ + reinterpret_cast<const std::byte*>(rhs), bytelen}; + return lhs.size() == bytelen && std::ranges::equal(lhs, rhs_span); +} + +bool operator==(std::span<const std::byte> lhs, const std::string& rhs) +{ + const std::span<const std::byte> rhs_span{ + reinterpret_cast<const std::byte*>(rhs.data()), rhs.size()}; + return lhs.size() == rhs.size() && std::ranges::equal(lhs, rhs_span); +} +#else +template <size_t N> +bool operator==( + const questdb::ingress::buffer_view lhs_view, const char (&rhs)[N]) +{ + constexpr size_t bytelen = N - 1; // Exclude null terminator + const questdb::ingress::buffer_view rhs_view{ + reinterpret_cast<const std::byte*>(rhs), bytelen}; + return lhs_view == rhs_view; +} + +bool operator==( + const questdb::ingress::buffer_view lhs_view, const std::string& rhs) +{ + const questdb::ingress::buffer_view rhs_view{ + reinterpret_cast<const std::byte*>(rhs.data()), rhs.size()}; + return lhs_view == rhs_view; +} +#endif + +template <size_t N> +std::string& push_double_arr_to_buffer( + std::string& buffer, + std::array<double, N> data, + size_t rank, + uintptr_t* shapes) +{ + buffer.push_back(14); + buffer.push_back(10); + buffer.push_back(static_cast<char>(rank)); + for (size_t i = 0; i < rank; ++i) + buffer.append( + reinterpret_cast<const char*>(&shapes[i]), sizeof(uint32_t)); + buffer.append( + reinterpret_cast<const char*>(data.data()), + data.size() * sizeof(double)); + return buffer; +} + +std::string& push_double_to_buffer(std::string& buffer, double data) +{ + buffer.push_back(16); + buffer.append(reinterpret_cast<const char*>(&data), sizeof(double)); + return buffer; +} + TEST_CASE("line_sender c api basics") { questdb::ingress::test::mock_server server; @@ -95,13 +157,45 @@ TEST_CASE("line_sender c api basics") CHECK(::line_sender_buffer_table(buffer, table_name, &err)); CHECK(::line_sender_buffer_symbol(buffer, t1_name, v1_utf8, &err)); CHECK(::line_sender_buffer_column_f64(buffer, f1_name, 0.5, &err)); + + line_sender_column_name arr_name = QDB_COLUMN_NAME_LITERAL("a1"); + // 3D array of doubles + size_t rank = 3; + uintptr_t shapes[] = {2, 3, 2}; + intptr_t strides[] = {48, 16, 8}; + std::array<double, 12> arr_data = { + 48123.5, + 2.4, + 48124.0, + 1.8, + 48124.5, + 0.9, + 48122.5, + 3.1, + 48122.0, + 2.7, + 48121.5, + 4.3}; + CHECK(::line_sender_buffer_column_f64_arr( + buffer, + arr_name, + rank, + shapes, + strides, + reinterpret_cast<uint8_t*>(arr_data.data()), + sizeof(arr_data), + &err)); CHECK(::line_sender_buffer_at_nanos(buffer, 10000000, &err)); CHECK(server.recv() == 0); - CHECK(::line_sender_buffer_size(buffer) == 27); + CHECK(::line_sender_buffer_size(buffer) == 150); CHECK(::line_sender_flush(sender, buffer, &err)); ::line_sender_buffer_free(buffer); CHECK(server.recv() == 1); - CHECK(server.msgs().front() == "test,t1=v1 f1=0.5 10000000\n"); + std::string expect{"test,t1=v1 f1=="}; + push_double_to_buffer(expect, 0.5).append(",a1=="); + push_double_arr_to_buffer(expect, arr_data, 3, shapes) + .append(" 10000000\n"); + CHECK(server.msgs(0) == expect); } TEST_CASE("Opts service API tests") @@ -157,10 +251,12 @@ TEST_CASE("line_sender c++ api basics") .at(questdb::ingress::timestamp_nanos{10000000}); CHECK(server.recv() == 0); - CHECK(buffer.size() == 31); + CHECK(buffer.size() == 38); sender.flush(buffer); CHECK(server.recv() == 1); - CHECK(server.msgs().front() == "test,t1=v1,t2= f1=0.5 10000000\n"); + std::string expect{"test,t1=v1,t2= f1=="}; + push_double_to_buffer(expect, 0.5).append(" 10000000\n"); + CHECK(server.msgs(0) == expect); } TEST_CASE("test multiple lines") @@ -193,16 +289,15 @@ TEST_CASE("test multiple lines") .at_now(); CHECK(server.recv() == 0); - CHECK(buffer.size() == 137); + CHECK(buffer.size() == 142); sender.flush(buffer); CHECK(server.recv() == 2); + std::string expect{"metric1,t1=val1,t2=val2 f1=t,f2=12345i,f3=="}; + push_double_to_buffer(expect, 10.75) + .append(",f4=\"val3\",f5=\"val4\",f6=\"val5\" 111222233333\n"); + CHECK(server.msgs(0) == expect); CHECK( - server.msgs()[0] == - ("metric1,t1=val1,t2=val2 f1=t,f2=12345i," - "f3=10.75,f4=\"val3\",f5=\"val4\",f6=\"val5\" 111222233333\n")); - CHECK( - server.msgs()[1] == - "metric1,tag3=value\\ 3,tag\\ 4=value:4 field5=f\n"); + server.msgs(1) == "metric1,tag3=value\\ 3,tag\\ 4=value:4 field5=f\n"); } TEST_CASE("State machine testing -- flush without data.") @@ -243,7 +338,7 @@ TEST_CASE("One symbol only - flush before server accept") // but the server hasn't actually accepted the client connection yet. server.accept(); CHECK(server.recv() == 1); - CHECK(server.msgs()[0] == "test,t1=v1\n"); + CHECK(server.msgs(0) == "test,t1=v1\n"); } TEST_CASE("One column only - server.accept() after flush, before close") @@ -263,7 +358,7 @@ TEST_CASE("One column only - server.accept() after flush, before close") sender.close(); CHECK(server.recv() == 1); - CHECK(server.msgs()[0] == "test t1=\"v1\"\n"); + CHECK(server.msgs(0) == "test t1=\"v1\"\n"); } TEST_CASE("Symbol after column") @@ -386,42 +481,6 @@ TEST_CASE("Validation of bad chars in key names.") } } -#if __cplusplus >= 202002L -template <size_t N> -bool operator==(std::span<const std::byte> lhs, const char (&rhs)[N]) -{ - constexpr size_t bytelen = N - 1; // Exclude null terminator - const std::span<const std::byte> rhs_span{ - reinterpret_cast<const std::byte*>(rhs), bytelen}; - return lhs.size() == bytelen && std::ranges::equal(lhs, rhs_span); -} - -bool operator==(std::span<const std::byte> lhs, const std::string& rhs) -{ - const std::span<const std::byte> rhs_span{ - reinterpret_cast<const std::byte*>(rhs.data()), rhs.size()}; - return lhs.size() == rhs.size() && std::ranges::equal(lhs, rhs_span); -} -#else -template <size_t N> -bool operator==( - const questdb::ingress::buffer_view lhs_view, const char (&rhs)[N]) -{ - constexpr size_t bytelen = N - 1; // Exclude null terminator - const questdb::ingress::buffer_view rhs_view{ - reinterpret_cast<const std::byte*>(rhs), bytelen}; - return lhs_view == rhs_view; -} - -bool operator==( - const questdb::ingress::buffer_view lhs_view, const std::string& rhs) -{ - const questdb::ingress::buffer_view rhs_view{ - reinterpret_cast<const std::byte*>(rhs.data()), rhs.size()}; - return lhs_view == rhs_view; -} -#endif - TEST_CASE("Buffer move and copy ctor testing") { const size_t init_buf_size = 128; @@ -634,13 +693,19 @@ TEST_CASE("os certs") { questdb::ingress::opts opts{ - questdb::ingress::protocol::https, "localhost", server.port()}; + questdb::ingress::protocol::https, + "localhost", + server.port(), + true}; opts.tls_ca(questdb::ingress::ca::os_roots); } { questdb::ingress::opts opts{ - questdb::ingress::protocol::https, "localhost", server.port()}; + questdb::ingress::protocol::https, + "localhost", + server.port(), + true}; opts.tls_ca(questdb::ingress::ca::webpki_and_os_roots); } } @@ -669,9 +734,12 @@ TEST_CASE("Opts copy ctor, assignment and move testing.") { questdb::ingress::opts opts1{ - questdb::ingress::protocol::https, "localhost", "9009"}; + questdb::ingress::protocol::https, "localhost", "9009", true}; questdb::ingress::opts opts2{ - questdb::ingress::protocol::https, "altavista.digital.com", "9009"}; + questdb::ingress::protocol::https, + "altavista.digital.com", + "9009", + true}; opts1 = opts2; } } @@ -715,7 +783,7 @@ TEST_CASE("Test timestamp column.") sender.close(); CHECK(server.recv() == 1); - CHECK(server.msgs()[0] == exp); + CHECK(server.msgs(0) == exp); } TEST_CASE("test timestamp_micros and timestamp_nanos::now()") @@ -841,15 +909,15 @@ TEST_CASE("Opts from conf") TEST_CASE("HTTP basics") { questdb::ingress::opts opts1{ - questdb::ingress::protocol::http, "localhost", 1}; + questdb::ingress::protocol::http, "localhost", 1, true}; questdb::ingress::opts opts1conf = questdb::ingress::opts::from_conf( "http::addr=localhost:1;username=user;password=pass;request_timeout=" - "5000;retry_timeout=5;"); + "5000;retry_timeout=5;disable_line_protocol_validation=on;"); questdb::ingress::opts opts2{ - questdb::ingress::protocol::https, "localhost", "1"}; + questdb::ingress::protocol::https, "localhost", "1", true}; questdb::ingress::opts opts2conf = questdb::ingress::opts::from_conf( "http::addr=localhost:1;token=token;request_min_throughput=1000;retry_" - "timeout=0;"); + "timeout=0;disable_line_protocol_validation=on;"); opts1.username("user") .password("pass") .max_buf_size(1000000) @@ -873,4 +941,30 @@ TEST_CASE("HTTP basics") questdb::ingress::opts::from_conf( "http::addr=localhost:1;bind_interface=0.0.0.0;"), questdb::ingress::line_sender_error); -} \ No newline at end of file +} + +TEST_CASE("line sender protocol version v1") +{ + questdb::ingress::test::mock_server server; + questdb::ingress::line_sender sender{ + questdb::ingress::protocol::tcp, + std::string("localhost"), + std::to_string(server.port())}; + CHECK_FALSE(sender.must_close()); + server.accept(); + CHECK(server.recv() == 0); + + questdb::ingress::line_sender_buffer buffer{line_protocol_version_1}; + buffer.table("test") + .symbol("t1", "v1") + .symbol("t2", "") + .column("f1", 0.5) + .at(questdb::ingress::timestamp_nanos{10000000}); + + CHECK(server.recv() == 0); + CHECK(buffer.size() == 31); + sender.flush(buffer); + CHECK(server.recv() == 1); + std::string expect{"test,t1=v1,t2= f1=0.5 10000000\n"}; + CHECK(server.msgs(0) == expect); +} diff --git a/examples/line_sender_c_example_array.c b/examples/line_sender_c_example_array.c new file mode 100644 index 00000000..6f5650c7 --- /dev/null +++ b/examples/line_sender_c_example_array.c @@ -0,0 +1,99 @@ +#include <questdb/ingress/line_sender.h> +#include <stdio.h> +#include <stdbool.h> +#include <string.h> +#include "concat.h" + +static bool example(const char* host, const char* port) +{ + line_sender_error* err = NULL; + line_sender* sender = NULL; + line_sender_buffer* buffer = NULL; + char* conf_str = concat("tcp::addr=", host, ":", port, ";"); + if (!conf_str) + { + fprintf(stderr, "Could not concatenate configuration string.\n"); + return false; + } + + line_sender_utf8 conf_str_utf8 = {0, NULL}; + if (!line_sender_utf8_init( + &conf_str_utf8, strlen(conf_str), conf_str, &err)) + goto on_error; + + sender = line_sender_from_conf(conf_str_utf8, &err); + if (!sender) + goto on_error; + + free(conf_str); + conf_str = NULL; + + buffer = line_sender_buffer_new(); + line_sender_buffer_reserve(buffer, 64 * 1024); // 64KB 初始缓冲 + + line_sender_table_name table_name = QDB_TABLE_NAME_LITERAL("market_orders"); + line_sender_column_name symbol_col = QDB_COLUMN_NAME_LITERAL("symbol"); + line_sender_column_name book_col = QDB_COLUMN_NAME_LITERAL("order_book"); + + if (!line_sender_buffer_table(buffer, table_name, &err)) + goto on_error; + + line_sender_utf8 symbol_val = QDB_UTF8_LITERAL("BTC-USD"); + if (!line_sender_buffer_symbol(buffer, symbol_col, symbol_val, &err)) + goto on_error; + + size_t array_rank = 3; + uintptr_t array_shapes[] = {2, 3, 2}; + intptr_t array_strides[] = {48, 16, 8}; + + double array_data[] = { + 48123.5, + 2.4, + 48124.0, + 1.8, + 48124.5, + 0.9, + 48122.5, + 3.1, + 48122.0, + 2.7, + 48121.5, + 4.3}; + + if (!line_sender_buffer_column_f64_arr( + buffer, + book_col, + array_rank, + array_shapes, + array_strides, + (const uint8_t*)array_data, + sizeof(array_data), + &err)) + goto on_error; + + if (!line_sender_buffer_at_nanos(buffer, line_sender_now_nanos(), &err)) + goto on_error; + + if (!line_sender_flush(sender, buffer, &err)) + goto on_error; + + line_sender_close(sender); + return true; + +on_error:; + size_t err_len = 0; + const char* err_msg = line_sender_error_msg(err, &err_len); + fprintf(stderr, "Error: %.*s\n", (int)err_len, err_msg); + free(conf_str); + line_sender_error_free(err); + line_sender_buffer_free(buffer); + line_sender_close(sender); + return false; +} + +int main(int argc, const char* argv[]) +{ + const char* host = (argc >= 2) ? argv[1] : "localhost"; + const char* port = (argc >= 3) ? argv[2] : "9009"; + return !example(host, port); +} diff --git a/examples/line_sender_c_example_http.c b/examples/line_sender_c_example_http.c index 427ab705..f3bd3248 100644 --- a/examples/line_sender_c_example_http.c +++ b/examples/line_sender_c_example_http.c @@ -29,6 +29,9 @@ static bool example(const char* host, const char* port) buffer = line_sender_buffer_new(); line_sender_buffer_reserve(buffer, 64 * 1024); // 64KB buffer initial size. + if (!line_sender_buffer_set_line_protocol_version( + buffer, line_sender_default_line_protocol_version(sender), &err)) + goto on_error; line_sender_table_name table_name = QDB_TABLE_NAME_LITERAL("c_trades_http"); line_sender_column_name symbol_name = QDB_COLUMN_NAME_LITERAL("symbol"); diff --git a/examples/line_sender_cpp_example.cpp b/examples/line_sender_cpp_example.cpp index 2d90812f..d65ecc31 100644 --- a/examples/line_sender_cpp_example.cpp +++ b/examples/line_sender_cpp_example.cpp @@ -21,8 +21,7 @@ static bool example(std::string_view host, std::string_view port) const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -40,10 +39,7 @@ static bool example(std::string_view host, std::string_view port) } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return false; } @@ -56,12 +52,11 @@ static bool displayed_help(int argc, const char* argv[]) const std::string_view arg{argv[index]}; if ((arg == "-h"sv) || (arg == "--help"sv)) { - std::cerr - << "Usage:\n" - << "line_sender_c_example: [HOST [PORT]]\n" - << " HOST: ILP host (defaults to \"localhost\").\n" - << " PORT: ILP port (defaults to \"9009\")." - << std::endl; + std::cerr << "Usage:\n" + << "line_sender_c_example: [HOST [PORT]]\n" + << " HOST: ILP host (defaults to \"localhost\").\n" + << " PORT: ILP port (defaults to \"9009\")." + << std::endl; return true; } } diff --git a/examples/line_sender_cpp_example_array.cpp b/examples/line_sender_cpp_example_array.cpp new file mode 100644 index 00000000..e07cd21b --- /dev/null +++ b/examples/line_sender_cpp_example_array.cpp @@ -0,0 +1,61 @@ +#include <questdb/ingress/line_sender.hpp> +#include <iostream> +#include <vector> + +using namespace std::literals::string_view_literals; +using namespace questdb::ingress::literals; + +static bool array_example(std::string_view host, std::string_view port) +{ + try + { + auto sender = questdb::ingress::line_sender::from_conf( + "tcp::addr=" + std::string{host} + ":" + std::string{port} + ";"); + + const auto table_name = "cpp_market_orders"_tn; + const auto symbol_col = "symbol"_cn; + const auto book_col = "order_book"_cn; + size_t rank = 3; + std::vector<uintptr_t> shape{2, 3, 2}; + std::vector<intptr_t> strides{48, 16, 8}; + std::array<double, 12> arr_data = { + 48123.5, + 2.4, + 48124.0, + 1.8, + 48124.5, + 0.9, + 48122.5, + 3.1, + 48122.0, + 2.7, + 48121.5, + 4.3}; + + questdb::ingress::line_sender_buffer buffer; + buffer.table(table_name) + .symbol(symbol_col, "BTC-USD"_utf8) + .column(book_col, 3, shape, strides, arr_data) + .at(questdb::ingress::timestamp_nanos::now()); + sender.flush(buffer); + return true; + } + catch (const questdb::ingress::line_sender_error& err) + { + std::cerr << "[ERROR] " << err.what() << std::endl; + return false; + } +} + +int main(int argc, const char* argv[]) +{ + auto host = "localhost"sv; + if (argc >= 2) + host = std::string_view{argv[1]}; + + auto port = "9009"sv; + if (argc >= 3) + port = std::string_view{argv[2]}; + + return !array_example(host, port); +} diff --git a/examples/line_sender_cpp_example_auth.cpp b/examples/line_sender_cpp_example_auth.cpp index 85e0d6e1..4c229617 100644 --- a/examples/line_sender_cpp_example_auth.cpp +++ b/examples/line_sender_cpp_example_auth.cpp @@ -9,7 +9,8 @@ static bool example(std::string_view host, std::string_view port) try { auto sender = questdb::ingress::line_sender::from_conf( - "tcp::addr=" + std::string{host} + ":" + std::string{port} + ";" + "tcp::addr=" + std::string{host} + ":" + std::string{port} + + ";" "username=admin;" "token=5UjEMuA0Pj5pjK8a-fa24dyIf-Es5mYny3oE_Wmus48;" "token_x=fLKYEaoEb9lrn3nkwLDA-M_xnuFOdSt9y0Z7_vWSHLU;" @@ -25,8 +26,7 @@ static bool example(std::string_view host, std::string_view port) const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -44,10 +44,7 @@ static bool example(std::string_view host, std::string_view port) } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return false; } @@ -60,12 +57,11 @@ static bool displayed_help(int argc, const char* argv[]) const std::string_view arg{argv[index]}; if ((arg == "-h"sv) || (arg == "--help"sv)) { - std::cerr - << "Usage:\n" - << "line_sender_c_example: [HOST [PORT]]\n" - << " HOST: ILP host (defaults to \"localhost\").\n" - << " PORT: ILP port (defaults to \"9009\")." - << std::endl; + std::cerr << "Usage:\n" + << "line_sender_c_example: [HOST [PORT]]\n" + << " HOST: ILP host (defaults to \"localhost\").\n" + << " PORT: ILP port (defaults to \"9009\")." + << std::endl; return true; } } diff --git a/examples/line_sender_cpp_example_auth_tls.cpp b/examples/line_sender_cpp_example_auth_tls.cpp index f100dd04..f202fc75 100644 --- a/examples/line_sender_cpp_example_auth_tls.cpp +++ b/examples/line_sender_cpp_example_auth_tls.cpp @@ -4,14 +4,13 @@ using namespace std::literals::string_view_literals; using namespace questdb::ingress::literals; -static bool example( - std::string_view host, - std::string_view port) +static bool example(std::string_view host, std::string_view port) { try { auto sender = questdb::ingress::line_sender::from_conf( - "tcps::addr=" + std::string{host} + ":" + std::string{port} + ";" + "tcps::addr=" + std::string{host} + ":" + std::string{port} + + ";" "username=admin;" "token=5UjEMuA0Pj5pjK8a-fa24dyIf-Es5mYny3oE_Wmus48;" "token_x=fLKYEaoEb9lrn3nkwLDA-M_xnuFOdSt9y0Z7_vWSHLU;" @@ -27,8 +26,7 @@ static bool example( const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -46,10 +44,7 @@ static bool example( } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return false; } @@ -62,12 +57,11 @@ static bool displayed_help(int argc, const char* argv[]) const std::string_view arg{argv[index]}; if ((arg == "-h"sv) || (arg == "--help"sv)) { - std::cerr - << "Usage:\n" - << "line_sender_c_example: CA_PATH [HOST [PORT]]\n" - << " HOST: ILP host (defaults to \"localhost\").\n" - << " PORT: ILP port (defaults to \"9009\")." - << std::endl; + std::cerr << "Usage:\n" + << "line_sender_c_example: CA_PATH [HOST [PORT]]\n" + << " HOST: ILP host (defaults to \"localhost\").\n" + << " PORT: ILP port (defaults to \"9009\")." + << std::endl; return true; } } diff --git a/examples/line_sender_cpp_example_from_conf.cpp b/examples/line_sender_cpp_example_from_conf.cpp index bb71c6e3..2a2fe510 100644 --- a/examples/line_sender_cpp_example_from_conf.cpp +++ b/examples/line_sender_cpp_example_from_conf.cpp @@ -21,8 +21,7 @@ int main(int argc, const char* argv[]) const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -40,10 +39,7 @@ int main(int argc, const char* argv[]) } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return 1; } diff --git a/examples/line_sender_cpp_example_from_env.cpp b/examples/line_sender_cpp_example_from_env.cpp index 63e99b26..54acd658 100644 --- a/examples/line_sender_cpp_example_from_env.cpp +++ b/examples/line_sender_cpp_example_from_env.cpp @@ -20,8 +20,7 @@ int main(int argc, const char* argv[]) const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -39,10 +38,7 @@ int main(int argc, const char* argv[]) } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return 1; } diff --git a/examples/line_sender_cpp_example_http.cpp b/examples/line_sender_cpp_example_http.cpp index 217845a2..800e11aa 100644 --- a/examples/line_sender_cpp_example_http.cpp +++ b/examples/line_sender_cpp_example_http.cpp @@ -20,9 +20,9 @@ static bool example(std::string_view host, std::string_view port) const auto price_name = "price"_cn; const auto amount_name = "amount"_cn; - questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + questdb::ingress::line_sender_buffer buffer{ + sender.default_line_protocol_version()}; + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -40,10 +40,7 @@ static bool example(std::string_view host, std::string_view port) } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return false; } @@ -56,12 +53,11 @@ static bool displayed_help(int argc, const char* argv[]) const std::string_view arg{argv[index]}; if ((arg == "-h"sv) || (arg == "--help"sv)) { - std::cerr - << "Usage:\n" - << "line_sender_c_example: [HOST [PORT]]\n" - << " HOST: ILP host (defaults to \"localhost\").\n" - << " PORT: ILP port (defaults to \"9009\")." - << std::endl; + std::cerr << "Usage:\n" + << "line_sender_c_example: [HOST [PORT]]\n" + << " HOST: ILP host (defaults to \"localhost\").\n" + << " PORT: ILP port (defaults to \"9009\")." + << std::endl; return true; } } diff --git a/examples/line_sender_cpp_example_tls_ca.cpp b/examples/line_sender_cpp_example_tls_ca.cpp index c0327e96..ac4d4743 100644 --- a/examples/line_sender_cpp_example_tls_ca.cpp +++ b/examples/line_sender_cpp_example_tls_ca.cpp @@ -5,19 +5,19 @@ using namespace std::literals::string_view_literals; using namespace questdb::ingress::literals; static bool example( - std::string_view ca_path, - std::string_view host, - std::string_view port) + std::string_view ca_path, std::string_view host, std::string_view port) { try { auto sender = questdb::ingress::line_sender::from_conf( - "tcps::addr=" + std::string{host} + ":" + std::string{port} + ";" + "tcps::addr=" + std::string{host} + ":" + std::string{port} + + ";" "username=admin;" "token=5UjEMuA0Pj5pjK8a-fa24dyIf-Es5mYny3oE_Wmus48;" "token_x=fLKYEaoEb9lrn3nkwLDA-M_xnuFOdSt9y0Z7_vWSHLU;" "token_y=Dt5tbS1dEDMSYfym3fgMv0B99szno-dFc1rYF9t0aac;" - "tls_roots=" + std::string{ca_path} + ";"); // path to custom `.pem` file. + "tls_roots=" + + std::string{ca_path} + ";"); // path to custom `.pem` file. // We prepare all our table names and column names in advance. // If we're inserting multiple rows, this allows us to avoid @@ -29,8 +29,7 @@ static bool example( const auto amount_name = "amount"_cn; questdb::ingress::line_sender_buffer buffer; - buffer - .table(table_name) + buffer.table(table_name) .symbol(symbol_name, "ETH-USD"_utf8) .symbol(side_name, "sell"_utf8) .column(price_name, 2615.54) @@ -48,10 +47,7 @@ static bool example( } catch (const questdb::ingress::line_sender_error& err) { - std::cerr - << "Error running example: " - << err.what() - << std::endl; + std::cerr << "Error running example: " << err.what() << std::endl; return false; } @@ -64,13 +60,12 @@ static bool displayed_help(int argc, const char* argv[]) const std::string_view arg{argv[index]}; if ((arg == "-h"sv) || (arg == "--help"sv)) { - std::cerr - << "Usage:\n" - << "line_sender_c_example: CA_PATH [HOST [PORT]]\n" - << " CA_PATH: Certificate authority pem file.\n" - << " HOST: ILP host (defaults to \"localhost\").\n" - << " PORT: ILP port (defaults to \"9009\")." - << std::endl; + std::cerr << "Usage:\n" + << "line_sender_c_example: CA_PATH [HOST [PORT]]\n" + << " CA_PATH: Certificate authority pem file.\n" + << " HOST: ILP host (defaults to \"localhost\").\n" + << " PORT: ILP port (defaults to \"9009\")." + << std::endl; return true; } } diff --git a/include/questdb/ingress/line_sender.h b/include/questdb/ingress/line_sender.h index 281b9d78..959fd149 100644 --- a/include/questdb/ingress/line_sender.h +++ b/include/questdb/ingress/line_sender.h @@ -77,6 +77,19 @@ typedef enum line_sender_error_code /** Bad configuration. */ line_sender_error_config_error, + + /** Currently, only arrays with a maximum 32 dimensions are supported. */ + line_sender_error_array_large_dim, + + /** ArrayView internal error, such as failure to get the size of a valid + * dimension. */ + line_sender_error_array_view_internal_error, + + /** Write arrayView to sender buffer error. */ + line_sender_error_array_view_write_to_buffer_error, + + /** Line sender protocol version error. */ + line_sender_error_line_protocol_version_error, } line_sender_error_code; /** The protocol used to connect with. */ @@ -95,6 +108,18 @@ typedef enum line_sender_protocol line_sender_protocol_https, } line_sender_protocol; +/** The line protocol version used to write data to buffer. */ +typedef enum line_protocol_version +{ + /** Version 1 of InfluxDB Line Protocol. + Uses text format serialization for f64. */ + line_protocol_version_1 = 1, + + /** Version 2 of InfluxDB Line Protocol. + Uses binary format serialization for f64, and support array data type.*/ + line_protocol_version_2 = 2, +} line_protocol_version; + /** Possible sources of the root certificates used to validate the server's * TLS certificate. */ typedef enum line_sender_ca @@ -296,6 +321,23 @@ line_sender_buffer* line_sender_buffer_new(); LINESENDER_API line_sender_buffer* line_sender_buffer_with_max_name_len(size_t max_name_len); +/** + * Sets the Line Protocol version for line_sender_buffer. + * + * The buffer defaults is line_protocol_version_2 which uses + * binary format f64 serialization and support array data type. Call this to + * switch to version 1 (text format f64) when connecting to servers that don't + * support line_protocol_version_2(under 8.3.2). + * + * Must be called before adding any data to the buffer. Protocol version cannot + * be changed after the buffer contains data. + */ +LINESENDER_API +bool line_sender_buffer_set_line_protocol_version( + line_sender_buffer* buffer, + line_protocol_version version, + line_sender_error** err_out); + /** Release the `line_sender_buffer` object. */ LINESENDER_API void line_sender_buffer_free(line_sender_buffer* buffer); @@ -461,6 +503,32 @@ bool line_sender_buffer_column_str( line_sender_utf8 value, line_sender_error** err_out); +/** + * Record a multidimensional array of double for the given column. + * The array data must be stored in row-major order (C-style contiguous layout). + * + * @param[in] buffer Line buffer object. + * @param[in] name Column name. + * @param[in] rank Number of dimensions of the array. + * @param[in] shapes Array of dimension sizes (length = `rank`). + * Each element must be a positive integer. + * @param[in] strides Array strides. + * @param[in] data_buffer First array element data. + * @param[in] data_buffer_len Bytes length of the array data. + * @param[out] err_out Set to an error object on failure (if non-NULL). + * @return true on success, false on error. + */ +LINESENDER_API +bool line_sender_buffer_column_f64_arr( + line_sender_buffer* buffer, + line_sender_column_name name, + size_t rank, + const uintptr_t* shape, + const intptr_t* strides, + const uint8_t* data_buffer, + size_t data_buffer_len, + line_sender_error** err_out); + /** * Record a nanosecond timestamp value for the given column. * @param[in] buffer Line buffer object. @@ -693,6 +761,13 @@ bool line_sender_opts_token_y( line_sender_utf8 token_y, line_sender_error** err_out); +/** + * Set the ECDSA public key Y for TCP authentication. + */ +LINESENDER_API +bool line_sender_opts_disable_line_protocol_validation( + line_sender_opts* opts, line_sender_error** err_out); + /** * Configure how long to wait for messages from the QuestDB server during * the TLS handshake and authentication process. @@ -852,6 +927,25 @@ line_sender* line_sender_from_conf( LINESENDER_API line_sender* line_sender_from_env(line_sender_error** err_out); +/** + * Returns the QuestDB server's recommended default line protocol version. + * Will be used to [`line_sender_buffer_set_line_protocol_version`] + * + * The version selection follows these rules: + * 1. TCP/TCPS Protocol: Always returns [`LineProtocolVersion::V2`] + * 2. HTTP/HTTPS Protocol: + * - If line protocol auto-detection is disabled + * [`line_sender_opts_disable_line_protocol_validation`], returns + * [`LineProtocolVersion::V2`] + * - If line protocol auto-detection is enabled: + * - Uses the server's default version if supported by the client + * - Otherwise uses the highest mutually supported version from the + * intersection of client and server compatible versions. + */ +LINESENDER_API +line_protocol_version line_sender_default_line_protocol_version( + const line_sender* sender); + /** * Tell whether the sender is no longer usable and must be closed. * This happens when there was an earlier failure. diff --git a/include/questdb/ingress/line_sender.hpp b/include/questdb/ingress/line_sender.hpp index e118a22f..21906aa9 100644 --- a/include/questdb/ingress/line_sender.hpp +++ b/include/questdb/ingress/line_sender.hpp @@ -26,6 +26,7 @@ #include "line_sender.h" +#include <array> #include <chrono> #include <cstddef> #include <cstdint> @@ -33,6 +34,7 @@ #include <stdexcept> #include <string> #include <type_traits> +#include <vector> #if __cplusplus >= 202002L # include <span> #endif @@ -399,10 +401,27 @@ class line_sender_buffer { } + line_sender_buffer( + size_t init_buf_size, + size_t max_name_len, + line_protocol_version version) noexcept + : _impl{nullptr} + , _init_buf_size{init_buf_size} + , _max_name_len{max_name_len} + , _line_protocol_version{version} + { + } + + line_sender_buffer(line_protocol_version version) noexcept + : line_sender_buffer{64 * 1024, 127, version} + { + } + line_sender_buffer(const line_sender_buffer& other) noexcept : _impl{::line_sender_buffer_clone(other._impl)} , _init_buf_size{other._init_buf_size} , _max_name_len{other._max_name_len} + , _line_protocol_version{other._line_protocol_version} { } @@ -410,6 +429,7 @@ class line_sender_buffer : _impl{other._impl} , _init_buf_size{other._init_buf_size} , _max_name_len{other._max_name_len} + , _line_protocol_version{other._line_protocol_version} { other._impl = nullptr; } @@ -425,6 +445,7 @@ class line_sender_buffer _impl = nullptr; _init_buf_size = other._init_buf_size; _max_name_len = other._max_name_len; + _line_protocol_version = other._line_protocol_version; } return *this; } @@ -437,11 +458,32 @@ class line_sender_buffer _impl = other._impl; _init_buf_size = other._init_buf_size; _max_name_len = other._max_name_len; + _line_protocol_version = other._line_protocol_version; other._impl = nullptr; } return *this; } + /** + * Sets the Line Protocol version for line_sender_buffer. + * + * The buffer defaults is line_protocol_version_2 which uses + * binary format f64 serialization and support array data type. Call this to + * switch to version 1 (text format f64) when connecting to servers that + * don't support line_protocol_version_2(under 8.3.2). + * + * Must be called before adding any data to the buffer. Protocol version + * cannot be changed after the buffer contains data. + */ + line_sender_buffer& set_line_protocol_version(line_protocol_version v) + { + may_init(); + line_sender_error::wrapped_call( + ::line_sender_buffer_set_line_protocol_version, _impl, v); + _line_protocol_version = v; + return *this; + } + /** * Pre-allocate to ensure the buffer has enough capacity for at least * the specified additional byte count. This may be rounded up. @@ -624,6 +666,38 @@ class line_sender_buffer return *this; } + /** + * Record a multidimensional double-precision array for the given column. + * + * @param name Column name. + * @param shape Array dimensions (e.g., [2,3] for a 2x3 matrix). + * @param data Array first element data. Size must match product of + * dimensions. + */ + template <typename T, size_t N> + line_sender_buffer& column( + column_name_view name, + const size_t rank, + const std::vector<uintptr_t>& shapes, + const std::vector<intptr_t>& strides, + const std::array<T, N>& data) + { + static_assert( + std::is_same_v<T, double>, + "Only double types are supported for arrays"); + may_init(); + line_sender_error::wrapped_call( + ::line_sender_buffer_column_f64_arr, + _impl, + name._impl, + rank, + shapes.data(), + strides.data(), + reinterpret_cast<const uint8_t*>(data.data()), + sizeof(double) * N); + return *this; + } + /** * Record a string value for the given column. * @param name Column name. @@ -769,12 +843,17 @@ class line_sender_buffer { _impl = ::line_sender_buffer_with_max_name_len(_max_name_len); ::line_sender_buffer_reserve(_impl, _init_buf_size); + line_sender_error::wrapped_call( + line_sender_buffer_set_line_protocol_version, + _impl, + _line_protocol_version); } } ::line_sender_buffer* _impl; size_t _init_buf_size; size_t _max_name_len; + line_protocol_version _line_protocol_version{::line_protocol_version_2}; friend class line_sender; }; @@ -834,13 +913,24 @@ class opts * @param[in] protocol The protocol to use. * @param[in] host The QuestDB database host. * @param[in] port The QuestDB tcp or http port. + * @param[in] disable_line_protocol_validation disable line protocol version + * validation. */ - opts(protocol protocol, utf8_view host, uint16_t port) noexcept + opts( + protocol protocol, + utf8_view host, + uint16_t port, + bool disable_line_protocol_validation = false) noexcept : _impl{::line_sender_opts_new( static_cast<::line_sender_protocol>(protocol), host._impl, port)} { line_sender_error::wrapped_call( ::line_sender_opts_user_agent, _impl, _user_agent::name()); + if (disable_line_protocol_validation) + { + line_sender_error::wrapped_call( + ::line_sender_opts_disable_line_protocol_validation, _impl); + } } /** @@ -849,8 +939,13 @@ class opts * @param[in] protocol The protocol to use. * @param[in] host The QuestDB database host. * @param[in] port The QuestDB tcp or http port as service name. + * @param[in] disable_line_protocol_validation disable line protocol version */ - opts(protocol protocol, utf8_view host, utf8_view port) noexcept + opts( + protocol protocol, + utf8_view host, + utf8_view port, + bool disable_line_protocol_validation = false) noexcept : _impl{::line_sender_opts_new_service( static_cast<::line_sender_protocol>(protocol), host._impl, @@ -858,6 +953,11 @@ class opts { line_sender_error::wrapped_call( ::line_sender_opts_user_agent, _impl, _user_agent::name()); + if (disable_line_protocol_validation) + { + line_sender_error::wrapped_call( + ::line_sender_opts_disable_line_protocol_validation, _impl); + } } opts(const opts& other) noexcept @@ -964,6 +1064,16 @@ class opts return *this; } + /** + * Disable the validation of the line protocol version. + */ + opts& disable_line_protocol_validation() + { + line_sender_error::wrapped_call( + ::line_sender_opts_disable_line_protocol_validation, _impl); + return *this; + } + /** * Configure how long to wait for messages from the QuestDB server during * the TLS handshake and authentication process. @@ -1150,13 +1260,23 @@ class line_sender return {opts::from_env()}; } - line_sender(protocol protocol, utf8_view host, uint16_t port) - : line_sender{opts{protocol, host, port}} + line_sender( + protocol protocol, + utf8_view host, + uint16_t port, + bool disable_line_protocol_validation = false) + : line_sender{ + opts{protocol, host, port, disable_line_protocol_validation}} { } - line_sender(protocol protocol, utf8_view host, utf8_view port) - : line_sender{opts{protocol, host, port}} + line_sender( + protocol protocol, + utf8_view host, + utf8_view port, + bool disable_line_protocol_validation = false) + : line_sender{ + opts{protocol, host, port, disable_line_protocol_validation}} { } @@ -1246,6 +1366,15 @@ class line_sender } } + /** + * Returns the QuestDB server's recommended default line protocol version. + */ + line_protocol_version default_line_protocol_version() + { + ensure_impl(); + return line_sender_default_line_protocol_version(_impl); + } + /** * Check if an error occurred previously and the sender must be closed. * This happens when there was an earlier failure. diff --git a/proj b/proj new file mode 100755 index 00000000..d4c80095 --- /dev/null +++ b/proj @@ -0,0 +1,2 @@ +#!/bin/bash +python3 proj.py "$@" diff --git a/proj.bat b/proj.bat new file mode 100644 index 00000000..027f5468 --- /dev/null +++ b/proj.bat @@ -0,0 +1,2 @@ +@echo off +python3 proj.py %* diff --git a/proj.ps1 b/proj.ps1 new file mode 100644 index 00000000..7f9e830e --- /dev/null +++ b/proj.ps1 @@ -0,0 +1 @@ +python3 proj.py $args diff --git a/proj.py b/proj.py new file mode 100755 index 00000000..bd36a449 --- /dev/null +++ b/proj.py @@ -0,0 +1,220 @@ +#!/usr/bin/env python3 + +import sys +sys.dont_write_bytecode = True +import pathlib +import shutil +import shlex +import subprocess +import os + + +PROJ_ROOT = pathlib.Path(__file__).parent + + +def _run(*args, env=None, cwd=None): + """ + Log and run a command within the build dir. + On error, exit with child's return code. + """ + args = [str(arg) for arg in args] + cwd = cwd or PROJ_ROOT + sys.stderr.write('[CMD] ') + if env is not None: + env_str = ' '.join(f'{k}={shlex.quote(v)}' for k, v in env.items()) + sys.stderr.write(f'{env_str} ') + env = {**os.environ, **env} + escaped_cmd = ' '.join(shlex.quote(arg) for arg in args) + sys.stderr.write(f'{escaped_cmd}\n') + ret_code = subprocess.run(args, cwd=str(cwd), env=env).returncode + if ret_code != 0: + sys.exit(ret_code) + + +def _rm(path: pathlib.Path, pattern: str): + paths = path.glob(pattern) + for path in paths: + sys.stderr.write(f'[RM] {path}\n') + path.unlink() + + +def _rmtree(path: pathlib.Path): + if not path.exists(): + return + sys.stderr.write(f'[RMTREE] {path}\n') + shutil.rmtree(path, ignore_errors=True) + + +def _has_command(command: str) -> bool: + """ + Check if a command is available in the system. + """ + return shutil.which(command) is not None + + +COMMANDS = [] + + +def command(fn): + COMMANDS.append(fn.__name__) + return fn + + +@command +def clean(): + _rmtree(PROJ_ROOT / 'build') + _rmtree(PROJ_ROOT / 'build_CXX20') + _rmtree(PROJ_ROOT / 'questdb-rs' / 'target') + _rmtree(PROJ_ROOT / 'questdb-rs-ffi' / 'target') + + +@command +def cmake_cxx17(): + _rmtree(PROJ_ROOT / 'build') + cmd = [ + 'cmake', + '-S', '.', + '-B', 'build', + '-DCMAKE_BUILD_TYPE=Release', + '-DQUESTDB_TESTS_AND_EXAMPLES=ON'] + if _has_command('ninja'): + cmd.insert(1, '-G') + cmd.insert(2, 'Ninja') + _run(*cmd) + + +@command +def cmake_cxx20(): + _rmtree(PROJ_ROOT / 'build_CXX20') + cmd = [ + 'cmake', + '-S', '.', + '-B', 'build_CXX20', + '-DCMAKE_BUILD_TYPE=Release', + '-DQUESTDB_TESTS_AND_EXAMPLES=ON', + '-DCMAKE_CXX_STANDARD=20'] + if _has_command('ninja'): + cmd.insert(1, '-G') + cmd.insert(2, 'Ninja') + _run(*cmd) + + +@command +def build_cxx17(): + if not (PROJ_ROOT / 'build').exists(): + cmake_cxx17() + _run('cmake', '--build', 'build') + + +@command +def build_cxx20(): + if not (PROJ_ROOT / 'build_CXX20').exists(): + cmake_cxx20() + _run('cmake', '--build', 'build_CXX20') + + +@command +def build(): + build_cxx17() + build_cxx20() + + +@command +def lint_rust(): + questdb_rs_path = PROJ_ROOT / 'questdb-rs' + questdb_rs_ffi_path = PROJ_ROOT / 'questdb-rs-ffi' + _run('cargo', 'fmt', '--all', '--', '--check', cwd=questdb_rs_path) + _run('cargo', 'clippy', '--all-targets', '--features', 'almost-all-features', '--', '-D', 'warnings', cwd=questdb_rs_path) + _run('cargo', 'fmt', '--all', '--', '--check', cwd=questdb_rs_ffi_path) + _run('cargo', 'clippy', '--all-targets', '--all-features', '--', '-D', 'warnings', cwd=questdb_rs_ffi_path) + + +@command +def lint_cpp(): + try: + _run( + sys.executable, + PROJ_ROOT / 'ci' / 'format_cpp.py', + '--check') + except subprocess.CalledProcessError: + sys.stderr.write('REMINDER: To fix any C++ formatting issues, run: ./proj format_cpp\n') + raise + + +@command +def lint(): + lint_rust() + lint_cpp() + + +@command +def format_rust(): + questdb_rs_path = PROJ_ROOT / 'questdb-rs' + questdb_rs_ffi_path = PROJ_ROOT / 'questdb-rs-ffi' + _run('cargo', 'fmt', '--all', cwd=questdb_rs_path) + _run('cargo', 'fmt', '--all', cwd=questdb_rs_ffi_path) + + +@command +def format_cpp(): + _run( + sys.executable, + PROJ_ROOT / 'ci' / 'format_cpp.py') + + +@command +def test(): + build() + _run( + sys.executable, + PROJ_ROOT / 'ci' / 'run_all_tests.py') + + +@command +def build_latest_questdb(branch='master'): + questdb_path = PROJ_ROOT / 'questdb' + if not questdb_path.exists(): + _run('git', 'clone', 'https://github.com/questdb/questdb.git') + _run('git', 'fetch', 'origin', branch, cwd=questdb_path) + _run('git', 'switch', branch=questdb_path) + _run('git', 'pull', 'origin', branch=questdb_path) + _run('git', 'submodule', 'update', '--init', '--recursive', cwd=questdb_path) + _run('mvn', 'clean', 'package', '-DskipTests', '-Pbuild-web-console', cwd=questdb_path) + + +@command +def test_vs_latest_questdb(): + questdb_path = PROJ_ROOT / 'questdb' + if not questdb_path.exists(): + build_latest_questdb() + _run( + sys.executable, + PROJ_ROOT / 'system_test' / 'test.py', + '--repo', PROJ_ROOT / 'questdb', + '-v') + + +@command +def all(): + clean() + build() + lint() + test() + test_vs_latest_questdb() + + +def main(): + if len(sys.argv) < 2: + sys.stderr.write('Usage: python3 proj.py <command>\n') + sys.stderr.write('Commands:\n') + for command in COMMANDS: + sys.stderr.write(f' {command}\n') + sys.stderr.write('\n') + sys.exit(0) + fn = sys.argv[1] + args = list(sys.argv)[2:] + globals()[fn](*args) + + +if __name__ == '__main__': + main() diff --git a/questdb-rs-ffi/src/lib.rs b/questdb-rs-ffi/src/lib.rs index 2db6a917..c0815d5d 100644 --- a/questdb-rs-ffi/src/lib.rs +++ b/questdb-rs-ffi/src/lib.rs @@ -34,6 +34,7 @@ use std::slice; use std::str; use questdb::{ + ingress, ingress::{ Buffer, CertificateAuthority, ColumnName, Protocol, Sender, SenderBuilder, TableName, TimestampMicros, TimestampNanos, @@ -135,6 +136,18 @@ pub enum line_sender_error_code { /// Bad configuration. line_sender_error_config_error, + + /// Currently, only arrays with a maximum 32 dimensions are supported. + line_sender_error_array_large_dim, + + /// ArrayView internal error, such as failure to get the size of a valid dimension. + line_sender_error_array_view_internal_error, + + /// Write arrayView to sender buffer error. + line_sender_error_array_view_write_to_buffer_error, + + /// Line sender protocol version error. + line_sender_error_line_protocol_version_error, } impl From<ErrorCode> for line_sender_error_code { @@ -159,6 +172,18 @@ impl From<ErrorCode> for line_sender_error_code { line_sender_error_code::line_sender_error_server_flush_error } ErrorCode::ConfigError => line_sender_error_code::line_sender_error_config_error, + ErrorCode::ArrayHasTooManyDims => { + line_sender_error_code::line_sender_error_array_large_dim + } + ErrorCode::ArrayViewError => { + line_sender_error_code::line_sender_error_array_view_internal_error + } + ErrorCode::ArrayWriteToBufferError => { + line_sender_error_code::line_sender_error_array_view_write_to_buffer_error + } + ErrorCode::LineProtocolVersionError => { + line_sender_error_code::line_sender_error_line_protocol_version_error + } } } } @@ -202,6 +227,37 @@ impl From<line_sender_protocol> for Protocol { } } +/// The version of Line Protocol used for [`Buffer`]. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub enum LineProtocolVersion { + /// Version 1 of Line Protocol. + /// Uses text format serialization for f64. + V1 = 1, + + /// Version 2 of InfluxDB Line Protocol. + /// Uses binary format serialization for f64, and support array data type. + V2 = 2, +} + +impl From<LineProtocolVersion> for ingress::LineProtocolVersion { + fn from(version: LineProtocolVersion) -> Self { + match version { + LineProtocolVersion::V1 => ingress::LineProtocolVersion::V1, + LineProtocolVersion::V2 => ingress::LineProtocolVersion::V2, + } + } +} + +impl From<ingress::LineProtocolVersion> for LineProtocolVersion { + fn from(version: ingress::LineProtocolVersion) -> Self { + match version { + ingress::LineProtocolVersion::V1 => LineProtocolVersion::V1, + ingress::LineProtocolVersion::V2 => LineProtocolVersion::V2, + } + } +} + /// Possible sources of the root certificates used to validate the server's TLS certificate. #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -561,6 +617,17 @@ pub unsafe extern "C" fn line_sender_buffer_with_max_name_len( Box::into_raw(Box::new(line_sender_buffer(buffer))) } +#[no_mangle] +pub unsafe extern "C" fn line_sender_buffer_set_line_protocol_version( + buffer: *mut line_sender_buffer, + version: LineProtocolVersion, + err_out: *mut *mut line_sender_error, +) -> bool { + let buffer = unwrap_buffer_mut(buffer); + bubble_err_to_c!(err_out, buffer.set_line_proto_version(version.into())); + true +} + /// Release the `line_sender_buffer` object. #[no_mangle] pub unsafe extern "C" fn line_sender_buffer_free(buffer: *mut line_sender_buffer) { @@ -804,6 +871,41 @@ pub unsafe extern "C" fn line_sender_buffer_column_str( true } +/// Record a float multidimensional array value for the given column. +/// @param[in] buffer Line buffer object. +/// @param[in] name Column name. +/// @param[in] rank Array dims. +/// @param[in] shape Array shapes. +/// @param[in] strides Array strides. +/// @param[in] data_buffer Array **first element** data memory ptr. +/// @param[in] data_buffer_len Array data memory length. +/// @param[out] err_out Set on error. +/// # Safety +/// - All pointer parameters must be valid and non-null +/// - shape must point to an array of `rank` integers +/// - data_buffer must point to a buffer of size `data_buffer_len` bytes +#[no_mangle] +pub unsafe extern "C" fn line_sender_buffer_column_f64_arr( + buffer: *mut line_sender_buffer, + name: line_sender_column_name, + rank: size_t, + shape: *const usize, + strides: *const isize, + data_buffer: *const u8, + data_buffer_len: size_t, + err_out: *mut *mut line_sender_error, +) -> bool { + let buffer = unwrap_buffer_mut(buffer); + let name = name.as_name(); + let view = + ingress::StrideArrayView::<f64>::new(rank, shape, strides, data_buffer, data_buffer_len); + bubble_err_to_c!( + err_out, + buffer.column_arr::<ColumnName<'_>, ingress::StrideArrayView<'_, f64>, f64>(name, &view) + ); + true +} + /// Record a nanosecond timestamp value for the given column. /// @param[in] buffer Line buffer object. /// @param[in] name Column name. @@ -1066,6 +1168,15 @@ pub unsafe extern "C" fn line_sender_opts_token_y( upd_opts!(opts, err_out, token_y, token_y.as_str()) } +/// Disable the line protocol validation. +#[no_mangle] +pub unsafe extern "C" fn line_sender_opts_disable_line_protocol_validation( + opts: *mut line_sender_opts, + err_out: *mut *mut line_sender_error, +) -> bool { + upd_opts!(opts, err_out, disable_line_protocol_validation) +} + /// Configure how long to wait for messages from the QuestDB server during /// the TLS handshake and authentication process. /// The value is in milliseconds, and the default is 15 seconds. @@ -1295,6 +1406,24 @@ unsafe fn unwrap_sender_mut<'a>(sender: *mut line_sender) -> &'a mut Sender { &mut (*sender).0 } +/// Returns the client's recommended default line protocol version. +/// Will be used to [`line_sender_buffer_set_line_protocol_version`] +/// +/// The version selection follows these rules: +/// 1. **TCP/TCPS Protocol**: Always returns [`LineProtocolVersion::V2`] +/// 2. **HTTP/HTTPS Protocol**: +/// - If line protocol auto-detection is disabled [`line_sender_opts_disable_line_protocol_validation`], returns [`LineProtocolVersion::V2`] +/// - If line protocol auto-detection is enabled: +/// - Uses the server's default version if supported by the client +/// - Otherwise uses the highest mutually supported version from the intersection +/// of client and server compatible versions +#[no_mangle] +pub unsafe extern "C" fn line_sender_default_line_protocol_version( + sender: *const line_sender, +) -> LineProtocolVersion { + unwrap_sender(sender).default_line_protocol_version().into() +} + /// Tell whether the sender is no longer usable and must be closed. /// This happens when there was an earlier failure. /// This fuction is specific to TCP and is not relevant for HTTP. diff --git a/questdb-rs/Cargo.toml b/questdb-rs/Cargo.toml index 90ace7f3..15b3b1f3 100644 --- a/questdb-rs/Cargo.toml +++ b/questdb-rs/Cargo.toml @@ -23,7 +23,7 @@ socket2 = "0.5.5" dns-lookup = "2.0.4" base64ct = { version = "1.7", features = ["alloc"] } rustls-pemfile = "2.0.0" -ryu = "1.0" +ryu = { version = "1.0" } itoa = "1.0" aws-lc-rs = { version = "1.13", optional = true } ring = { version = "0.17.14", optional = true } @@ -39,6 +39,7 @@ ureq = { version = "3.0.10, <3.1.0", default-features = false, features = ["rust serde_json = { version = "1", optional = true } questdb-confstr = "0.1.1" rand = { version = "0.9.0", optional = true } +ndarray = { version = "0.16", optional = true } no-panic = { version = "0.1", optional = true } [target.'cfg(windows)'.dependencies] @@ -55,6 +56,8 @@ mio = { version = "1", features = ["os-poll", "net"] } chrono = "0.4.31" tempfile = "3" webpki-roots = "0.26.8" +criterion = "0.5" +rstest = "0.25.0" [features] default = ["tls-webpki-certs", "ilp-over-http", "aws-lc-crypto"] @@ -83,6 +86,8 @@ json_tests = [] # Enable methods to create timestamp objects from chrono::DateTime objects. chrono_timestamp = ["chrono"] +benchmark = [] + # The `aws-lc-crypto` and `ring-crypto` features are mutually exclusive, # thus compiling with `--all-features` will not work. # Instead compile with `--features almost-all-features`. @@ -93,12 +98,18 @@ almost-all-features = [ "aws-lc-crypto", "insecure-skip-verify", "json_tests", - "chrono_timestamp" + "chrono_timestamp", + "ndarray" ] +[[bench]] +name = "ndarr" +harness = false +required-features = ["benchmark", "ndarray"] + [[example]] name = "basic" -required-features = ["chrono_timestamp"] +required-features = ["chrono_timestamp", "ndarray"] [[example]] name = "auth" @@ -110,4 +121,8 @@ required-features = ["chrono_timestamp"] [[example]] name = "http" -required-features = ["ilp-over-http"] +required-features = ["ilp-over-http", "ndarray"] + +[[example]] +name = "line_protocol_version" +required-features = ["ilp-over-http", "ndarray"] diff --git a/questdb-rs/benches/ndarr.rs b/questdb-rs/benches/ndarr.rs new file mode 100644 index 00000000..9491644c --- /dev/null +++ b/questdb-rs/benches/ndarr.rs @@ -0,0 +1,115 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use ndarray::{Array, Array2}; +use questdb::ingress::{Buffer, ColumnName, StrideArrayView}; + +/// run with +/// ```shell +/// cargo bench --bench ndarr --features="benchmark, ndarray" +/// ``` +fn bench_write_array_data(c: &mut Criterion) { + let mut group = c.benchmark_group("write_array_data"); + let contiguous_array: Array2<f64> = Array::zeros((1000, 1000)); + let non_contiguous_array = contiguous_array.t(); + assert!(contiguous_array.is_standard_layout()); + assert!(!non_contiguous_array.is_standard_layout()); + + let col_name = ColumnName::new("col1").unwrap(); + // Case 1 + group.bench_function("contiguous_writer", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer + .column_arr(col_name, black_box(&contiguous_array.view())) + .unwrap(); + }); + buffer.clear(); + }); + + // Case 2 + group.bench_function("contiguous_raw_buffer", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer + .column_arr_use_raw_buffer(col_name, black_box(&contiguous_array.view())) + .unwrap(); + }); + buffer.clear(); + }); + + // Case 3 + group.bench_function("non_contiguous_writer", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer + .column_arr(col_name, black_box(&non_contiguous_array.view())) + .unwrap(); + }); + buffer.clear(); + }); + + // Case 4 + group.bench_function("non_contiguous_raw_buffer", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer + .column_arr_use_raw_buffer(col_name, black_box(&non_contiguous_array.view())) + .unwrap(); + }); + buffer.clear(); + }); + + group.finish(); +} + +// bench NdArrayView and StridedArrayView write performance. +fn bench_array_view(c: &mut Criterion) { + let mut group = c.benchmark_group("write_array_view"); + let col_name = ColumnName::new("col1").unwrap(); + let array: Array2<f64> = Array::ones((1000, 1000)); + let transposed_view = array.t(); + + // Case 1 + group.bench_function("ndarray_view", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer + .column_arr(col_name, black_box(&transposed_view)) + .unwrap(); + }); + buffer.clear(); + }); + + let elem_size = size_of::<f64>() as isize; + let strides: Vec<isize> = transposed_view + .strides() + .iter() + .map(|&s| s * elem_size) // 转换为字节步长 + .collect(); + let view2: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + transposed_view.ndim(), + transposed_view.shape().as_ptr(), + strides.as_ptr(), + transposed_view.as_ptr() as *const u8, + transposed_view.len() * elem_size as usize, + ) + }; + + // Case 2 + group.bench_function("strides_view", |b| { + let mut buffer = Buffer::new(); + buffer.table("x1").unwrap(); + b.iter(|| { + buffer.column_arr(col_name, black_box(&view2)).unwrap(); + }); + buffer.clear(); + }); +} + +criterion_group!(benches, bench_write_array_data, bench_array_view); +criterion_main!(benches); diff --git a/questdb-rs/build.rs b/questdb-rs/build.rs index 636f55e8..99d0cf10 100644 --- a/questdb-rs/build.rs +++ b/questdb-rs/build.rs @@ -50,6 +50,8 @@ pub mod json_tests { #[derive(Debug, Serialize, Deserialize)] struct Expected { line: Option<String>, + #[serde(rename = "binaryBase64")] + binary_base64: Option<String>, #[serde(rename = "anyLines")] any_lines: Option<Vec<String>>, @@ -95,8 +97,11 @@ pub mod json_tests { indoc! {r#" // This file is auto-generated by build.rs. - use crate::{Result, ingress::{Buffer}}; + use crate::{Result, ingress::{Buffer, LineProtocolVersion}}; use crate::tests::{TestResult}; + use base64ct::Base64; + use base64ct::Encoding; + use rstest::rstest; fn matches_any_line(line: &[u8], expected: &[&str]) -> bool { for &exp in expected { @@ -117,14 +122,17 @@ pub mod json_tests { // for line in serde_json::to_string_pretty(&spec).unwrap().split("\n") { // writeln!(output, "/// {}", line)?; // } - writeln!(output, "#[test]")?; + writeln!(output, "#[rstest]")?; writeln!( output, - "fn test_{:03}_{}() -> TestResult {{", + "fn test_{:03}_{}(\n #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion,\n) -> TestResult {{", index, slugify!(&spec.test_name, separator = "_") )?; - writeln!(output, " let mut buffer = Buffer::new();")?; + writeln!( + output, + " let mut buffer = Buffer::new().with_line_proto_version(version)?;" + )?; let (expected, indent) = match &spec.result { Outcome::Success(line) => (Some(line), ""), @@ -168,7 +176,46 @@ pub mod json_tests { } writeln!(output, "{} .at_now()?;", indent)?; if let Some(expected) = expected { - if let Some(ref line) = expected.line { + if let Some(ref base64) = expected.binary_base64 { + writeln!(output, " if version != LineProtocolVersion::V1 {{")?; + writeln!( + output, + " let exp = Base64::decode_vec(\"{}\").unwrap();", + base64 + )?; + writeln!( + output, + " assert_eq!(buffer.as_bytes(), exp.as_slice());" + )?; + writeln!(output, " }} else {{")?; + if let Some(ref line) = expected.line { + let exp_ln = format!("{}\n", line); + writeln!(output, " let exp = {:?};", exp_ln)?; + writeln!( + output, + " assert_eq!(buffer.as_bytes(), exp.as_bytes());" + )?; + } else { + // 处理 V1 版本的 any_lines + let any: Vec<String> = expected + .any_lines + .as_ref() + .unwrap() + .iter() + .map(|line| format!("{}\n", line)) + .collect(); + writeln!(output, " let any = [")?; + for line in any.iter() { + writeln!(output, " {:?},", line)?; + } + writeln!(output, " ];")?; + writeln!( + output, + " assert!(matches_any_line(buffer.as_bytes(), &any));" + )?; + } + writeln!(output, " }}")?; + } else if let Some(ref line) = expected.line { let exp_ln = format!("{}\n", line); writeln!(output, " let exp = {:?};", exp_ln)?; writeln!(output, " assert_eq!(buffer.as_bytes(), exp.as_bytes());")?; diff --git a/questdb-rs/examples/basic.rs b/questdb-rs/examples/basic.rs index 6fa665d3..5df32873 100644 --- a/questdb-rs/examples/basic.rs +++ b/questdb-rs/examples/basic.rs @@ -1,4 +1,5 @@ use chrono::{TimeZone, Utc}; +use ndarray::arr1; use questdb::{ ingress::{Buffer, Sender, TimestampNanos}, Result, @@ -17,6 +18,7 @@ fn main() -> Result<()> { .symbol("side", "sell")? .column_f64("price", 2615.54)? .column_f64("amount", 0.00044)? + .column_arr("location", &arr1(&[100.0, 100.1, 100.2]).view())? .at(designated_timestamp)?; //// If you want to pass the current system timestamp, replace with: diff --git a/questdb-rs/examples/http.rs b/questdb-rs/examples/http.rs index 74b2f3e9..ead1efde 100644 --- a/questdb-rs/examples/http.rs +++ b/questdb-rs/examples/http.rs @@ -1,3 +1,4 @@ +use ndarray::arr1; use questdb::{ ingress::{Buffer, Sender, TimestampNanos}, Result, @@ -5,13 +6,15 @@ use questdb::{ fn main() -> Result<()> { let mut sender = Sender::from_conf("https::addr=localhost:9000;username=foo;password=bar;")?; - let mut buffer = Buffer::new(); + let mut buffer = + Buffer::new().with_line_proto_version(sender.default_line_protocol_version())?; buffer .table("trades")? .symbol("symbol", "ETH-USD")? .symbol("side", "sell")? .column_f64("price", 2615.54)? .column_f64("amount", 0.00044)? + .column_arr("location", &arr1(&[100.0, 100.1, 100.2]).view())? .at(TimestampNanos::now())?; sender.flush(&mut buffer)?; Ok(()) diff --git a/questdb-rs/examples/line_protocol_version.rs b/questdb-rs/examples/line_protocol_version.rs new file mode 100644 index 00000000..656db6b8 --- /dev/null +++ b/questdb-rs/examples/line_protocol_version.rs @@ -0,0 +1,33 @@ +use ndarray::arr1; +use questdb::ingress::LineProtocolVersion; +use questdb::{ + ingress::{Buffer, Sender, TimestampNanos}, + Result, +}; + +fn main() -> Result<()> { + let mut sender = Sender::from_conf("https::addr=localhost:9000;username=foo;password=bar;")?; + let mut buffer = Buffer::new().with_line_proto_version(LineProtocolVersion::V1)?; + buffer + .table("trades_ilp_v1")? + .symbol("symbol", "ETH-USD")? + .symbol("side", "sell")? + .column_f64("price", 2615.54)? + .column_f64("amount", 0.00044)? + .at(TimestampNanos::now())?; + sender.flush(&mut buffer)?; + + let mut sender2 = Sender::from_conf("https::addr=localhost:9000;username=foo;password=bar;")?; + let mut buffer2 = + Buffer::new().with_line_proto_version(sender2.default_line_protocol_version())?; + buffer2 + .table("trades_ilp_v2")? + .symbol("symbol", "ETH-USD")? + .symbol("side", "sell")? + .column_f64("price", 2615.54)? + .column_f64("amount", 0.00044)? + .column_arr("location", &arr1(&[100.0, 100.1, 100.2]).view())? + .at(TimestampNanos::now())?; + sender2.flush(&mut buffer2)?; + Ok(()) +} diff --git a/questdb-rs/src/error.rs b/questdb-rs/src/error.rs index 45f56650..cadb0971 100644 --- a/questdb-rs/src/error.rs +++ b/questdb-rs/src/error.rs @@ -48,6 +48,18 @@ pub enum ErrorCode { /// Bad configuration. ConfigError, + + /// Array has too many dims. Currently, only arrays with a maximum [`crate::ingress::MAX_DIMS`] dimensions are supported. + ArrayHasTooManyDims, + + /// Array view internal error. + ArrayViewError, + + /// Array write to buffer error. + ArrayWriteToBufferError, + + /// Validate line protocol version error. + LineProtocolVersionError, } /// An error that occurred when using QuestDB client library. diff --git a/questdb-rs/src/ingress/http.rs b/questdb-rs/src/ingress/http.rs index 323a4a92..92c867bf 100644 --- a/questdb-rs/src/ingress/http.rs +++ b/questdb-rs/src/ingress/http.rs @@ -1,4 +1,5 @@ use super::conf::ConfigSetting; +use crate::error::fmt; use crate::{error, Error}; use base64ct::Base64; use base64ct::Encoding; @@ -16,6 +17,7 @@ use ureq::unversioned::transport::{ Buffers, Connector, LazyBuffers, NextTimeout, Transport, TransportAdapter, }; +use crate::ingress::LineProtocolVersion; use ureq::unversioned::*; use ureq::Error::*; use ureq::{http, Body}; @@ -57,6 +59,7 @@ pub(super) struct HttpConfig { pub(super) user_agent: String, pub(super) retry_timeout: ConfigSetting<Duration>, pub(super) request_timeout: ConfigSetting<Duration>, + pub(super) disable_line_proto_validation: ConfigSetting<bool>, } impl Default for HttpConfig { @@ -66,6 +69,7 @@ impl Default for HttpConfig { user_agent: concat!("questdb/rust/", env!("CARGO_PKG_VERSION")).to_string(), retry_timeout: ConfigSetting::new_default(Duration::from_secs(10)), request_timeout: ConfigSetting::new_default(Duration::from_secs(10)), + disable_line_proto_validation: ConfigSetting::new_default(false), } } } @@ -109,6 +113,24 @@ impl HttpHandlerState { Err(err) => (need_retry(Err(err)), response), } } + + pub(crate) fn get_request( + &self, + url: &str, + request_timeout: Duration, + ) -> (bool, Result<Response<Body>, ureq::Error>) { + let request = self + .agent + .get(url) + .config() + .timeout_per_call(Some(request_timeout)) + .build(); + let response = request.call(); + match &response { + Ok(res) => (need_retry(Ok(res.status())), response), + Err(err) => (need_retry(Err(err)), response), + } + } } #[derive(Debug)] @@ -383,3 +405,175 @@ pub(super) fn http_send_with_retries( retry_http_send(state, buf, request_timeout, retry_timeout, last_rep) } + +pub(super) fn get_line_protocol_version( + state: &HttpHandlerState, + settings_url: &str, +) -> Result<(Option<Vec<LineProtocolVersion>>, LineProtocolVersion), Error> { + let mut support_versions: Option<Vec<_>> = None; + let mut default_version = LineProtocolVersion::V1; + + let response = match http_get_with_retries( + state, + settings_url, + *state.config.request_timeout, + Duration::from_secs(1), + ) { + Ok(res) => { + if res.status().is_client_error() || res.status().is_server_error() { + if res.status().as_u16() == 404 { + return Ok((support_versions, default_version)); + } + return Err(fmt!( + LineProtocolVersionError, + "Failed to detect server's line protocol version, settings url: {}, status code: {}.", + settings_url, + res.status() + )); + } else { + res + } + } + Err(err) => { + let e = match err { + ureq::Error::StatusCode(code) => { + if code == 404 { + return Ok((support_versions, default_version)); + } else { + fmt!( + LineProtocolVersionError, + "Failed to detect server's line protocol version, settings url: {}, err: {}.", + settings_url, + err + ) + } + } + e => { + fmt!( + LineProtocolVersionError, + "Failed to detect server's line protocol version, settings url: {}, err: {}.", + settings_url, + e + ) + } + }; + return Err(e); + } + }; + + let (_, body) = response.into_parts(); + let body_content = body.into_with_config().lossy_utf8(true).read_to_string(); + + if let Ok(msg) = body_content { + let json: serde_json::Value = serde_json::from_str(&msg).map_err(|_| { + error::fmt!( + LineProtocolVersionError, + "Malformed server response, settings url: {}, err: response is not valid JSON.", + settings_url, + ) + })?; + + if let Some(serde_json::Value::Array(ref values)) = json.get("line.proto.support.versions") + { + let mut versions = Vec::new(); + for value in values.iter() { + if let Some(v) = value.as_u64() { + match v { + 1 => versions.push(LineProtocolVersion::V1), + 2 => versions.push(LineProtocolVersion::V2), + _ => {} + } + } + } + support_versions = Some(versions); + } + + if let Some(serde_json::Value::Number(ref v)) = json.get("line.proto.default.version") { + default_version = match v.as_u64() { + Some(vu64) => match vu64 { + 1 => LineProtocolVersion::V1, + 2 => LineProtocolVersion::V2, + _ => { + if let Some(ref versions) = support_versions { + if versions.contains(&LineProtocolVersion::V2) { + LineProtocolVersion::V2 + } else if versions.contains(&LineProtocolVersion::V1) { + LineProtocolVersion::V1 + } else { + return Err(error::fmt!( + LineProtocolVersionError, + "Server does not support current client" + )); + } + } else { + return Err(error::fmt!( + LineProtocolVersionError, + "Unexpected response version content." + )); + } + } + }, + None => { + return Err(error::fmt!( + LineProtocolVersionError, + "Not a valid int for line.proto.default.version in response." + )) + } + }; + } + } else { + return Err(error::fmt!( + LineProtocolVersionError, + "Malformed server response, settings url: {}, err: failed to read response body as UTF-8", settings_url + )); + } + Ok((support_versions, default_version)) +} + +#[allow(clippy::result_large_err)] // `ureq::Error` is large enough to cause this warning. +fn retry_http_get( + state: &HttpHandlerState, + url: &str, + request_timeout: Duration, + retry_timeout: Duration, + mut last_rep: Result<Response<Body>, ureq::Error>, +) -> Result<Response<Body>, ureq::Error> { + let mut rng = rand::rng(); + let retry_end = std::time::Instant::now() + retry_timeout; + let mut retry_interval_ms = 10; + let mut need_retry; + loop { + let jitter_ms = rng.random_range(-5i32..5); + let to_sleep_ms = retry_interval_ms + jitter_ms; + let to_sleep = Duration::from_millis(to_sleep_ms as u64); + if (std::time::Instant::now() + to_sleep) > retry_end { + return last_rep; + } + sleep(to_sleep); + if let Ok(last_rep) = last_rep { + // Actively consume the reader to return the connection to the connection pool. + // see https://github.com/algesten/ureq/issues/94 + _ = last_rep.into_body().read_to_vec(); + } + (need_retry, last_rep) = state.get_request(url, request_timeout); + if !need_retry { + return last_rep; + } + retry_interval_ms = (retry_interval_ms * 2).min(1000); + } +} + +#[allow(clippy::result_large_err)] // `ureq::Error` is large enough to cause this warning. +fn http_get_with_retries( + state: &HttpHandlerState, + url: &str, + request_timeout: Duration, + retry_timeout: Duration, +) -> Result<Response<Body>, ureq::Error> { + let (need_retry, last_rep) = state.get_request(url, request_timeout); + if !need_retry || retry_timeout.is_zero() { + return last_rep; + } + + retry_http_get(state, url, request_timeout, retry_timeout, last_rep) +} diff --git a/questdb-rs/src/ingress/mod.rs b/questdb-rs/src/ingress/mod.rs index 77b7d063..c04576ff 100644 --- a/questdb-rs/src/ingress/mod.rs +++ b/questdb-rs/src/ingress/mod.rs @@ -24,22 +24,24 @@ #![doc = include_str!("mod.md")] +pub use self::ndarr::{ArrayElement, NdArrayView, StrideArrayView}; pub use self::timestamp::*; - use crate::error::{self, Error, Result}; use crate::gai; use crate::ingress::conf::ConfigSetting; use base64ct::{Base64, Base64UrlUnpadded, Encoding}; use core::time::Duration; +use ndarr::ArrayElementSealed; use rustls::{ClientConnection, RootCertStore, StreamOwned}; use rustls_pki_types::ServerName; use socket2::{Domain, Protocol as SockProtocol, SockAddr, Socket, Type}; use std::collections::HashMap; use std::convert::Infallible; use std::fmt::{Debug, Display, Formatter, Write}; -use std::io::{self, BufRead, BufReader, ErrorKind, Write as IoWrite}; +use std::io::{self, BufRead, BufReader, Cursor, ErrorKind, Write as IoWrite}; use std::ops::Deref; use std::path::PathBuf; +use std::slice::from_raw_parts_mut; use std::str::FromStr; use std::sync::Arc; @@ -55,6 +57,25 @@ use ring::{ signature::{EcdsaKeyPair, ECDSA_P256_SHA256_FIXED_SIGNING}, }; +/// Defines the maximum allowed dimensions for array data in binary serialization protocols. +pub const MAX_DIMS: usize = 32; + +/// Line Protocol Version supported by current client. +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum LineProtocolVersion { + V1 = 1, + V2 = 2, +} + +impl std::fmt::Display for LineProtocolVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + LineProtocolVersion::V1 => write!(f, "v1"), + LineProtocolVersion::V2 => write!(f, "v2"), + } + } +} + #[derive(Debug, Copy, Clone)] enum Op { Table = 1, @@ -464,6 +485,8 @@ impl BufferState { } } +pub trait Buffer1 {} + /// A reusable buffer to prepare a batch of ILP messages. /// /// # Example @@ -506,6 +529,7 @@ impl BufferState { /// [`column_i64`](Buffer::column_i64), /// [`column_f64`](Buffer::column_f64), /// [`column_str`](Buffer::column_str), +/// [`column_arr`](Buffer::column_arr), /// [`column_ts`](Buffer::column_ts)). /// * Symbols must appear before columns. /// * A row must be terminated with either [`at`](Buffer::at) or @@ -524,6 +548,7 @@ impl BufferState { /// | [`column_i64`](Buffer::column_i64) | [`INTEGER`](https://questdb.io/docs/reference/api/ilp/columnset-types#integer) | /// | [`column_f64`](Buffer::column_f64) | [`FLOAT`](https://questdb.io/docs/reference/api/ilp/columnset-types#float) | /// | [`column_str`](Buffer::column_str) | [`STRING`](https://questdb.io/docs/reference/api/ilp/columnset-types#string) | +/// | [`column_arr`](Buffer::column_arr) | [`ARRAY`](https://questdb.io/docs/reference/api/ilp/columnset-types#array) | /// | [`column_ts`](Buffer::column_ts) | [`TIMESTAMP`](https://questdb.io/docs/reference/api/ilp/columnset-types#timestamp) | /// /// QuestDB supports both `STRING` and `SYMBOL` column types. @@ -555,6 +580,8 @@ pub struct Buffer { state: BufferState, marker: Option<(usize, BufferState)>, max_name_len: usize, + f64serializer: fn(&mut Vec<u8>, f64), + version: LineProtocolVersion, } impl Buffer { @@ -566,6 +593,8 @@ impl Buffer { state: BufferState::new(), marker: None, max_name_len: 127, + f64serializer: f64_binary_series, + version: LineProtocolVersion::V2, } } @@ -582,6 +611,36 @@ impl Buffer { buf } + pub fn with_line_proto_version(mut self, version: LineProtocolVersion) -> Result<Self> { + if self.state.op_case != OpCase::Init { + return Err(error::fmt!( + LineProtocolVersionError, + "Line protocol version must be set before adding any data." + )); + } + self.f64serializer = match version { + LineProtocolVersion::V1 => f64_text_series, + LineProtocolVersion::V2 => f64_binary_series, + }; + self.version = version; + Ok(self) + } + + pub fn set_line_proto_version(&mut self, version: LineProtocolVersion) -> Result<&mut Self> { + if self.state.op_case != OpCase::Init { + return Err(error::fmt!( + LineProtocolVersionError, + "Line protocol version must be set before adding any data." + )); + } + self.f64serializer = match version { + LineProtocolVersion::V1 => f64_text_series, + LineProtocolVersion::V2 => f64_binary_series, + }; + self.version = version; + Ok(self) + } + /// Pre-allocate to ensure the buffer has enough capacity for at least the /// specified additional byte count. This may be rounded up. /// This does not allocate if such additional capacity is already satisfied. @@ -953,8 +1012,7 @@ impl Buffer { Error: From<N::Error>, { self.write_column_key(name)?; - let mut ser = F64Serializer::new(value); - self.output.extend_from_slice(ser.as_str().as_bytes()); + (self.f64serializer)(&mut self.output, value); Ok(self) } @@ -1011,6 +1069,234 @@ impl Buffer { Ok(self) } + /// Record a multidimensional array value for the given column. + /// + /// Supports arrays with up to [`MAX_DIMS`] dimensions. The array elements must + /// implement [`ArrayElement`] trait which provides type-to-[`ElemDataType`] mapping. + /// + /// # Examples + /// + /// Basic usage with direct dimension specification: + /// + /// ``` + /// # #[cfg(feature = "ndarray")] + /// # { + /// # use questdb::Result; + /// # use questdb::ingress::Buffer; + /// # use ndarray::array; + /// # fn main() -> Result<()> { + /// # let mut buffer = Buffer::new(); + /// # buffer.table("x")?; + /// // Record a 2D array of f64 values + /// let array_2d = array![[1.1, 2.2], [3.3, 4.4]]; + /// buffer.column_arr("array_col", &array_2d.view())?; + /// # Ok(()) + /// # } + /// # } + /// + /// ``` + /// + /// Using [`ColumnName`] for validated column names: + /// + /// ``` + /// # #[cfg(feature = "ndarray")] + /// # { + /// # use questdb::Result; + /// # use questdb::ingress::{Buffer, ColumnName}; + /// # use ndarray::Array3; + /// # fn main() -> Result<()> { + /// # let mut buffer = Buffer::new(); + /// # buffer.table("x1")?; + /// // Record a 3D array of f64 values + /// let array_3d = Array3::from_elem((2, 3, 4), 42f64); + /// let col_name = ColumnName::new("col1")?; + /// buffer.column_arr(col_name, &array_3d.view())?; + /// # Ok(()) + /// # } + /// # } + /// ``` + /// # Errors + /// + /// Returns [`Error`] if: + /// - Array dimensions exceed [`MAX_DIMS`] + /// - Failed to get dimension sizes + /// - Column name validation fails + #[allow(private_bounds)] + pub fn column_arr<'a, N, T, D>(&mut self, name: N, view: &T) -> Result<&mut Self> + where + N: TryInto<ColumnName<'a>>, + T: NdArrayView<D>, + D: ArrayElement + ArrayElementSealed, + Error: From<N::Error>, + { + if self.version == LineProtocolVersion::V1 { + return Err(error::fmt!( + LineProtocolVersionError, + "line protocol version v1 does not support array datatype", + )); + } + let ndim = view.ndim(); + if ndim == 0 { + return Err(error::fmt!( + ArrayViewError, + "Zero-dimensional arrays are not supported", + )); + } + + self.write_column_key(name)?; + + // check dimension less equal than max dims + if MAX_DIMS < ndim { + return Err(error::fmt!( + ArrayHasTooManyDims, + "Array dimension mismatch: expected at most {} dimensions, but got {}", + MAX_DIMS, + ndim + )); + } + + // TODO: Remove `check_data_buf` this from the trait. + // It's private impl details that can be coded generically + let array_buf_size = view.check_data_buf()?; + if array_buf_size > i32::MAX as usize { + // TODO: We should probably agree on a significantly + // _smaller_ limit here, since there's no way + // we've ever tested anything that big. + // My gut feeling is that the maximum array buffer should be + // in the order of 100MB or so. + return Err(error::fmt!( + ArrayViewError, + "Array buffer size too big: {}", + array_buf_size + )); + } + + // binary format flag '=' + self.output.push(b'='); + // binary format entity type + self.output.push(ARRAY_BINARY_FORMAT_TYPE); + // ndarr datatype + self.output.push(D::type_tag()); + // ndarr dims + self.output.push(ndim as u8); + + let dim_header_size = size_of::<u32>() * ndim; + self.output.reserve(dim_header_size + array_buf_size); + + for i in 0..ndim { + let dim = view.dim(i).ok_or_else(|| { + error::fmt!( + ArrayViewError, + "Cannot get correct dimensions for dim {}", + i + ) + })?; + + // TODO: check that the dimension is not past + // the maximum size that the java impl will accept. + // I seem to remember that it's 2^28-1 or something like that. + // Must check Java impl. + + // ndarr shapes + self.output + .extend_from_slice((dim as u32).to_le_bytes().as_slice()); + } + + let index = self.output.len(); + let writeable = + unsafe { from_raw_parts_mut(self.output.as_mut_ptr().add(index), array_buf_size) }; + let mut cursor = Cursor::new(writeable); + + // TODO: The next section needs a bit of a rewrite. + // It also needs clear comments that explain the design decisions. + // + // I'd be expecting two code paths here: + // 1. The array is row-major contiguous + // 2. The data needs to be written out via the strides. + // + // The code here seems to do something a bit different and + // is worth explaining. + // I see two code paths that I honestly don't understand, + // the `ndarr::write_array_data` and the `ndarr::write_array_data_use_raw_buffer` + // functions both seem to handle both cases (why?) and then seem + // to construct a vectored IoSlice buffer (why??) before writing + // the strided data out. + + // ndarr data + if view.as_slice().is_some() { + if let Err(e) = ndarr::write_array_data(view, &mut cursor) { + return Err(error::fmt!( + ArrayWriteToBufferError, + "Can not write row major to writer: {}", + e + )); + } + if cursor.position() != (array_buf_size as u64) { + return Err(error::fmt!( + ArrayWriteToBufferError, + "Array write buffer length mismatch (actual: {}, expected: {})", + cursor.position(), + array_buf_size + )); + } + unsafe { self.output.set_len(array_buf_size + index) } + } else { + unsafe { self.output.set_len(array_buf_size + index) } + ndarr::write_array_data_use_raw_buffer(&mut self.output[index..], view); + } + Ok(self) + } + + #[cfg(feature = "benchmark")] + pub fn column_arr_use_raw_buffer<'a, N, T, D>(&mut self, name: N, view: &T) -> Result<&mut Self> + where + N: TryInto<ColumnName<'a>>, + T: NdArrayView<D>, + D: ArrayElement, + Error: From<N::Error>, + { + self.write_column_key(name)?; + + // check dimension less equal than max dims + if MAX_DIMS < view.ndim() { + return Err(error::fmt!( + ArrayHasTooManyDims, + "Array dimension mismatch: expected at most {} dimensions, but got {}", + MAX_DIMS, + view.ndim() + )); + } + + let reserve_size = view.check_data_buf()?; + // binary format flag '=' + self.output.push(b'='); + // binary format entity type + self.output.push(ARRAY_BINARY_FORMAT_TYPE); + // ndarr datatype + self.output.push(D::elem_type().into()); + // ndarr dims + self.output.push(view.ndim() as u8); + + for i in 0..view.ndim() { + let d = view.dim(i).ok_or_else(|| { + error::fmt!( + ArrayViewError, + "Can not get correct dimensions for dim {}", + i + ) + })?; + // ndarr shapes + self.output + .extend_from_slice((d as i32).to_le_bytes().as_slice()); + } + + self.output.reserve(reserve_size); + let index = self.output.len(); + unsafe { self.output.set_len(reserve_size + index) } + ndarr::write_array_data_use_raw_buffer(&mut self.output[index..], view); + Ok(self) + } + /// Record a timestamp value for the given column. /// /// ``` @@ -1060,7 +1346,7 @@ impl Buffer { /// or you can also pass in a `TimestampNanos`. /// /// Note that both `TimestampMicros` and `TimestampNanos` can be constructed - /// easily from either `chrono::DateTime` and `std::time::SystemTime`. + /// easily from either `std::time::SystemTime` or `chrono::DateTime`. /// /// This last option requires the `chrono_timestamp` feature. pub fn column_ts<'a, N, T>(&mut self, name: N, value: T) -> Result<&mut Self> @@ -1114,7 +1400,7 @@ impl Buffer { /// You can also pass in a `TimestampMicros`. /// /// Note that both `TimestampMicros` and `TimestampNanos` can be constructed - /// easily from either `chrono::DateTime` and `std::time::SystemTime`. + /// easily from either `std::time::SystemTime` or `chrono::DateTime`. /// pub fn at<T>(&mut self, timestamp: T) -> Result<()> where @@ -1199,6 +1485,9 @@ pub struct Sender { handler: ProtocolHandler, connected: bool, max_buf_size: usize, + default_line_protocol_version: LineProtocolVersion, + #[cfg(feature = "ilp-over-http")] + supported_line_protocol_versions: Option<Vec<LineProtocolVersion>>, } impl std::fmt::Debug for Sender { @@ -1822,6 +2111,19 @@ impl SenderBuilder { "retry_timeout" => { builder.retry_timeout(Duration::from_millis(parse_conf_value(key, val)?))? } + + #[cfg(feature = "ilp-over-http")] + "disable_line_protocol_validation" => { + if val == "on" { + builder.disable_line_protocol_validation()? + } else if val != "off" { + return Err(error::fmt!( + ConfigError, "invalid \"disable_line_protocol_validation\" [value={val}, allowed-values=[on, off]]]\"]")); + } else { + builder + } + } + // Ignore other parameters. // We don't want to fail on unknown keys as this would require releasing different // library implementations in lock step as soon as a new parameter is added to any of them, @@ -2092,6 +2394,24 @@ impl SenderBuilder { Ok(self) } + #[cfg(feature = "ilp-over-http")] + /// Disables automatic line protocol version validation for ILP-over-HTTP. + /// + /// - When set to `"off"`: Skips the initial server version handshake and disables protocol validation. + /// - When set to `"on"`: Keeps default validation behavior (recommended). + /// + /// Please ensure client's default version ([`LINE_PROTOCOL_VERSION_V2`]) or + /// explicitly set protocol version exactly matches server expectation. + pub fn disable_line_protocol_validation(mut self) -> Result<Self> { + if let Some(http) = &mut self.http { + // ignore "already specified" error + let _ = http + .disable_line_proto_validation + .set_specified("disable_line_protocol_validation", true); + } + Ok(self) + } + #[cfg(feature = "ilp-over-http")] /// Internal API, do not use. /// This is exposed exclusively for the Python client. @@ -2374,23 +2694,52 @@ impl SenderBuilder { agent, url, auth, - config: self.http.as_ref().unwrap().clone(), }) } }; + let mut default_line_protocol_version = LineProtocolVersion::V2; + #[cfg(feature = "ilp-over-http")] + let mut supported_line_protocol_versions: Option<Vec<_>> = None; + + #[cfg(feature = "ilp-over-http")] + match self.protocol { + Protocol::Tcp | Protocol::Tcps => {} + Protocol::Http | Protocol::Https => { + let http_config = self.http.as_ref().unwrap(); + if !*http_config.disable_line_proto_validation.deref() { + if let ProtocolHandler::Http(http_state) = &handler { + let settings_url = &format!( + "{}://{}:{}/settings", + self.protocol.schema(), + self.host.deref(), + self.port.deref() + ); + ( + supported_line_protocol_versions, + default_line_protocol_version, + ) = get_line_protocol_version(http_state, settings_url)?; + } else { + default_line_protocol_version = LineProtocolVersion::V1; + } + } + } + }; + if auth.is_some() { descr.push_str("auth=on]"); } else { descr.push_str("auth=off]"); } - let sender = Sender { descr, handler, connected: true, max_buf_size: *self.max_buf_size, + default_line_protocol_version, + #[cfg(feature = "ilp-over-http")] + supported_line_protocol_versions, }; Ok(sender) @@ -2509,6 +2858,17 @@ fn parse_key_pair(auth: &EcdsaAuthParams) -> Result<EcdsaKeyPair> { }) } +fn f64_text_series(vec: &mut Vec<u8>, value: f64) { + let mut ser = F64Serializer::new(value); + vec.extend_from_slice(ser.as_str().as_bytes()) +} + +fn f64_binary_series(vec: &mut Vec<u8>, value: f64) { + vec.push(b'='); + vec.push(DOUBLE_BINARY_FORMAT_TYPE); + vec.extend_from_slice(&value.to_le_bytes()) +} + pub(crate) struct F64Serializer { buf: ryu::Buffer, n: f64, @@ -2605,6 +2965,9 @@ impl Sender { )); } + #[cfg(feature = "ilp-over-http")] + self.check_line_protocol_version(buf.version)?; + let bytes = buf.as_bytes(); if bytes.is_empty() { return Ok(()); @@ -2725,9 +3088,69 @@ impl Sender { pub fn must_close(&self) -> bool { !self.connected } + + /// Returns the QuestDB server's recommended default line protocol version. + /// Will be used to [`Buffer::with_line_proto_version`] + /// + /// The version selection follows these rules: + /// 1. **TCP/TCPS Protocol**: Always returns [`LineProtocolVersion::V2`] + /// 2. **HTTP/HTTPS Protocol**: + /// - If line protocol auto-detection is disabled [`SenderBuilder::disable_line_protocol_validation`], returns [`LineProtocolVersion::V2`] + /// - If line protocol auto-detection is enabled: + /// - Uses the server's default version if supported by the client + /// - Otherwise uses the highest mutually supported version from the intersection + /// of client and server compatible versions + pub fn default_line_protocol_version(&self) -> LineProtocolVersion { + self.default_line_protocol_version + } + + #[cfg(feature = "ilp-over-http")] + #[cfg(test)] + pub(crate) fn support_line_protocol_versions(&self) -> Option<Vec<LineProtocolVersion>> { + self.supported_line_protocol_versions.clone() + } + + #[cfg(feature = "ilp-over-http")] + #[inline(always)] + fn check_line_protocol_version(&self, version: LineProtocolVersion) -> Result<()> { + match &self.handler { + ProtocolHandler::Socket(_) => Ok(()), + #[cfg(feature = "ilp-over-http")] + ProtocolHandler::Http(http) => { + if *http.config.disable_line_proto_validation.deref() { + Ok(()) + } else { + match self.supported_line_protocol_versions { + Some(ref supported_line_protocols) => { + if supported_line_protocols.contains(&version) { + Ok(()) + } else { + Err(error::fmt!( + LineProtocolVersionError, + "Line protocol version {} is not supported by current QuestDB Server", version)) + } + } + None => { + if version == LineProtocolVersion::V1 { + Ok(()) + } else { + Err(error::fmt!( + LineProtocolVersionError, + "Line protocol version {} is not supported by current QuestDB Server", version)) + } + } + } + } + } + } + } } +pub(crate) const ARRAY_BINARY_FORMAT_TYPE: u8 = 14; +pub(crate) const DOUBLE_BINARY_FORMAT_TYPE: u8 = 16; + mod conf; +pub(crate) mod ndarr; mod timestamp; #[cfg(feature = "ilp-over-http")] diff --git a/questdb-rs/src/ingress/ndarr.rs b/questdb-rs/src/ingress/ndarr.rs new file mode 100644 index 00000000..a4a1a14b --- /dev/null +++ b/questdb-rs/src/ingress/ndarr.rs @@ -0,0 +1,674 @@ +pub trait NdArrayView<T> +where + T: ArrayElement, +{ + type Iter<'a>: Iterator<Item = &'a T> + where + Self: 'a, + T: 'a; + + /// Returns the number of dimensions (rank) of the array. + fn ndim(&self) -> usize; + + /// Returns the size of the specified dimension. + fn dim(&self, index: usize) -> Option<usize>; + + /// Return the array’s data as a slice, if it is c-major-layout. + /// Return `None` otherwise. + fn as_slice(&self) -> Option<&[T]>; + + /// Return an iterator of references to the elements of the array. + /// Iterator element type is `&T`. + fn iter(&self) -> Self::Iter<'_>; + + /// Validates the data buffer size of array is consistency with array shapes. + /// + /// # Returns + /// - `Ok(usize)`: Expected buffer size in bytes if valid + /// - `Err(Error)`: Otherwise + fn check_data_buf(&self) -> Result<usize, Error>; +} + +pub fn write_array_data<W: std::io::Write, A: NdArrayView<T>, T>( + array: &A, + writer: &mut W, +) -> std::io::Result<()> +where + T: ArrayElement, +{ + // First optimization path: write contiguous memory directly + if let Some(contiguous) = array.as_slice() { + let bytes = unsafe { + slice::from_raw_parts(contiguous.as_ptr() as *const u8, size_of_val(contiguous)) + }; + return writer.write_all(bytes); + } + + // Fallback path: non-contiguous memory handling + let elem_size = size_of::<T>(); + let mut io_slices = Vec::new(); + for element in array.iter() { + let bytes = unsafe { slice::from_raw_parts(element as *const T as *const _, elem_size) }; + io_slices.push(std::io::IoSlice::new(bytes)); + } + + let mut io_slices: &mut [IoSlice<'_>] = io_slices.as_mut_slice(); + IoSlice::advance_slices(&mut io_slices, 0); + + while !io_slices.is_empty() { + let written = writer.write_vectored(io_slices)?; + if written == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::WriteZero, + "Failed to write all bytes", + )); + } + IoSlice::advance_slices(&mut io_slices, written); + } + Ok(()) +} + +pub(crate) fn write_array_data_use_raw_buffer<A: NdArrayView<T>, T>(buf: &mut [u8], array: &A) +where + T: ArrayElement, +{ + // First optimization path: write contiguous memory directly + if let Some(contiguous) = array.as_slice() { + let byte_len = size_of_val(contiguous); + unsafe { + std::ptr::copy_nonoverlapping( + contiguous.as_ptr() as *const u8, + buf.as_mut_ptr(), + byte_len, + ) + } + } + + // Fallback path: non-contiguous memory handling + let elem_size = size_of::<T>(); + for (i, &element) in array.iter().enumerate() { + unsafe { + std::ptr::copy_nonoverlapping( + &element as *const T as *const u8, + buf.as_mut_ptr().add(i * elem_size), + elem_size, + ) + } + } +} + +/// Marker trait for valid array element types. +/// +/// Implemented for primitive types that can be stored in arrays. +/// Combines type information with data type classification. +pub trait ArrayElement: Copy + 'static {} + +pub(crate) trait ArrayElementSealed { + /// Returns the binary format identifier for array element types compatible + /// with QuestDB's io.questdb.cairo.ColumnType numeric type constants. + fn type_tag() -> u8; +} + +impl ArrayElement for f64 {} + +impl ArrayElementSealed for f64 { + fn type_tag() -> u8 { + 10 // Double + } +} + +/// A view into a multi-dimensional array with custom memory strides. +#[derive(Debug)] +pub struct StrideArrayView<'a, T> { + dims: usize, + shape: &'a [usize], + strides: &'a [isize], + + // TODO: Why a pointer and len? Shouldn't it be a `&'a [u8]` slice? + buf_len: usize, + buf: *const u8, + _marker: std::marker::PhantomData<T>, +} + +impl<T> NdArrayView<T> for StrideArrayView<'_, T> +where + T: ArrayElement, +{ + type Iter<'b> + = RowMajorIter<'b, T> + where + Self: 'b, + T: 'b; + + fn ndim(&self) -> usize { + self.dims + } + + fn dim(&self, index: usize) -> Option<usize> { + if index >= self.dims { + return None; + } + + Some(self.shape[index]) + } + + fn as_slice(&self) -> Option<&[T]> { + if self.is_c_major() { + Some(unsafe { + slice::from_raw_parts(self.buf as *const T, self.buf_len / size_of::<T>()) + }) + } else { + None + } + } + + fn iter(&self) -> Self::Iter<'_> { + let mut dim_products = Vec::with_capacity(self.dims); + let mut product = 1; + for &dim in self.shape.iter().rev() { + dim_products.push(product); + product *= dim; + } + dim_products.reverse(); + + // consider minus strides + let base_ptr = self + .strides + .iter() + .enumerate() + .fold(self.buf, |ptr, (dim, &stride)| { + if stride < 0 { + let dim_size = self.shape[dim] as isize; + unsafe { ptr.offset(stride * (dim_size - 1)) } + } else { + ptr + } + }); + RowMajorIter { + base_ptr, + array: self, + dim_products, + current_linear: 0, + total_elements: self.shape.iter().product(), + } + } + + fn check_data_buf(&self) -> Result<usize, Error> { + let mut size = size_of::<T>(); + for i in 0..self.dims { + let d = self.shape[i]; + size = size.checked_mul(d).ok_or(error::fmt!( + ArrayViewError, + "Array total elem size overflow" + ))? + } + if size != self.buf_len { + return Err(error::fmt!( + ArrayWriteToBufferError, + "Array buffer length mismatch (actual: {}, expected: {})", + self.buf_len, + size + )); + } + Ok(size) + } +} + +impl<T> StrideArrayView<'_, T> +where + T: ArrayElement, +{ + /// Creates a new strided array view from raw components (unsafe constructor). + /// + /// # Safety + /// Caller must ensure all the following conditions: + /// - `shapes` points to a valid array of at least `dims` elements + /// - `strides` points to a valid array of at least `dims` elements + /// - `data` points to a valid memory block of at least `data_len` bytes + /// - Memory layout must satisfy: + /// 1. `data_len ≥ (shape[0]-1)*abs(strides[0]) + ... + (shape[n-1]-1)*abs(strides[n-1]) + size_of::<T>()` + /// 2. All calculated offsets stay within `[0, data_len - size_of::<T>()]` + /// - Lifetime `'a` must outlive the view's usage + /// - Strides are measured in bytes (not elements) + pub unsafe fn new( + dims: usize, + shape: *const usize, + strides: *const isize, + data: *const u8, + data_len: usize, + ) -> Self { + let shapes = slice::from_raw_parts(shape, dims); + let strides = slice::from_raw_parts(strides, dims); + Self { + dims, + shape: shapes, + strides, + buf_len: data_len, + buf: data, + _marker: std::marker::PhantomData::<T>, + } + } + + /// Verifies if the array follows C-style row-major memory layout. + fn is_c_major(&self) -> bool { + if self.buf.is_null() || self.buf_len == 0 { + return false; + } + + let elem_size = size_of::<T>() as isize; + if self.dims == 1 { + return self.strides[0] == elem_size || self.shape[0] == 1; + } + + let mut expected_stride = elem_size; + for (dim, &stride) in self.shape.iter().zip(self.strides).rev() { + if *dim > 1 && stride != expected_stride { + return false; + } + expected_stride *= *dim as isize; + } + true + } +} + +/// Iterator for traversing a stride array in row-major (C-style) order. +pub struct RowMajorIter<'a, T> { + base_ptr: *const u8, + array: &'a StrideArrayView<'a, T>, + dim_products: Vec<usize>, + current_linear: usize, + total_elements: usize, +} + +impl<'a, T> Iterator for RowMajorIter<'a, T> +where + T: ArrayElement, +{ + type Item = &'a T; + fn next(&mut self) -> Option<Self::Item> { + if self.current_linear >= self.total_elements { + return None; + } + let mut remaining_index = self.current_linear; + let mut offset = 0; + + for (dim, &dim_factor) in self.dim_products.iter().enumerate() { + let coord = remaining_index / dim_factor; + remaining_index %= dim_factor; + let stride = self.array.strides[dim]; + let actual_coord = if stride >= 0 { + coord + } else { + self.array.shape[dim] - 1 - coord + }; + offset += actual_coord * stride.unsigned_abs(); + } + + self.current_linear += 1; + unsafe { + let ptr = self.base_ptr.add(offset); + Some(&*(ptr as *const T)) + } + } +} + +/// impl NdArrayView for one dimension vector +impl<T: ArrayElement> NdArrayView<T> for Vec<T> { + type Iter<'a> + = std::slice::Iter<'a, T> + where + T: 'a; + + fn ndim(&self) -> usize { + 1 + } + + fn dim(&self, idx: usize) -> Option<usize> { + (idx == 0).then_some(self.len()) + } + + fn as_slice(&self) -> Option<&[T]> { + Some(self.as_slice()) + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(self.len() * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for one dimension array +impl<T: ArrayElement, const N: usize> NdArrayView<T> for [T; N] { + type Iter<'a> + = std::slice::Iter<'a, T> + where + T: 'a; + + fn ndim(&self) -> usize { + 1 + } + + fn dim(&self, idx: usize) -> Option<usize> { + (idx == 0).then_some(N) + } + + fn as_slice(&self) -> Option<&[T]> { + Some(self) + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(N * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for one dimension slice +impl<T: ArrayElement> NdArrayView<T> for &[T] { + type Iter<'a> + = std::slice::Iter<'a, T> + where + Self: 'a, + T: 'a; + + fn ndim(&self) -> usize { + 1 + } + + fn dim(&self, idx: usize) -> Option<usize> { + (idx == 0).then_some(self.len()) + } + + fn as_slice(&self) -> Option<&[T]> { + Some(self) + } + + fn iter(&self) -> Self::Iter<'_> { + <[T]>::iter(self) + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(std::mem::size_of_val(*self)) + } +} + +/// impl NdArrayView for two dimensions vector +impl<T: ArrayElement> NdArrayView<T> for Vec<Vec<T>> { + type Iter<'a> + = std::iter::Flatten<std::slice::Iter<'a, Vec<T>>> + where + T: 'a; + + fn ndim(&self) -> usize { + 2 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(self.len()), + 1 => self.first().map(|v| v.len()), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + None + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter().flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + let row_len = self.first().map_or(0, |v| v.len()); + if self.as_slice().iter().any(|v| v.len() != row_len) { + return Err(error::fmt!(ArrayViewError, "Irregular array shape")); + } + Ok(self.len() * row_len * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for two dimensions array +impl<T: ArrayElement, const M: usize, const N: usize> NdArrayView<T> for [[T; M]; N] { + type Iter<'a> + = std::iter::Flatten<std::slice::Iter<'a, [T; M]>> + where + T: 'a; + + fn ndim(&self) -> usize { + 2 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(N), + 1 => Some(M), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + Some(unsafe { std::slice::from_raw_parts(self.as_ptr() as *const T, N * M) }) + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter().flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(N * M * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for two dimensions slices +impl<T: ArrayElement, const M: usize> NdArrayView<T> for &[[T; M]] { + type Iter<'a> + = std::iter::Flatten<std::slice::Iter<'a, [T; M]>> + where + Self: 'a, + T: 'a; + + fn ndim(&self) -> usize { + 2 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(self.len()), + 1 => Some(M), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + Some(unsafe { std::slice::from_raw_parts(self.as_ptr() as *const T, self.len() * M) }) + } + + fn iter(&self) -> Self::Iter<'_> { + <[[T; M]]>::iter(self).flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(self.len() * M * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for three dimensions vector +impl<T: ArrayElement> NdArrayView<T> for Vec<Vec<Vec<T>>> { + type Iter<'a> + = std::iter::Flatten<std::iter::Flatten<std::slice::Iter<'a, Vec<Vec<T>>>>> + where + T: 'a; + + fn ndim(&self) -> usize { + 3 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(self.len()), + 1 => self.first().map(|v| v.len()), + 2 => self.first().and_then(|v2| v2.first()).map(|v3| v3.len()), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + None + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter().flatten().flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + let dim1 = self.first().map_or(0, |v| v.len()); + + if self.as_slice().iter().any(|v2| v2.len() != dim1) { + return Err(error::fmt!(ArrayViewError, "Irregular array shape")); + } + + let dim2 = self + .first() + .and_then(|v2| v2.first()) + .map_or(0, |v3| v3.len()); + + if self + .as_slice() + .iter() + .flat_map(|v2| v2.as_slice().iter()) + .any(|v3| v3.len() != dim2) + { + return Err(error::fmt!(ArrayViewError, "Irregular array shape")); + } + + Ok(self.len() * dim1 * dim2 * std::mem::size_of::<T>()) + } +} + +/// impl NdArrayView for three dimensions array +impl<T: ArrayElement, const M: usize, const N: usize, const L: usize> NdArrayView<T> + for [[[T; M]; N]; L] +{ + type Iter<'a> + = std::iter::Flatten<std::iter::Flatten<std::slice::Iter<'a, [[T; M]; N]>>> + where + T: 'a; + + fn ndim(&self) -> usize { + 3 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(L), + 1 => Some(N), + 2 => Some(M), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + Some(unsafe { std::slice::from_raw_parts(self.as_ptr() as *const T, L * N * M) }) + } + + fn iter(&self) -> Self::Iter<'_> { + self.as_slice().iter().flatten().flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(L * N * M * std::mem::size_of::<T>()) + } +} + +impl<T: ArrayElement, const M: usize, const N: usize> NdArrayView<T> for &[[[T; M]; N]] { + type Iter<'a> + = std::iter::Flatten<std::iter::Flatten<std::slice::Iter<'a, [[T; M]; N]>>> + where + Self: 'a, + T: 'a; + + fn ndim(&self) -> usize { + 3 + } + + fn dim(&self, idx: usize) -> Option<usize> { + match idx { + 0 => Some(self.len()), + 1 => Some(N), + 2 => Some(M), + _ => None, + } + } + + fn as_slice(&self) -> Option<&[T]> { + Some(unsafe { std::slice::from_raw_parts(self.as_ptr() as *const T, self.len() * N * M) }) + } + + fn iter(&self) -> Self::Iter<'_> { + <[[[T; M]; N]]>::iter(self).flatten().flatten() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(self.len() * N * M * std::mem::size_of::<T>()) + } +} + +use crate::{error, Error}; +#[cfg(feature = "ndarray")] +use ndarray::{ArrayView, Axis, Dimension}; +use std::io::IoSlice; +use std::slice; + +#[cfg(feature = "ndarray")] +impl<T, D> NdArrayView<T> for ArrayView<'_, T, D> +where + T: ArrayElement, + D: Dimension, +{ + type Iter<'a> + = ndarray::iter::Iter<'a, T, D> + where + Self: 'a, + T: 'a; + + fn ndim(&self) -> usize { + self.ndim() + } + + fn dim(&self, index: usize) -> Option<usize> { + let len = self.ndim(); + if index < len { + Some(self.len_of(Axis(index))) + } else { + None + } + } + + fn iter(&self) -> Self::Iter<'_> { + self.iter() + } + + fn as_slice(&self) -> Option<&[T]> { + self.as_slice() + } + + fn check_data_buf(&self) -> Result<usize, Error> { + Ok(self.len() * size_of::<T>()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_f64_element_type() { + assert_eq!(<f64 as ArrayElementSealed>::type_tag(), 10); + } +} diff --git a/questdb-rs/src/ingress/tests.rs b/questdb-rs/src/ingress/tests.rs index 8d83a7b0..01a06f51 100644 --- a/questdb-rs/src/ingress/tests.rs +++ b/questdb-rs/src/ingress/tests.rs @@ -447,6 +447,8 @@ fn connect_timeout_uses_request_timeout() { let builder = SenderBuilder::new(Protocol::Http, "127.0.0.2", "1111") .request_timeout(request_timeout) .unwrap() + .disable_line_protocol_validation() + .unwrap() .retry_timeout(Duration::from_millis(10)) .unwrap() .request_min_throughput(0) diff --git a/questdb-rs/src/tests/http.rs b/questdb-rs/src/tests/http.rs index 90ec1222..170182ab 100644 --- a/questdb-rs/src/tests/http.rs +++ b/questdb-rs/src/tests/http.rs @@ -22,18 +22,20 @@ * ******************************************************************************/ -use crate::ingress::{Buffer, Protocol, SenderBuilder, TimestampNanos}; +use crate::ingress::{Buffer, LineProtocolVersion, Protocol, SenderBuilder, TimestampNanos}; use crate::tests::mock::{certs_dir, HttpResponse, MockServer}; +use crate::tests::TestResult; use crate::ErrorCode; +use rstest::rstest; use std::io; use std::io::ErrorKind; use std::time::Duration; -use crate::tests::TestResult; - -#[test] -fn test_two_lines() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_two_lines( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? @@ -46,11 +48,19 @@ fn test_two_lines() -> TestResult { .at_now()?; let buffer2 = buffer.clone(); - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -66,6 +76,7 @@ fn test_two_lines() -> TestResult { Ok(()) }); + let mut sender = sender_builder.build()?; let res = sender.flush(&mut buffer); server_thread.join().unwrap()?; @@ -77,9 +88,11 @@ fn test_two_lines() -> TestResult { Ok(()) } -#[test] -fn test_text_plain_error() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_text_plain_error( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? @@ -87,12 +100,20 @@ fn test_text_plain_error() -> TestResult { .at_now()?; buffer.table("test")?.column_f64("sym", 2.0)?.at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -109,6 +130,7 @@ fn test_text_plain_error() -> TestResult { Ok(()) }); + let mut sender = sender_builder.build()?; let res = sender.flush(&mut buffer); server_thread.join().unwrap()?; @@ -123,9 +145,11 @@ fn test_text_plain_error() -> TestResult { Ok(()) } -#[test] -fn test_bad_json_error() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_bad_json_error( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? @@ -133,12 +157,20 @@ fn test_bad_json_error() -> TestResult { .at_now()?; buffer.table("test")?.column_f64("sym", 2.0)?.at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -156,6 +188,7 @@ fn test_bad_json_error() -> TestResult { Ok(()) }); + let mut sender = sender_builder.build()?; let res = sender.flush_and_keep(&buffer); server_thread.join().unwrap()?; @@ -171,9 +204,11 @@ fn test_bad_json_error() -> TestResult { Ok(()) } -#[test] -fn test_json_error() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_json_error( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? @@ -181,12 +216,20 @@ fn test_json_error() -> TestResult { .at_now()?; buffer.table("test")?.column_f64("sym", 2.0)?.at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -207,7 +250,7 @@ fn test_json_error() -> TestResult { Ok(()) }); - let res = sender.flush_and_keep(&buffer); + let res = sender_builder.build()?.flush_and_keep(&buffer); server_thread.join().unwrap()?; @@ -222,16 +265,20 @@ fn test_json_error() -> TestResult { Ok(()) } -#[test] -fn test_no_connection() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_no_connection( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? .column_f64("x", 1.0)? .at_now()?; - let mut sender = SenderBuilder::new(Protocol::Http, "127.0.0.1", 1).build()?; + let mut sender = SenderBuilder::new(Protocol::Http, "127.0.0.1", 1) + .disable_line_protocol_validation()? + .build()?; let res = sender.flush_and_keep(&buffer); assert!(res.is_err()); let err = res.unwrap_err(); @@ -242,21 +289,31 @@ fn test_no_connection() -> TestResult { Ok(()) } -#[test] -fn test_old_server_without_ilp_http_support() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_old_server_without_ilp_http_support( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? .column_f64("x", 1.0)? .at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -273,7 +330,7 @@ fn test_old_server_without_ilp_http_support() -> TestResult { Ok(()) }); - let res = sender.flush_and_keep(&buffer); + let res = sender_builder.build()?.flush_and_keep(&buffer); server_thread.join().unwrap()?; @@ -288,25 +345,34 @@ fn test_old_server_without_ilp_http_support() -> TestResult { Ok(()) } -#[test] -fn test_http_basic_auth() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_http_basic_auth( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? .column_f64("x", 1.0)? .at_now()?; - let mut server = MockServer::new()?; - let mut sender = server + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server .lsb_http() .username("Aladdin")? - .password("OpenSesame")? - .build()?; + .password("OpenSesame")?; let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; @@ -323,7 +389,7 @@ fn test_http_basic_auth() -> TestResult { Ok(()) }); - let res = sender.flush(&mut buffer); + let res = sender_builder.build()?.flush(&mut buffer); server_thread.join().unwrap()?; @@ -334,21 +400,31 @@ fn test_http_basic_auth() -> TestResult { Ok(()) } -#[test] -fn test_unauthenticated() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_unauthenticated( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? .column_f64("x", 1.0)? .at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -365,7 +441,7 @@ fn test_unauthenticated() -> TestResult { Ok(()) }); - let res = sender.flush(&mut buffer); + let res = sender_builder.build()?.flush(&mut buffer); server_thread.join().unwrap()?; @@ -382,21 +458,31 @@ fn test_unauthenticated() -> TestResult { Ok(()) } -#[test] -fn test_token_auth() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_token_auth( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? .column_f64("x", 1.0)? .at_now()?; - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().token("0123456789")?.build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http().token("0123456789")?; let buffer2 = buffer.clone(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.method(), "POST"); @@ -409,7 +495,7 @@ fn test_token_auth() -> TestResult { Ok(()) }); - let res = sender.flush(&mut buffer); + let res = sender_builder.build()?.flush(&mut buffer); server_thread.join().unwrap()?; @@ -418,9 +504,11 @@ fn test_token_auth() -> TestResult { Ok(()) } -#[test] -fn test_request_timeout() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_request_timeout( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("sym", "bol")? @@ -428,12 +516,13 @@ fn test_request_timeout() -> TestResult { .at_now()?; // Here we use a mock (tcp) server instead and don't send a response back. - let server = MockServer::new()?; + let server = MockServer::new()?.configure_settings_response(2, &[1, 2]); let request_timeout = Duration::from_millis(50); let time_start = std::time::Instant::now(); let mut sender = server .lsb_http() + .disable_line_protocol_validation()? .request_timeout(request_timeout)? .build()?; let res = sender.flush_and_keep(&buffer); @@ -446,12 +535,14 @@ fn test_request_timeout() -> TestResult { Ok(()) } -#[test] -fn test_tls() -> TestResult { +#[rstest] +fn test_tls( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { let mut ca_path = certs_dir(); ca_path.push("server_rootCA.pem"); - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -459,8 +550,12 @@ fn test_tls() -> TestResult { .at(TimestampNanos::new(10000000))?; let buffer2 = buffer.clone(); - let mut server = MockServer::new()?; - let mut sender = server.lsb_https().tls_roots(ca_path)?.build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let mut sender = server + .lsb_https() + .tls_roots(ca_path)? + .disable_line_protocol_validation()? + .build()?; let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept_tls_sync()?; @@ -484,9 +579,11 @@ fn test_tls() -> TestResult { Ok(()) } -#[test] -fn test_user_agent() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_user_agent( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -494,11 +591,15 @@ fn test_user_agent() -> TestResult { .at(TimestampNanos::new(10000000))?; let buffer2 = buffer.clone(); - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().user_agent("wallabies/1.2.99")?.build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http().user_agent("wallabies/1.2.99")?; let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.header("user-agent"), Some("wallabies/1.2.99")); @@ -509,7 +610,7 @@ fn test_user_agent() -> TestResult { Ok(()) }); - let res = sender.flush_and_keep(&buffer); + let res = sender_builder.build()?.flush_and_keep(&buffer); server_thread.join().unwrap()?; @@ -519,11 +620,13 @@ fn test_user_agent() -> TestResult { Ok(()) } -#[test] -fn test_two_retries() -> TestResult { +#[rstest] +fn test_two_retries( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { // Note: This also tests that the _same_ connection is being reused, i.e. tests keepalive. - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -531,14 +634,19 @@ fn test_two_retries() -> TestResult { .at(TimestampNanos::new(10000000))?; let buffer2 = buffer.clone(); - let mut server = MockServer::new()?; - let mut sender = server - .lsb_http() - .retry_timeout(Duration::from_secs(30))? - .build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http().retry_timeout(Duration::from_secs(30))?; let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.body(), buffer2.as_bytes()); @@ -574,7 +682,7 @@ fn test_two_retries() -> TestResult { Ok(()) }); - let res = sender.flush_and_keep(&buffer); + let res = sender_builder.build()?.flush_and_keep(&buffer); server_thread.join().unwrap()?; @@ -584,9 +692,11 @@ fn test_two_retries() -> TestResult { Ok(()) } -#[test] -fn test_one_retry() -> TestResult { - let mut buffer = Buffer::new(); +#[rstest] +fn test_one_retry( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -594,15 +704,16 @@ fn test_one_retry() -> TestResult { .at(TimestampNanos::new(10000000))?; let buffer2 = buffer.clone(); - let mut server = MockServer::new()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); let mut sender = server .lsb_http() .retry_timeout(Duration::from_millis(19))? + .disable_line_protocol_validation() + .unwrap() .build()?; let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; - let req = server.recv_http_q()?; assert_eq!(req.body(), buffer2.as_bytes()); @@ -648,10 +759,12 @@ fn test_one_retry() -> TestResult { Ok(()) } -#[test] -fn test_transactional() -> TestResult { +#[rstest] +fn test_transactional( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { // A buffer with a two tables. - let mut buffer1 = Buffer::new(); + let mut buffer1 = Buffer::new().with_line_proto_version(version)?; buffer1 .table("tab1")? .symbol("t1", "v1")? @@ -665,7 +778,7 @@ fn test_transactional() -> TestResult { assert!(!buffer1.transactional()); // A buffer with a single table. - let mut buffer2 = Buffer::new(); + let mut buffer2 = Buffer::new().with_line_proto_version(version)?; buffer2 .table("test")? .symbol("t1", "v1")? @@ -674,11 +787,19 @@ fn test_transactional() -> TestResult { let buffer3 = buffer2.clone(); assert!(buffer2.transactional()); - let mut server = MockServer::new()?; - let mut sender = server.lsb_http().build()?; + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); let server_thread = std::thread::spawn(move || -> io::Result<()> { server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; let req = server.recv_http_q()?; assert_eq!(req.body(), buffer3.as_bytes()); @@ -688,6 +809,8 @@ fn test_transactional() -> TestResult { Ok(()) }); + let mut sender = sender_builder.build()?; + let res = sender.flush_and_keep_with_flags(&buffer1, true); assert!(res.is_err()); let err = res.unwrap_err(); @@ -707,3 +830,145 @@ fn test_transactional() -> TestResult { Ok(()) } + +#[test] +fn test_sender_line_protocol_version() -> TestResult { + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let sender_builder = server.lsb_http(); + let server_thread = std::thread::spawn(move || -> io::Result<()> { + server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; + Ok(()) + }); + let sender = sender_builder.build()?; + assert_eq!( + sender.default_line_protocol_version(), + LineProtocolVersion::V2 + ); + assert_eq!( + sender.support_line_protocol_versions().unwrap(), + vec![LineProtocolVersion::V1, LineProtocolVersion::V2] + ); + server_thread.join().unwrap()?; + Ok(()) +} + +#[test] +fn test_sender_line_protocol_version_old_server1() -> TestResult { + let mut server = MockServer::new()?.configure_settings_response(0, &[1, 2]); + let sender_builder = server.lsb_http(); + let server_thread = std::thread::spawn(move || -> io::Result<()> { + server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.method(), "GET"); + assert_eq!(req.path(), "/settings"); + assert_eq!( + req.header("user-agent"), + Some(concat!("questdb/rust/", env!("CARGO_PKG_VERSION"))) + ); + server.send_settings_response()?; + Ok(()) + }); + let sender = sender_builder.build()?; + assert_eq!( + sender.default_line_protocol_version(), + LineProtocolVersion::V1 + ); + assert!(sender.support_line_protocol_versions().is_none()); + server_thread.join().unwrap()?; + Ok(()) +} + +#[test] +fn test_sender_line_protocol_version_old_server2() -> TestResult { + let mut server = MockServer::new()?.configure_settings_response(0, &[1, 2]); + let sender_builder = server.lsb_http(); + let server_thread = std::thread::spawn(move || -> io::Result<()> { + server.accept()?; + server.send_http_response_q( + HttpResponse::empty() + .with_status(404, "Not Found") + .with_header("content-type", "text/plain") + .with_body_str("Not Found"), + )?; + Ok(()) + }); + let sender = sender_builder.build()?; + assert_eq!( + sender.default_line_protocol_version(), + LineProtocolVersion::V1 + ); + assert!(sender.support_line_protocol_versions().is_none()); + server_thread.join().unwrap()?; + Ok(()) +} + +#[test] +fn test_sender_line_protocol_version_unsupported_client() -> TestResult { + let mut server = MockServer::new()?.configure_settings_response(3, &[3, 4]); + let sender_builder = server.lsb_http(); + let server_thread = std::thread::spawn(move || -> io::Result<()> { + server.accept()?; + server.send_settings_response()?; + Ok(()) + }); + let res1 = sender_builder.build(); + assert!(res1.is_err()); + let e1 = res1.err().unwrap(); + assert_eq!(e1.code(), ErrorCode::LineProtocolVersionError); + assert!(e1.msg().contains("Server does not support current client")); + server_thread.join().unwrap()?; + Ok(()) +} + +#[test] +fn test_sender_disable_line_protocol_version_validation() -> TestResult { + let mut server = MockServer::new()?.configure_settings_response(2, &[1, 2]); + let mut sender = server + .lsb_http() + .disable_line_protocol_validation()? + .build()?; + let mut buffer = + Buffer::new().with_line_proto_version(sender.default_line_protocol_version())?; + buffer + .table("test")? + .symbol("sym", "bol")? + .column_f64("x", 1.0)? + .at_now()?; + let buffer2 = buffer.clone(); + + let server_thread = std::thread::spawn(move || -> io::Result<()> { + server.accept()?; + let req = server.recv_http_q()?; + assert_eq!(req.body(), buffer2.as_bytes()); + server.send_http_response_q(HttpResponse::empty())?; + Ok(()) + }); + + sender.flush(&mut buffer)?; + server_thread.join().unwrap()?; + Ok(()) +} + +#[test] +fn test_sender_line_protocol_version1_not_support_array() -> TestResult { + let mut buffer = Buffer::new().with_line_proto_version(LineProtocolVersion::V1)?; + let res = buffer + .table("test")? + .symbol("sym", "bol")? + .column_arr("x", &[1.0f64, 2.0]); + assert!(res.is_err()); + let e1 = res.as_ref().err().unwrap(); + assert_eq!(e1.code(), ErrorCode::LineProtocolVersionError); + assert!(e1 + .msg() + .contains("line protocol version v1 does not support array datatype")); + Ok(()) +} diff --git a/questdb-rs/src/tests/interop/ilp-client-interop-test.json b/questdb-rs/src/tests/interop/ilp-client-interop-test.json index d3e0e259..0acedad7 100644 --- a/questdb-rs/src/tests/interop/ilp-client-interop-test.json +++ b/questdb-rs/src/tests/interop/ilp-client-interop-test.json @@ -32,6 +32,7 @@ ], "result": { "status": "SUCCESS", + "binaryBase64": "dGVzdF90YWJsZSxzeW1fY29sPXN5bV92YWwgc3RyX2NvbD0iZm9vIGJhciBiYXoiLGxvbmdfY29sPTQyaSxkb3VibGVfY29sPT0QAAAAAABARUAsYm9vbF9jb2w9dAo=", "line": "test_table,sym_col=sym_val str_col=\"foo bar baz\",long_col=42i,double_col=42.5,bool_col=t" } }, @@ -73,6 +74,7 @@ ], "result": { "status": "SUCCESS", + "binaryBase64": "ZG91YmxlcyBkMD09EAAAAAAAAAAALGRtMD09EAAAAAAAAACALGQxPT0QAAAAAAAA8D8sZEUxMDA9PRB9w5QlrUmyVCxkMDAwMDAwMT09EI3ttaD3xrA+LGROMDAwMDAwMT09EI3ttaD3xrC+Cg==", "anyLines": [ "doubles d0=0,dm0=-0,d1=1,dE100=1E+100,d0000001=1E-06,dN0000001=-1E-06", "doubles d0=0.0,dm0=-0.0,d1=1.0,dE100=1e100,d0000001=1e-6,dN0000001=-1e-6" diff --git a/questdb-rs/src/tests/mock.rs b/questdb-rs/src/tests/mock.rs index 289b41cb..71d63c29 100644 --- a/questdb-rs/src/tests/mock.rs +++ b/questdb-rs/src/tests/mock.rs @@ -36,9 +36,12 @@ use std::net::SocketAddr; use std::sync::Arc; use std::time::Instant; +use crate::ingress; #[cfg(feature = "ilp-over-http")] use std::io::Write; +use super::ndarr::ArrayColumnTypeTag; + const CLIENT: Token = Token(0); #[derive(Debug)] @@ -50,7 +53,9 @@ pub struct MockServer { tls_conn: Option<ServerConnection>, pub host: &'static str, pub port: u16, - pub msgs: Vec<String>, + pub msgs: Vec<Vec<u8>>, + #[cfg(feature = "ilp-over-http")] + settings_response: serde_json::Value, } pub fn certs_dir() -> std::path::PathBuf { @@ -206,6 +211,8 @@ impl MockServer { host: "localhost", port, msgs: Vec::new(), + #[cfg(feature = "ilp-over-http")] + settings_response: serde_json::Value::Null, }) } @@ -302,6 +309,23 @@ impl MockServer { } } + #[cfg(feature = "ilp-over-http")] + pub fn configure_settings_response( + mut self, + default_version: u16, + supported_versions: &[u16], + ) -> Self { + if default_version == 0 { + self.settings_response = serde_json::json!({"version": "8.1.2"}); + } else { + self.settings_response = serde_json::json!({ + "line.proto.default.version": default_version, + "line.proto.support.versions": supported_versions + }); + } + self + } + #[cfg(feature = "ilp-over-http")] fn do_write(&mut self, buf: &[u8]) -> io::Result<usize> { let client = self.client.as_mut().unwrap(); @@ -454,6 +478,15 @@ impl MockServer { Ok(()) } + #[cfg(feature = "ilp-over-http")] + pub fn send_settings_response(&mut self) -> io::Result<()> { + let response = HttpResponse::empty() + .with_status(200, "OK") + .with_body_json(&self.settings_response); + self.send_http_response(response, Some(2.0))?; + Ok(()) + } + #[cfg(feature = "ilp-over-http")] pub fn send_http_response_q(&mut self, response: HttpResponse) -> io::Result<()> { self.send_http_response(response, Some(5.0)) @@ -465,6 +498,14 @@ impl MockServer { let deadline = Instant::now() + Duration::from_secs_f64(wait_timeout_sec); let (pos, method, path) = self.recv_http_method(&mut accum, deadline)?; let (pos, headers) = self.recv_http_headers(pos, &mut accum, deadline)?; + if &method == "GET" { + return Ok(HttpRequest { + method, + path, + headers, + body: vec![], + }); + } let content_length = headers .get("content-length") .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "Missing Content-Length"))? @@ -521,15 +562,47 @@ impl MockServer { let mut received_count = 0usize; let mut head = 0usize; - for index in 1..accum.len() { + let binary_length = 0usize; + let mut index = 1; + + while index < accum.len() { let last = accum[index]; let prev = accum[index - 1]; - if (last == b'\n') && (prev != b'\\') { + if last == b'=' && prev == b'=' { + index += 1; + // calc binary length + let binary_type = accum[index]; + if binary_type == ingress::DOUBLE_BINARY_FORMAT_TYPE { + index += size_of::<f64>() + 1; + } else if binary_type == ingress::ARRAY_BINARY_FORMAT_TYPE { + index += 1; + let element_type = match ArrayColumnTypeTag::try_from(accum[index]) { + Ok(t) => t, + Err(e) => { + return Err(io::Error::new(io::ErrorKind::Other, e)); + } + }; + let mut elems_size = element_type.size(); + index += 1; + let dims = accum[index] as usize; + index += 1; + for _ in 0..dims { + elems_size *= i32::from_le_bytes( + accum[index..index + size_of::<i32>()].try_into().unwrap(), + ) as usize; + index += size_of::<i32>(); + } + index += elems_size; + } + } else if (last == b'\n') && (prev != b'\\' && binary_length == 0) { let tail = index + 1; - let msg = std::str::from_utf8(&accum[head..tail]).unwrap(); - self.msgs.push(msg.to_owned()); + let msg = &accum[head..tail]; + self.msgs.push(msg.to_vec()); head = tail; received_count += 1; + index = tail; + } else { + index += 1; } } Ok(received_count) diff --git a/questdb-rs/src/tests/mod.rs b/questdb-rs/src/tests/mod.rs index f817287e..d68dede9 100644 --- a/questdb-rs/src/tests/mod.rs +++ b/questdb-rs/src/tests/mod.rs @@ -21,6 +21,7 @@ * limitations under the License. * ******************************************************************************/ + mod f64_serializer; #[cfg(feature = "ilp-over-http")] @@ -29,6 +30,8 @@ mod http; mod mock; mod sender; +mod ndarr; + #[cfg(feature = "json_tests")] mod json_tests { include!(concat!(env!("OUT_DIR"), "/json_tests.rs")); diff --git a/questdb-rs/src/tests/ndarr.rs b/questdb-rs/src/tests/ndarr.rs new file mode 100644 index 00000000..ab0b984b --- /dev/null +++ b/questdb-rs/src/tests/ndarr.rs @@ -0,0 +1,1144 @@ +#[cfg(feature = "ndarray")] +use crate::ingress::MAX_DIMS; +use crate::ingress::{Buffer, NdArrayView, StrideArrayView, ARRAY_BINARY_FORMAT_TYPE}; +use crate::tests::TestResult; +use crate::ErrorCode; + +use crate::ingress::ndarr::write_array_data; +#[cfg(feature = "ndarray")] +use ndarray::{arr1, arr2, arr3, s, ArrayD}; +#[cfg(feature = "ndarray")] +use std::iter; +use std::ptr; + +/// QuestDB column type tags that are supported as array element types. +#[derive(Clone, Copy)] +#[repr(u8)] +pub enum ArrayColumnTypeTag { + Double = 10, +} + +impl ArrayColumnTypeTag { + pub fn size(&self) -> usize { + match self { + ArrayColumnTypeTag::Double => std::mem::size_of::<f64>(), + } + } +} + +impl From<ArrayColumnTypeTag> for u8 { + fn from(tag: ArrayColumnTypeTag) -> Self { + tag as u8 + } +} + +impl TryFrom<u8> for ArrayColumnTypeTag { + type Error = String; + + fn try_from(value: u8) -> Result<Self, Self::Error> { + match value { + 10 => Ok(ArrayColumnTypeTag::Double), + _ => Err(format!("Unsupported column type tag {} for arrays", value)), + } + } +} + +fn to_bytes<T: Copy>(data: &[T]) -> Vec<u8> { + data.iter() + .flat_map(|x| { + let bytes = + unsafe { std::slice::from_raw_parts(x as *const T as *const u8, size_of::<T>()) }; + bytes.to_vec() + }) + .collect() +} + +#[test] +fn test_stride_array_view() -> TestResult { + // contiguous layout + let test_data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let shapes = [2usize, 3]; + let strides = [ + (shapes[1] * size_of::<f64>()) as isize, + size_of::<f64>() as isize, + ]; + let array = unsafe { + StrideArrayView::<f64>::new( + shapes.len(), + shapes.as_ptr(), + strides.as_ptr(), + test_data.as_ptr() as *const u8, + test_data.len() * size_of::<f64>(), + ) + }; + + assert_eq!(array.ndim(), 2); + assert_eq!(array.dim(0), Some(2)); + assert_eq!(array.dim(1), Some(3)); + assert_eq!(array.dim(2), None); + assert!(array.as_slice().is_some()); + let mut buf = vec![]; + write_array_data(&array, &mut buf).unwrap(); + let expected = to_bytes(&test_data); + assert_eq!(buf, expected); + Ok(()) +} + +#[test] +fn test_strided_non_contiguous() -> TestResult { + let elem_size = size_of::<f64>() as isize; + let col_major_data = [1.0, 3.0, 5.0, 2.0, 4.0, 6.0]; + let shapes = [3usize, 2]; + let strides = [elem_size, shapes[0] as isize * elem_size]; + + let array_view: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + shapes.len(), + shapes.as_ptr(), + strides.as_ptr(), + col_major_data.as_ptr() as *const u8, + col_major_data.len() * elem_size as usize, + ) + }; + + assert_eq!(array_view.ndim(), 2); + assert_eq!(array_view.dim(0), Some(3)); + assert_eq!(array_view.dim(1), Some(2)); + assert_eq!(array_view.dim(2), None); + assert!(array_view.as_slice().is_none()); + let mut buffer = Vec::new(); + write_array_data(&array_view, &mut buffer)?; + + let expected_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let expected_bytes = unsafe { + std::slice::from_raw_parts( + expected_data.as_ptr() as *const u8, + expected_data.len() * elem_size as usize, + ) + }; + assert_eq!(buffer, expected_bytes); + Ok(()) +} + +#[test] +fn test_negative_strides() -> TestResult { + let elem_size = size_of::<f64>(); + let data = [1f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let view = unsafe { + StrideArrayView::<f64>::new( + 2, + &[3usize, 3] as *const usize, + &[-24isize, 8] as *const isize, + (data.as_ptr() as *const u8).add(48), + data.len() * elem_size, + ) + }; + let collected: Vec<_> = view.iter().copied().collect(); + assert!(view.as_slice().is_none()); + let expected_data = vec![7.0, 8.0, 9.0, 4.0, 5.0, 6.0, 1.0, 2.0, 3.0]; + assert_eq!(collected, expected_data); + let mut buffer = Vec::new(); + write_array_data(&view, &mut buffer)?; + let expected_bytes = unsafe { + std::slice::from_raw_parts( + expected_data.as_ptr() as *const u8, + expected_data.len() * elem_size, + ) + }; + assert_eq!(buffer, expected_bytes); + Ok(()) +} + +#[test] +fn test_basic_edge_cases() { + // empty array + let elem_size = std::mem::size_of::<f64>() as isize; + let empty_view: StrideArrayView<'_, f64> = + unsafe { StrideArrayView::new(2, [0, 0].as_ptr(), [0, 0].as_ptr(), ptr::null(), 0) }; + assert_eq!(empty_view.ndim(), 2); + assert_eq!(empty_view.dim(0), Some(0)); + assert_eq!(empty_view.dim(1), Some(0)); + + // single element array + let single_data = [42.0]; + let single_view: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + 2, + [1, 1].as_ptr(), + [elem_size, elem_size].as_ptr(), + single_data.as_ptr() as *const u8, + elem_size as usize, + ) + }; + let mut buf = vec![]; + write_array_data(&single_view, &mut buf).unwrap(); + assert_eq!(buf, 42.0f64.to_ne_bytes()); +} + +#[test] +fn test_buffer_basic_write() -> TestResult { + let elem_size = std::mem::size_of::<f64>() as isize; + + let test_data = [1.1, 2.2, 3.3, 4.4]; + let array_view: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + 2, + [2, 2].as_ptr(), + [2 * elem_size, elem_size].as_ptr(), + test_data.as_ptr() as *const u8, + test_data.len() * elem_size as usize, + ) + }; + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &array_view)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[24..32], + [2i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[32..64], + &[ + 1.1f64.to_ne_bytes(), + 2.2f64.to_le_bytes(), + 3.3f64.to_le_bytes(), + 4.4f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_size_overflow() -> TestResult { + let overflow_view = unsafe { + StrideArrayView::<f64>::new( + 2, + [u32::MAX as usize, u32::MAX as usize].as_ptr(), + [8, 8].as_ptr(), + ptr::null(), + 0, + ) + }; + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + let result = buffer.column_arr("arr1", &overflow_view); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayViewError); + assert!(err.msg().contains("Array total elem size overflow")); + Ok(()) +} + +#[test] +fn test_array_length_mismatch() -> TestResult { + let elem_size = size_of::<f64>() as isize; + let under_data = [1.1]; + let under_view: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + 2, + [1, 2].as_ptr(), + [elem_size, elem_size].as_ptr(), + under_data.as_ptr() as *const u8, + under_data.len() * elem_size as usize, + ) + }; + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + let result = buffer.column_arr("arr1", &under_view); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayWriteToBufferError); + assert!(err + .msg() + .contains("Array buffer length mismatch (actual: 8, expected: 16)")); + + let over_data = [1.1, 2.2, 3.3]; + let over_view: StrideArrayView<'_, f64> = unsafe { + StrideArrayView::new( + 2, + [1, 2].as_ptr(), + [elem_size, elem_size].as_ptr(), + over_data.as_ptr() as *const u8, + over_data.len() * elem_size as usize, + ) + }; + + buffer.clear(); + buffer.table("my_test")?; + let result = buffer.column_arr("arr1", &over_view); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayWriteToBufferError); + assert!(err + .msg() + .contains("Array buffer length mismatch (actual: 24, expected: 16)")); + Ok(()) +} + +#[test] +fn test_build_in_1d_array_normal() -> TestResult { + let arr = [1.0f64, 2.0, 3.0, 4.0]; + assert_eq!(arr.ndim(), 1); + assert_eq!(arr.dim(0), Some(4)); + assert_eq!(arr.dim(1), None); + assert_eq!(NdArrayView::as_slice(&arr), Some(&[1.0, 2.0, 3.0, 4.0][..])); + let collected: Vec<_> = NdArrayView::iter(&arr).copied().collect(); + assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0]); + assert_eq!(arr.check_data_buf(), Ok(32)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [4i32.to_le_bytes()].concat()); + assert_eq!( + &data[28..60], + &[ + 1.0f64.to_ne_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_1d_array_empty() -> TestResult { + let arr: [f64; 0] = []; + assert_eq!(arr.ndim(), 1); + assert_eq!(arr.dim(0), Some(0)); + assert_eq!(NdArrayView::as_slice(&arr), Some(&[][..])); + assert_eq!(arr.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [0i32.to_le_bytes()].concat()); + Ok(()) +} + +#[test] +fn test_build_in_1d_vec_normal() -> TestResult { + let vec = vec![5.0f64, 6.0, 7.0]; + assert_eq!(vec.ndim(), 1); + assert_eq!(vec.dim(0), Some(3)); + assert_eq!(NdArrayView::as_slice(&vec), Some(&[5.0, 6.0, 7.0][..])); + let collected: Vec<_> = NdArrayView::iter(&vec).copied().collect(); + assert_eq!(collected, vec![5.0, 6.0, 7.0]); + assert_eq!(vec.check_data_buf(), Ok(24)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [3i32.to_le_bytes()].concat()); + assert_eq!( + &data[28..52], + &[ + 5.0f64.to_le_bytes(), + 6.0f64.to_le_bytes(), + 7.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_1d_vec_empty() -> TestResult { + let vec: Vec<f64> = Vec::new(); + assert_eq!(vec.ndim(), 1); + assert_eq!(vec.dim(0), Some(0)); + assert_eq!(NdArrayView::as_slice(&vec), Some(&[][..])); + assert_eq!(vec.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [0i32.to_le_bytes()].concat()); + Ok(()) +} + +#[test] +fn test_build_in_1d_slice_normal() -> TestResult { + let data = [10.0f64, 20.0, 30.0, 40.0]; + let slice = &data[1..3]; + assert_eq!(slice.ndim(), 1); + assert_eq!(slice.dim(0), Some(2)); + assert_eq!(NdArrayView::as_slice(&slice), Some(&[20.0, 30.0][..])); + assert_eq!(slice.check_data_buf(), Ok(16)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [2i32.to_le_bytes()].concat()); + assert_eq!( + &data[28..44], + &[20.0f64.to_le_bytes(), 30.0f64.to_le_bytes(),].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_1d_slice_empty() -> TestResult { + let data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let slice = &data[2..2]; + assert_eq!(slice.ndim(), 1); + assert_eq!(slice.dim(0), Some(0)); + assert_eq!(NdArrayView::as_slice(&slice), Some(&[][..])); + assert_eq!(slice.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("temperature", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 1u8 + ] + ); + assert_eq!(&data[24..28], [0i32.to_le_bytes()].concat()); + Ok(()) +} + +#[test] +fn test_build_in_2d_array_normal() -> TestResult { + let arr = [[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]]; + assert_eq!(arr.ndim(), 2); + assert_eq!(arr.dim(0), Some(3)); + assert_eq!(arr.dim(1), Some(2)); + assert_eq!( + NdArrayView::as_slice(&arr), + Some(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0][..]) + ); + let collected: Vec<_> = NdArrayView::iter(&arr).copied().collect(); + assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + assert_eq!(arr.check_data_buf(), Ok(48)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [3i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[28..76], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + 5.0f64.to_le_bytes(), + 6.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_2d_array_empty() -> TestResult { + let arr: [[f64; 0]; 0] = []; + assert_eq!(arr.ndim(), 2); + assert_eq!(arr.dim(0), Some(0)); + assert_eq!(arr.dim(1), Some(0)); + assert_eq!(NdArrayView::as_slice(&arr), Some(&[][..])); + assert_eq!(arr.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [0i32.to_le_bytes(), 0i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_2d_vec_normal() -> TestResult { + let vec = vec![vec![1.0f64, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]]; + assert_eq!(vec.ndim(), 2); + assert_eq!(vec.dim(0), Some(3)); + assert_eq!(vec.dim(1), Some(2)); + assert!(NdArrayView::as_slice(&vec).is_none()); + let collected: Vec<_> = NdArrayView::iter(&vec).copied().collect(); + assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]); + assert_eq!(vec.check_data_buf(), Ok(48)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [3i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[28..76], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + 5.0f64.to_le_bytes(), + 6.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_2d_vec_irregular_shape() -> TestResult { + let irregular_vec = vec![vec![1.0, 2.0], vec![3.0], vec![4.0, 5.0]]; + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + let result = buffer.column_arr("arr", &irregular_vec); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayViewError); + assert!(err.msg().contains("Irregular array shape")); + Ok(()) +} + +#[test] +fn test_build_in_2d_vec_empty() -> TestResult { + let vec: Vec<Vec<f64>> = vec![vec![], vec![], vec![]]; + assert_eq!(vec.ndim(), 2); + assert_eq!(vec.dim(0), Some(3)); + assert_eq!(vec.dim(1), Some(0)); + assert_eq!(vec.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [3i32.to_le_bytes(), 0i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_2d_slice_normal() -> TestResult { + let data = [[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]]; + let slice = &data[..2]; + assert_eq!(slice.ndim(), 2); + assert_eq!(slice.dim(0), Some(2)); + assert_eq!(slice.dim(1), Some(2)); + assert_eq!( + NdArrayView::as_slice(&slice), + Some(&[1.0, 2.0, 3.0, 4.0][..]) + ); + assert_eq!(slice.check_data_buf(), Ok(32)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [2i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[28..60], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_2d_slice_empty() -> TestResult { + let data = [[1.0f64, 2.0], [3.0, 4.0], [5.0, 6.0]]; + let slice = &data[2..2]; + assert_eq!(slice.ndim(), 2); + assert_eq!(slice.dim(0), Some(0)); + assert_eq!(slice.dim(1), Some(2)); + assert_eq!(NdArrayView::as_slice(&slice), Some(&[][..])); + assert_eq!(slice.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("2darray", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"2darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[20..28], + [0i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_3d_array_normal() -> TestResult { + let arr = [[[1.0f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]; + assert_eq!(arr.ndim(), 3); + assert_eq!(arr.dim(0), Some(2)); + assert_eq!(arr.dim(1), Some(2)); + assert_eq!(arr.dim(2), Some(2)); + assert_eq!( + NdArrayView::as_slice(&arr), + Some(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0][..]) + ); + let collected: Vec<_> = NdArrayView::iter(&arr).copied().collect(); + assert_eq!(collected, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + assert_eq!(arr.check_data_buf(), Ok(64)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [2i32.to_le_bytes(), 2i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[32..96], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + 5.0f64.to_le_bytes(), + 6.0f64.to_le_bytes(), + 7.0f64.to_le_bytes(), + 8.0f64.to_le_bytes() + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_3d_array_empty() -> TestResult { + let arr: [[[f64; 2]; 0]; 0] = []; + assert_eq!(arr.ndim(), 3); + assert_eq!(arr.dim(0), Some(0)); + assert_eq!(arr.dim(1), Some(0)); + assert_eq!(arr.dim(2), Some(2)); + assert_eq!(NdArrayView::as_slice(&arr), Some(&[][..])); + assert_eq!(arr.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &arr)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [0i32.to_le_bytes(), 0i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_3d_vec_normal() -> TestResult { + let vec = vec![ + vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]], + vec![vec![7.0, 8.0, 9.0], vec![10.0, 11.0, 12.0]], + ]; + assert_eq!(vec.ndim(), 3); + assert_eq!(vec.dim(0), Some(2)); + assert_eq!(vec.dim(1), Some(2)); + assert_eq!(vec.dim(2), Some(3)); + assert!(NdArrayView::as_slice(&vec).is_none()); + let collected: Vec<_> = NdArrayView::iter(&vec).copied().collect(); + assert_eq!( + collected, + vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0] + ); + assert_eq!(vec.check_data_buf(), Ok(96)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [2i32.to_le_bytes(), 2i32.to_le_bytes(), 3i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[32..128], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + 5.0f64.to_le_bytes(), + 6.0f64.to_le_bytes(), + 7.0f64.to_le_bytes(), + 8.0f64.to_le_bytes(), + 9.0f64.to_le_bytes(), + 10.0f64.to_le_bytes(), + 11.0f64.to_le_bytes(), + 12.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_3d_vec_empty() -> TestResult { + let vec: Vec<Vec<Vec<f64>>> = vec![vec![vec![], vec![]], vec![vec![], vec![]]]; + assert_eq!(vec.ndim(), 3); + assert_eq!(vec.dim(0), Some(2)); + assert_eq!(vec.dim(1), Some(2)); + assert_eq!(vec.dim(2), Some(0)); + assert!(NdArrayView::as_slice(&vec).is_none()); + assert_eq!(vec.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &vec)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [2i32.to_le_bytes(), 2i32.to_le_bytes(), 0i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[test] +fn test_build_in_3d_vec_irregular_shape() -> TestResult { + let irregular1 = vec![vec![vec![1.0, 2.0], vec![3.0, 4.0]], vec![vec![5.0, 6.0]]]; + + let irregular2 = vec![ + vec![vec![1.0, 2.0], vec![3.0, 4.0, 5.0]], + vec![vec![6.0, 7.0], vec![8.0, 9.0]], + ]; + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + let result = buffer.column_arr("arr", &irregular1); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayViewError); + assert!(err.msg().contains("Irregular array shape")); + + let result = buffer.column_arr("arr", &irregular2); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayViewError); + assert!(err.msg().contains("Irregular array shape")); + Ok(()) +} + +#[test] +fn test_3d_slice_normal() -> TestResult { + let data = [[[1f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]; + let slice = &data[..1]; + assert_eq!(slice.ndim(), 3); + assert_eq!(slice.dim(0), Some(1)); + assert_eq!(slice.dim(1), Some(2)); + assert_eq!(slice.dim(2), Some(2)); + assert_eq!( + NdArrayView::as_slice(&slice), + Some(&[1.0, 2.0, 3.0, 4.0][..]) + ); + assert_eq!(slice.check_data_buf(), Ok(32)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [1i32.to_le_bytes(), 2i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + assert_eq!( + &data[32..64], + &[ + 1.0f64.to_le_bytes(), + 2.0f64.to_le_bytes(), + 3.0f64.to_le_bytes(), + 4.0f64.to_le_bytes(), + ] + .concat() + ); + Ok(()) +} + +#[test] +fn test_3d_slice_empty() -> TestResult { + let data = [[[1f64, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]; + let slice = &data[1..1]; + assert_eq!(slice.ndim(), 3); + assert_eq!(slice.dim(0), Some(0)); + assert_eq!(slice.dim(1), Some(2)); + assert_eq!(slice.dim(2), Some(2)); + assert_eq!(NdArrayView::as_slice(&slice), Some(&[][..])); + assert_eq!(slice.check_data_buf(), Ok(0)); + + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + buffer.column_arr("3darray", &slice)?; + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..15], b"3darray"); + assert_eq!( + &data[15..20], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 3u8 + ] + ); + assert_eq!( + &data[20..32], + [0i32.to_le_bytes(), 2i32.to_le_bytes(), 2i32.to_le_bytes()].concat() + ); + Ok(()) +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_1d_contiguous_ndarray_buffer() -> TestResult { + let array = arr1(&[1.0, 2.0, 3.0, 4.0]); + let view = array.view(); + let mut buf = vec![0u8; 4 * size_of::<f64>()]; + write_array_data(&view, &mut &mut buf[0..])?; + let expected: Vec<u8> = array + .iter() + .flat_map(|&x| x.to_ne_bytes().to_vec()) + .collect(); + assert_eq!(buf, expected); + Ok(()) +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_2d_non_contiguous_ndarray_buffer() -> TestResult { + let array = arr2(&[[1.0, 2.0], [3.0, 4.0]]); + let transposed = array.view().reversed_axes(); + assert!(!transposed.is_standard_layout()); + let mut buf = vec![0u8; 4 * size_of::<f64>()]; + write_array_data(&transposed, &mut &mut buf[0..])?; + let expected = [1.0f64, 3.0, 2.0, 4.0] + .iter() + .flat_map(|&x| x.to_ne_bytes()) + .collect::<Vec<_>>(); + assert_eq!(buf, expected); + Ok(()) +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_strided_ndarray_layout() -> TestResult { + let array = arr2(&[ + [1.0, 2.0, 3.0, 4.0], + [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0], + [13.0, 14.0, 15.0, 16.0], + ]); + let strided_view = array.slice(s![1..;2, 1..;2]); + assert_eq!(strided_view.dim(), (2, 2)); + let mut buf = vec![0u8; 4 * size_of::<f64>()]; + write_array_data(&strided_view, &mut &mut buf[0..])?; + + // expect:6.0, 8.0, 14.0, 16.0 + let expected = [6.0f64, 8.0, 14.0, 16.0] + .iter() + .flat_map(|&x| x.to_ne_bytes()) + .collect::<Vec<_>>(); + + assert_eq!(buf, expected); + Ok(()) +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_1d_dimension_ndarray_info() { + let array = arr1(&[1.0, 2.0, 3.0]); + let view = array.view(); + + assert_eq!(NdArrayView::ndim(&view), 1); + assert_eq!(NdArrayView::dim(&view, 0), Some(3)); + assert_eq!(NdArrayView::dim(&view, 1), None); +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_complex_ndarray_dimensions() { + let array = arr3(&[[[1.0], [2.0]], [[3.0], [4.0]]]); + let view = array.view(); + + assert_eq!(NdArrayView::ndim(&view), 3); + assert_eq!(NdArrayView::dim(&view, 0), Some(2)); + assert_eq!(NdArrayView::dim(&view, 1), Some(2)); + assert_eq!(NdArrayView::dim(&view, 2), Some(1)); +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_buffer_ndarray_write() -> TestResult { + let mut buffer = Buffer::new(); + buffer.table("my_test")?; + let array_2d = arr2(&[[1.1, 2.2], [3.3, 4.4]]); + buffer.column_arr("temperature", &array_2d.view())?; + + let data = buffer.as_bytes(); + assert_eq!(&data[0..7], b"my_test"); + assert_eq!(&data[8..19], b"temperature"); + assert_eq!( + &data[19..24], + &[ + b'=', + b'=', + ARRAY_BINARY_FORMAT_TYPE, + ArrayColumnTypeTag::Double.into(), + 2u8 + ] + ); + assert_eq!( + &data[24..32], + [2i32.to_le_bytes().as_slice(), 2i32.to_le_bytes().as_slice()].concat() + ); + Ok(()) +} + +#[cfg(feature = "ndarray")] +#[test] +fn test_buffer_write_ndarray_max_dimensions() -> TestResult { + let mut buffer = Buffer::new(); + buffer.table("nd_test")?; + let shape: Vec<usize> = iter::repeat_n(1, MAX_DIMS).collect(); + let array = ArrayD::<f64>::zeros(shape.clone()); + buffer.column_arr("max_dim", &array.view())?; + let data = buffer.as_bytes(); + assert_eq!(data[19], MAX_DIMS as u8); + + // 33 dims error + let shape_invalid: Vec<_> = iter::repeat_n(1, MAX_DIMS + 1).collect(); + let array_invalid = ArrayD::<f64>::zeros(shape_invalid); + let result = buffer.column_arr("invalid", &array_invalid.view()); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.code(), ErrorCode::ArrayHasTooManyDims); + Ok(()) +} diff --git a/questdb-rs/src/tests/sender.rs b/questdb-rs/src/tests/sender.rs index 041e78b5..1b812b5b 100644 --- a/questdb-rs/src/tests/sender.rs +++ b/questdb-rs/src/tests/sender.rs @@ -29,16 +29,27 @@ use crate::{ Error, ErrorCode, }; +use crate::ingress; +#[cfg(feature = "ndarray")] +use crate::ingress::ndarr::write_array_data; +use crate::ingress::LineProtocolVersion; use crate::tests::{ mock::{certs_dir, MockServer}, + ndarr::ArrayColumnTypeTag, TestResult, }; - use core::time::Duration; -use std::{io, time::SystemTime}; +#[cfg(feature = "ndarray")] +use ndarray::{arr1, arr2, ArrayD}; +use rstest::rstest; +use std::io; + +#[rstest] +fn test_basics( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { + use std::time::SystemTime; -#[test] -fn test_basics() -> TestResult { let mut server = MockServer::new()?; let mut sender = server.lsb_tcp().build()?; assert!(!sender.must_close()); @@ -54,7 +65,7 @@ fn test_basics() -> TestResult { let ts_nanos = TimestampNanos::from_systemtime(ts)?; assert_eq!(ts_nanos.as_i64(), ts_nanos_num); - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -65,32 +76,154 @@ fn test_basics() -> TestResult { .at(ts_nanos)?; assert_eq!(server.recv_q()?, 0); - let exp = format!( - "test,t1=v1 f1=0.5,ts1=12345t,ts2={}t,ts3={}t {}\n", - ts_micros_num, - ts_nanos_num / 1000i64, - ts_nanos_num - ); - let exp_byte = exp.as_bytes(); - assert_eq!(buffer.as_bytes(), exp_byte); + let exp = &[ + "test,t1=v1 ".as_bytes(), + f64_to_bytes("f1", 0.5, version).as_slice(), + format!( + ",ts1=12345t,ts2={}t,ts3={}t {}\n", + ts_micros_num, + ts_nanos_num / 1000i64, + ts_nanos_num + ) + .as_bytes(), + ] + .concat(); + assert_eq!(buffer.as_bytes(), exp); + assert_eq!(buffer.len(), exp.len()); + sender.flush(&mut buffer)?; + assert_eq!(buffer.len(), 0); + assert_eq!(buffer.as_bytes(), b""); + assert_eq!(server.recv_q()?, 1); + assert_eq!(server.msgs[0], *exp); + Ok(()) +} + +#[test] +fn test_array_f64_basic() -> TestResult { + let mut server = MockServer::new()?; + let mut sender = server.lsb_tcp().build()?; + server.accept()?; + + let ts = TimestampNanos::now(); + + let mut buffer = + Buffer::new().with_line_proto_version(sender.default_line_protocol_version())?; + buffer + .table("my_table")? + .symbol("device", "A001")? + .column_f64("f1", 25.5)? + .column_arr("arr1d", &[1.0, 2.0, 3.0])? + .at(ts)?; + + assert_eq!(server.recv_q()?, 0); + + let exp = &[ + b"my_table,device=A001 ", + f64_to_bytes("f1", 25.5, LineProtocolVersion::V2).as_slice(), + b",arr1d=", + b"=", // binary field + &[ingress::ARRAY_BINARY_FORMAT_TYPE], + &[ArrayColumnTypeTag::Double.into()], + &[1u8], // 1D array + &3u32.to_le_bytes(), // 3 elements + &1.0f64.to_le_bytes(), + &2.0f64.to_le_bytes(), + &3.0f64.to_le_bytes(), + format!(" {}\n", ts.as_i64()).as_bytes(), + ] + .concat(); + + assert_eq!(buffer.as_bytes(), exp); assert_eq!(buffer.len(), exp.len()); sender.flush(&mut buffer)?; assert_eq!(buffer.len(), 0); assert_eq!(buffer.as_bytes(), b""); assert_eq!(server.recv_q()?, 1); - assert_eq!(server.msgs[0].as_bytes(), exp_byte); + assert_eq!(server.msgs[0].as_slice(), exp); Ok(()) } +#[cfg(feature = "ndarray")] #[test] -fn test_max_buf_size() -> TestResult { +fn test_array_f64_from_ndarray() -> TestResult { + let mut server = MockServer::new()?; + let mut sender = server.lsb_tcp().build()?; + server.accept()?; + + let ts = TimestampNanos::now(); + let array_2d = arr2(&[[1.1, 2.2], [3.3, 4.4]]); + let array_3d = ArrayD::<f64>::ones(vec![2, 3, 4]); + + let mut buffer = + Buffer::new().with_line_proto_version(sender.default_line_protocol_version())?; + buffer + .table("my_table")? + .symbol("device", "A001")? + .column_f64("f1", 25.5)? + .column_arr("arr2d", &array_2d.view())? + .column_arr("arr3d", &array_3d.view())? + .at(ts)?; + + assert_eq!(server.recv_q()?, 0); + + let array_header2d = &[ + &[b'='], + &[ingress::ARRAY_BINARY_FORMAT_TYPE], + &[ArrayColumnTypeTag::Double.into()], + &[2u8], + &2i32.to_le_bytes(), + &2i32.to_le_bytes(), + ] + .concat(); + let mut array_data2d = vec![0u8; 4 * size_of::<f64>()]; + write_array_data(&array_2d.view(), &mut &mut array_data2d[0..])?; + + let array_header3d = &[ + &[b'='][..], + &[ingress::ARRAY_BINARY_FORMAT_TYPE], + &[ArrayColumnTypeTag::Double.into()], + &[3u8], + &2i32.to_le_bytes(), + &3i32.to_le_bytes(), + &4i32.to_le_bytes(), + ] + .concat(); + let mut array_data3d = vec![0u8; 24 * size_of::<f64>()]; + write_array_data(&array_3d.view(), &mut &mut array_data3d[0..])?; + + let exp = &[ + "my_table,device=A001 ".as_bytes(), + f64_to_bytes("f1", 25.5, LineProtocolVersion::V2).as_slice(), + ",arr2d=".as_bytes(), + array_header2d, + array_data2d.as_slice(), + ",arr3d=".as_bytes(), + array_header3d, + array_data3d.as_slice(), + format!(" {}\n", ts_nanos_num).as_bytes(), + ] + .concat(); + + assert_eq!(buffer.as_bytes(), exp); + assert_eq!(buffer.len(), exp.len()); + sender.flush(&mut buffer)?; + assert_eq!(buffer.len(), 0); + assert_eq!(buffer.as_bytes(), b""); + assert_eq!(server.recv_q()?, 1); + assert_eq!(server.msgs[0].as_slice(), exp); + Ok(()) +} + +#[rstest] +fn test_max_buf_size( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { let max = 1024; let mut server = MockServer::new()?; let mut sender = server.lsb_tcp().max_buf_size(max)?.build()?; assert!(!sender.must_close()); server.accept()?; - - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; while buffer.len() < max { buffer @@ -102,10 +235,20 @@ fn test_max_buf_size() -> TestResult { let err = sender.flush(&mut buffer).unwrap_err(); assert_eq!(err.code(), ErrorCode::InvalidApiCall); - assert_eq!( - err.msg(), - "Could not flush buffer: Buffer size of 1026 exceeds maximum configured allowed size of 1024 bytes." - ); + match version { + LineProtocolVersion::V1 => { + assert_eq!( + err.msg(), + "Could not flush buffer: Buffer size of 1026 exceeds maximum configured allowed size of 1024 bytes." + ); + } + LineProtocolVersion::V2 => { + assert_eq!( + err.msg(), + "Could not flush buffer: Buffer size of 1025 exceeds maximum configured allowed size of 1024 bytes." + ); + } + } Ok(()) } @@ -262,6 +405,8 @@ fn test_bad_key( #[test] fn test_timestamp_overloads() -> TestResult { + use std::time::SystemTime; + let tbl_name = TableName::new("tbl_name")?; let mut buffer = Buffer::new(); @@ -357,8 +502,16 @@ fn test_str_column_name_too_long() -> TestResult { column_name_too_long_test_impl!(column_str, "value") } +#[cfg(feature = "ndarray")] #[test] -fn test_tls_with_file_ca() -> TestResult { +fn test_arr_column_name_too_long() -> TestResult { + column_name_too_long_test_impl!(column_arr, &arr1(&[1.0, 2.0, 3.0]).view()) +} + +#[rstest] +fn test_tls_with_file_ca( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { let mut ca_path = certs_dir(); ca_path.push("server_rootCA.pem"); @@ -368,7 +521,7 @@ fn test_tls_with_file_ca() -> TestResult { let mut sender = lsb.build()?; let mut server: MockServer = server_jh.join().unwrap()?; - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -376,12 +529,17 @@ fn test_tls_with_file_ca() -> TestResult { .at(TimestampNanos::new(10000000))?; assert_eq!(server.recv_q()?, 0); - let exp = b"test,t1=v1 f1=0.5 10000000\n"; + let exp = &[ + "test,t1=v1 ".as_bytes(), + f64_to_bytes("f1", 0.5, version).as_slice(), + " 10000000\n".as_bytes(), + ] + .concat(); assert_eq!(buffer.as_bytes(), exp); assert_eq!(buffer.len(), exp.len()); sender.flush(&mut buffer)?; assert_eq!(server.recv_q()?, 1); - assert_eq!(server.msgs[0].as_bytes(), exp); + assert_eq!(server.msgs[0].as_slice(), exp); Ok(()) } @@ -453,15 +611,17 @@ fn test_plain_to_tls_server() -> TestResult { } #[cfg(feature = "insecure-skip-verify")] -#[test] -fn test_tls_insecure_skip_verify() -> TestResult { +#[rstest] +fn test_tls_insecure_skip_verify( + #[values(LineProtocolVersion::V1, LineProtocolVersion::V2)] version: LineProtocolVersion, +) -> TestResult { let server = MockServer::new()?; let lsb = server.lsb_tcps().tls_verify(false)?; let server_jh = server.accept_tls(); let mut sender = lsb.build()?; let mut server: MockServer = server_jh.join().unwrap()?; - let mut buffer = Buffer::new(); + let mut buffer = Buffer::new().with_line_proto_version(version)?; buffer .table("test")? .symbol("t1", "v1")? @@ -469,12 +629,17 @@ fn test_tls_insecure_skip_verify() -> TestResult { .at(TimestampNanos::new(10000000))?; assert_eq!(server.recv_q()?, 0); - let exp = b"test,t1=v1 f1=0.5 10000000\n"; + let exp = &[ + "test,t1=v1 ".as_bytes(), + f64_to_bytes("f1", 0.5, version).as_slice(), + " 10000000\n".as_bytes(), + ] + .concat(); assert_eq!(buffer.as_bytes(), exp); assert_eq!(buffer.len(), exp.len()); sender.flush(&mut buffer)?; assert_eq!(server.recv_q()?, 1); - assert_eq!(server.msgs[0].as_bytes(), exp); + assert_eq!(server.msgs[0].as_slice(), exp); Ok(()) } @@ -495,3 +660,22 @@ fn bad_uppercase_addr() { assert!(err.code() == ErrorCode::ConfigError); assert!(err.msg() == "Missing \"addr\" parameter in config string"); } + +fn f64_to_bytes(name: &str, value: f64, version: LineProtocolVersion) -> Vec<u8> { + let mut buf = Vec::new(); + buf.extend_from_slice(name.as_bytes()); + buf.push(b'='); + + match version { + LineProtocolVersion::V1 => { + let mut ser = crate::ingress::F64Serializer::new(value); + buf.extend_from_slice(ser.as_str().as_bytes()); + } + LineProtocolVersion::V2 => { + buf.push(b'='); + buf.push(crate::ingress::DOUBLE_BINARY_FORMAT_TYPE); + buf.extend_from_slice(&value.to_le_bytes()); + } + } + buf +} diff --git a/system_test/fixture.py b/system_test/fixture.py index 40dca629..c992127a 100644 --- a/system_test/fixture.py +++ b/system_test/fixture.py @@ -23,12 +23,12 @@ ################################################################################ import sys + sys.dont_write_bytecode = True import os import re import pathlib -import textwrap import json import tarfile import shutil @@ -42,11 +42,9 @@ import urllib.error from pprint import pformat - AUTH_TXT = """admin ec-p-256-sha256 fLKYEaoEb9lrn3nkwLDA-M_xnuFOdSt9y0Z7_vWSHLU Dt5tbS1dEDMSYfym3fgMv0B99szno-dFc1rYF9t0aac # [key/user id] [key type] {keyX keyY}""" - # Valid keys as registered with the QuestDB fixture. AUTH = dict( username="admin", @@ -54,18 +52,17 @@ token_x="fLKYEaoEb9lrn3nkwLDA-M_xnuFOdSt9y0Z7_vWSHLU", token_y="Dt5tbS1dEDMSYfym3fgMv0B99szno-dFc1rYF9t0aac") - CA_PATH = (pathlib.Path(__file__).parent.parent / - 'tls_certs' / 'server_rootCA.pem') + 'tls_certs' / 'server_rootCA.pem') def retry( - predicate_task, - timeout_sec=30, - every=0.05, - msg='Timed out retrying', - backoff_till=5.0, - lead_sleep=0.001): + predicate_task, + timeout_sec=30, + every=0.05, + msg='Timed out retrying', + backoff_till=5.0, + lead_sleep=0.001): """ Repeat task every `interval` until it returns a truthy value or times out. """ @@ -121,8 +118,8 @@ def __init__(self): def list_questdb_releases(max_results=1): url = ( - 'https://api.github.com/repos/questdb/questdb/releases?' + - urllib.parse.urlencode({'per_page': max_results})) + 'https://api.github.com/repos/questdb/questdb/releases?' + + urllib.parse.urlencode({'per_page': max_results})) req = urllib.request.Request( url, headers={ @@ -351,8 +348,8 @@ def check_http_up(): def http_sql_query(self, sql_query): url = ( - f'http://{self.host}:{self.http_server_port}/exec?' + - urllib.parse.urlencode({'query': sql_query})) + f'http://{self.host}:{self.http_server_port}/exec?' + + urllib.parse.urlencode({'query': sql_query})) buf = None try: resp = urllib.request.urlopen(url, timeout=5) @@ -370,7 +367,7 @@ def http_sql_query(self, sql_query): if 'error' in data: raise QueryError(data['error']) return data - + def query_version(self): try: res = self.http_sql_query('select build') @@ -397,6 +394,7 @@ def retry_check_table( log_ctx=None): sql_query = f"select * from '{table_name}'" http_response_log = [] + def check_table(): try: resp = self.http_sql_query(sql_query) @@ -414,7 +412,8 @@ def check_table(): except TimeoutError as toe: if log: if log_ctx: - log_ctx = f'\n{textwrap.indent(log_ctx, " ")}\n' + log_ctx_str = log_ctx.decode('utf-8', errors='replace') + log_ctx = f'\n{textwrap.indent(log_ctx_str, " ")}\n' sys.stderr.write( f'Timed out after {timeout_sec} seconds ' + f'waiting for query {sql_query!r}. ' + @@ -493,7 +492,7 @@ def check_started(): self.listen_port = retry( check_started, timeout_sec=180, # Longer to include time to compile. - msg='Timed out waiting for `tls_proxy` to start.',) + msg='Timed out waiting for `tls_proxy` to start.', ) def connect_to_listening_port(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -518,4 +517,3 @@ def stop(self): if self._log_file: self._log_file.close() self._log_file = None - diff --git a/system_test/questdb_line_sender.py b/system_test/questdb_line_sender.py index 5d4eca85..595049fd 100644 --- a/system_test/questdb_line_sender.py +++ b/system_test/questdb_line_sender.py @@ -37,9 +37,10 @@ """ - -from ast import arg import sys + +import numpy + sys.dont_write_bytecode = True import pathlib @@ -61,58 +62,96 @@ c_void_p, c_ssize_t) -from typing import Optional, Tuple, Union +from typing import Optional, Union class c_line_sender(ctypes.Structure): pass + class c_line_sender_buffer(ctypes.Structure): pass + c_line_sender_protocol = ctypes.c_int + class Protocol(Enum): TCP = (c_line_sender_protocol(0), 'tcp') TCPS = (c_line_sender_protocol(1), 'tcps') HTTP = (c_line_sender_protocol(2), 'http') HTTPS = (c_line_sender_protocol(3), 'https') + c_line_sender_ca = ctypes.c_int + class CertificateAuthority(Enum): WEBPKI_ROOTS = (c_line_sender_ca(0), 'webpki_roots') OS_ROOTS = (c_line_sender_ca(1), 'os_roots') WEBPKI_AND_OS_ROOTS = (c_line_sender_ca(2), 'webpki_and_os_roots') PEM_FILE = (c_line_sender_ca(3), 'pem_file') + +c_line_protocol_version = ctypes.c_int + + +class LineProtocolVersion(Enum): + V1 = (c_line_protocol_version(1), 'v1') + V2 = (c_line_protocol_version(2), 'v2') + + @classmethod + def from_int(cls, value: c_line_protocol_version): + for member in cls: + if member.value[0].value == value: + return member + raise ValueError(f"invalid protocol version: {value}") + + class c_line_sender_opts(ctypes.Structure): pass + class c_line_sender_error(ctypes.Structure): pass + c_size_t_p = ctypes.POINTER(c_size_t) +c_ssize_t_p = ctypes.POINTER(c_ssize_t) c_line_sender_p = ctypes.POINTER(c_line_sender) c_line_sender_buffer_p = ctypes.POINTER(c_line_sender_buffer) c_line_sender_opts_p = ctypes.POINTER(c_line_sender_opts) c_line_sender_error_p = ctypes.POINTER(c_line_sender_error) c_line_sender_error_p_p = ctypes.POINTER(c_line_sender_error_p) +c_uint8_p = ctypes.POINTER(c_uint8) + + class c_line_sender_utf8(ctypes.Structure): _fields_ = [("len", c_size_t), ("buf", c_char_p)] + + c_line_sender_utf8_p = ctypes.POINTER(c_line_sender_utf8) + + class c_line_sender_table_name(ctypes.Structure): _fields_ = [("len", c_size_t), ("buf", c_char_p)] + + class line_sender_buffer_view(ctypes.Structure): _fields_ = [("len", c_size_t), - ("buf", ctypes.POINTER(c_uint8))] + ("buf", c_uint8_p)] + c_line_sender_table_name_p = ctypes.POINTER(c_line_sender_table_name) + + class c_line_sender_column_name(ctypes.Structure): _fields_ = [("len", c_size_t), ("buf", c_char_p)] + + c_line_sender_column_name_p = ctypes.POINTER(c_line_sender_column_name) @@ -129,7 +168,7 @@ def _setup_cdll(): 'darwin': 'dylib', 'win32': 'dll'}[sys.platform] dll_path = next( - build_dir.glob(f'**/*questdb_client*.{dll_ext}')) + build_dir.glob(f'**/*questdb_client*.{dll_ext}')) dll = ctypes.CDLL(str(dll_path)) @@ -176,6 +215,12 @@ def set_sig(fn, restype, *argtypes): dll.line_sender_buffer_with_max_name_len, c_line_sender_buffer_p, c_size_t) + set_sig( + dll.line_sender_buffer_set_line_protocol_version, + c_bool, + c_line_sender_buffer_p, + c_line_protocol_version, + c_line_sender_error_p_p) set_sig( dll.line_sender_buffer_free, None, @@ -237,6 +282,17 @@ def set_sig(fn, restype, *argtypes): c_line_sender_column_name, c_line_sender_utf8, c_line_sender_error_p_p) + set_sig( + dll.line_sender_buffer_column_f64_arr, + c_bool, + c_line_sender_buffer_p, + c_line_sender_column_name, + c_size_t, + c_size_t_p, + c_ssize_t_p, + c_uint8_p, + c_size_t, + c_line_sender_error_p_p) set_sig( dll.line_sender_buffer_column_ts_nanos, c_bool, @@ -316,6 +372,11 @@ def set_sig(fn, restype, *argtypes): c_line_sender_opts_p, c_line_sender_utf8, c_line_sender_error_p_p) + set_sig( + dll.line_sender_opts_disable_line_protocol_validation, + c_bool, + c_line_sender_opts_p, + c_line_sender_error_p_p) set_sig( dll.line_sender_opts_auth_timeout, c_bool, @@ -386,6 +447,10 @@ def set_sig(fn, restype, *argtypes): dll.line_sender_from_env, c_line_sender_p, c_line_sender_error_p_p) + set_sig( + dll.line_sender_default_line_protocol_version, + c_line_protocol_version, + c_line_sender_p) set_sig( dll.line_sender_must_close, None, @@ -513,7 +578,10 @@ def __init__(self, host, port, protocol=Protocol.TCP): def __getattr__(self, name: str): fn = getattr(_DLL, 'line_sender_opts_' + name) + def wrapper(*args): + if name == 'disable_line_protocol_validation': + return _error_wrapped_call(fn, self.impl) mapped_args = [ (_utf8(arg) if isinstance(arg, str) else arg) for arg in args] @@ -521,6 +589,7 @@ def wrapper(*args): return _error_wrapped_call(fn, self.impl, *mapped_args) else: return fn(self.impl, *mapped_args) + return wrapper def __del__(self): @@ -533,16 +602,20 @@ def __init__(self, micros: int): class Buffer: - def __init__(self, init_buf_size=65536, max_name_len=127): + def __init__(self, init_buf_size=65536, max_name_len=127, line_protocol_version=LineProtocolVersion.V2): self._impl = _DLL.line_sender_buffer_with_max_name_len( c_size_t(max_name_len)) _DLL.line_sender_buffer_reserve(self._impl, c_size_t(init_buf_size)) + _error_wrapped_call( + _DLL.line_sender_buffer_set_line_protocol_version, + self._impl, + line_protocol_version.value[0]) def __len__(self): return _DLL.line_sender_buffer_size(self._impl) def peek(self) -> str: - # This is a hacky way of doing it because it copies the whole buffer. + # This is a hacky way of doing it because it copies the whole buffer. # Instead the `buffer` should be made to support the buffer protocol: # https://docs.python.org/3/c-api/buffer.html # This way we would not need to `bytes(..)` the object to keep it alive. @@ -554,6 +627,12 @@ def peek(self) -> str: else: return '' + def set_line_protocol_version(self, version: LineProtocolVersion): + _error_wrapped_call( + _DLL.line_sender_buffer_set_line_protocol_version, + self._impl, + version.value[0]) + def reserve(self, additional): _DLL.line_sender_buffer_reserve(self._impl, c_size_t(additional)) @@ -627,6 +706,34 @@ def column( '`bool`, `int`, `float` or `str`.') return self + def column_f64_arr(self, name: str, + rank: int, + shapes: tuple[int, ...], + strides: tuple[int, ...], + data: c_void_p, + length: int): + def _convert_tuple(tpl: tuple[int, ...], c_type: type, name: str) -> ctypes.POINTER: + arr_type = c_type * len(tpl) + try: + return arr_type(*[c_type(v) for v in tpl]) + except OverflowError as e: + raise ValueError( + f"{name} value exceeds {c_type.__name__} range" + ) from e + + c_shapes = _convert_tuple(shapes, c_size_t, "shapes") + c_strides = _convert_tuple(strides, c_ssize_t, "strides") + _error_wrapped_call( + _DLL.line_sender_buffer_column_f64_arr, + self._impl, + _column_name(name), + c_size_t(rank), + c_shapes, + c_strides, + ctypes.cast(data, c_uint8_p), + c_size_t(length) + ) + def at_now(self): _error_wrapped_call( _DLL.line_sender_buffer_at_now, @@ -671,7 +778,7 @@ def __init__( host: str, port: Union[str, int], **kwargs): - + self._build_mode = build_mode self._impl = None self._conf = [ @@ -679,7 +786,6 @@ def __init__( '::', f'addr={host}:{port};'] self._opts = None - self._buffer = Buffer() opts = _Opts(host, port, protocol) for key, value in kwargs.items(): # Build the config string param pair. @@ -716,12 +822,18 @@ def connect(self): def __enter__(self): self.connect() + self._buffer = Buffer( + line_protocol_version=LineProtocolVersion.from_int(self.line_sender_default_line_protocol_version())) return self def _check_connected(self): if not self._impl: raise SenderError('Not connected.') + def line_sender_default_line_protocol_version(self): + self._check_connected() + return _DLL.line_sender_default_line_protocol_version(self._impl) + def table(self, table: str): self._buffer.table(table) return self @@ -736,13 +848,21 @@ def column( self._buffer.column(name, value) return self + def column_f64_arr( + self, name: str, + array: numpy.ndarray): + if array.dtype != numpy.float64: + raise ValueError('expect float64 array') + self._buffer.column_f64_arr(name, array.ndim, array.shape, array.strides, array.ctypes.data, array.nbytes) + return self + def at_now(self): self._buffer.at_now() def at(self, timestamp: int): self._buffer.at(timestamp) - def flush(self, buffer: Optional[Buffer]=None, clear=True, transactional=None): + def flush(self, buffer: Optional[Buffer] = None, clear=True, transactional=None): if (buffer is None) and not clear: raise ValueError( 'Clear flag must be True when using internal buffer') diff --git a/system_test/test.py b/system_test/test.py index c28ccd6d..19dc3c0c 100755 --- a/system_test/test.py +++ b/system_test/test.py @@ -25,14 +25,15 @@ ################################################################################ import sys + sys.dont_write_bytecode = True import os - import pathlib import math import datetime import argparse import unittest +import numpy as np import time import questdb_line_sender as qls import uuid @@ -47,7 +48,6 @@ import subprocess from collections import namedtuple - QDB_FIXTURE: QuestDbFixture = None TLS_PROXY_FIXTURE: TlsProxyFixture = None BUILD_MODE = None @@ -72,7 +72,6 @@ def ns_to_qdb_date(at_ts_ns): token_x="-nSHz3evuPl-rGLIlbIZjwOJeWao0rbk53Cll6XEgak", token_y="9iYksF4L5mfmArupv0CMoyVAWjQ4gNIoupdg6N5noG8") - # Bad malformed key AUTH_MALFORMED1 = dict( username="testUser3", @@ -80,7 +79,6 @@ def ns_to_qdb_date(at_ts_ns): token_x="-nSHz3evuPl-rGLIlbIZjwOJeWao0rbk53Cll6XEgak", token_y="9iYksF4L6mfmArupv0CMoyVAWjQ4gNIoupdg6N5noG8") - # Another malformed key where the keys invalid base 64. AUTH_MALFORMED2 = dict( username="testUser4", @@ -88,7 +86,6 @@ def ns_to_qdb_date(at_ts_ns): token_x="-nSHz3evuPl-rGLIlbIZjwOJeWao0rbk5XEgak", token_y="9iYksF4L6mfmArupv0CMoyVAWjQ4gNIou5noG8") - # All the keys are valid, but the username is wrong. AUTH_MALFORMED3 = dict( username="wrongUser", @@ -98,9 +95,11 @@ def ns_to_qdb_date(at_ts_ns): class TestSender(unittest.TestCase): - def _mk_linesender(self): + def _mk_linesender(self, disable_line_protocol_validation=False): # N.B.: We never connect with TLS here. auth = AUTH if QDB_FIXTURE.auth else {} + if disable_line_protocol_validation: + auth["disable_line_protocol_validation"] = "on" return qls.Sender( BUILD_MODE, qls.Protocol.HTTP if QDB_FIXTURE.http else qls.Protocol.TCP, @@ -115,9 +114,9 @@ def _expect_eventual_disconnect(self, sender): for _ in range(1000): time.sleep(0.1) (sender - .table(table_name) - .symbol('s1', 'v1') - .at_now()) + .table(table_name) + .symbol('s1', 'v1') + .at_now()) sender.flush() def test_insert_three_rows(self): @@ -126,13 +125,13 @@ def test_insert_three_rows(self): with self._mk_linesender() as sender: for _ in range(3): (sender - .table(table_name) - .symbol('name_a', 'val_a') - .column('name_b', True) - .column('name_c', 42) - .column('name_d', 2.5) - .column('name_e', 'val_b') - .at_now()) + .table(table_name) + .symbol('name_a', 'val_a') + .column('name_b', True) + .column('name_c', 42) + .column('name_d', 2.5) + .column('name_e', 'val_b') + .at_now()) pending = sender.buffer.peek() sender.flush() @@ -161,12 +160,12 @@ def test_repeated_symbol_and_column_names(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .symbol('a', 'A') - .symbol('a', 'B') - .column('b', False) - .column('b', 'C') - .at_now()) + .table(table_name) + .symbol('a', 'A') + .symbol('a', 'B') + .column('b', False) + .column('b', 'C') + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -188,10 +187,10 @@ def test_same_symbol_and_col_name(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .symbol('a', 'A') - .column('a', 'B') - .at_now()) + .table(table_name) + .symbol('a', 'A') + .column('a', 'B') + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -209,9 +208,9 @@ def _test_single_symbol_impl(self, sender): pending = None with sender: (sender - .table(table_name) - .symbol('a', 'A') - .at_now()) + .table(table_name) + .symbol('a', 'A') + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -232,10 +231,10 @@ def test_two_columns(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .column('a', 'A') - .column('b', 'B') - .at_now()) + .table(table_name) + .column('a', 'A') + .column('b', 'B') + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -254,13 +253,13 @@ def test_mismatched_types_across_rows(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .column('a', 1) # LONG - .at_now()) + .table(table_name) + .column('a', 1) # LONG + .at_now()) (sender - .table(table_name) - .symbol('a', 'B') # SYMBOL - .at_now()) + .table(table_name) + .symbol('a', 'B') # SYMBOL + .at_now()) pending = sender.buffer.peek() @@ -304,9 +303,9 @@ def test_at(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .symbol('a', 'A') - .at(at_ts_ns)) + .table(table_name) + .symbol('a', 'A') + .at(at_ts_ns)) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) exp_dataset = [['A', ns_to_qdb_date(at_ts_ns)]] @@ -322,9 +321,9 @@ def test_neg_at(self): with self._mk_linesender() as sender: with self.assertRaisesRegex(qls.SenderError, r'.*Timestamp .* is negative.*'): (sender - .table(table_name) - .symbol('a', 'A') - .at(at_ts_ns)) + .table(table_name) + .symbol('a', 'A') + .at(at_ts_ns)) def test_timestamp_col(self): if QDB_FIXTURE.version <= (6, 0, 7, 1): @@ -334,13 +333,13 @@ def test_timestamp_col(self): pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .column('a', qls.TimestampMicros(-1000000)) - .at_now()) + .table(table_name) + .column('a', qls.TimestampMicros(-1000000)) + .at_now()) (sender - .table(table_name) - .column('a', qls.TimestampMicros(1000000)) - .at_now()) + .table(table_name) + .column('a', qls.TimestampMicros(1000000)) + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -353,16 +352,15 @@ def test_timestamp_col(self): scrubbed_dataset = [row[:-1] for row in resp['dataset']] self.assertEqual(scrubbed_dataset, exp_dataset) - def test_underscores(self): table_name = f'_{uuid.uuid4().hex}_' pending = None with self._mk_linesender() as sender: (sender - .table(table_name) - .symbol('_a_b_c_', 'A') - .column('_d_e_f_', True) - .at_now()) + .table(table_name) + .symbol('_a_b_c_', 'A') + .column('_d_e_f_', True) + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -422,16 +420,16 @@ def test_floats(self): 1.23456789012, 1000000000000000000000000.0, -1000000000000000000000000.0, - float("nan"), # Converted to `None`. - float("inf"), # Converted to `None`. + float("nan"), # Converted to `None`. + float("inf"), # Converted to `None`. float("-inf")] # Converted to `None`. - # These values below do not round-trip properly: QuestDB limitation. - # 1.2345678901234567, - # 2.2250738585072014e-308, - # -2.2250738585072014e-308, - # 1.7976931348623157e+308, - # -1.7976931348623157e+308] + # These values below do not round-trip properly: QuestDB limitation. + # 1.2345678901234567, + # 2.2250738585072014e-308, + # -2.2250738585072014e-308, + # 1.7976931348623157e+308, + # -1.7976931348623157e+308] table_name = uuid.uuid4().hex pending = None with self._mk_linesender() as sender: @@ -469,9 +467,9 @@ def test_timestamp_column(self): ts = qls.TimestampMicros(3600000000) # One hour past epoch. with self._mk_linesender() as sender: (sender - .table(table_name) - .column('ts1', ts) - .at_now()) + .table(table_name) + .column('ts1', ts) + .at_now()) pending = sender.buffer.peek() resp = retry_check_table(table_name, log_ctx=pending) @@ -483,6 +481,185 @@ def test_timestamp_column(self): scrubbed_dataset = [row[:-1] for row in resp['dataset']] self.assertEqual(scrubbed_dataset, exp_dataset) + def test_f64_arr_column(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + table_name = uuid.uuid4().hex + array1 = np.array( + [ + [[1.1, 2.2], [3.3, 4.4]], + [[5.5, 6.6], [7.7, 8.8]] + ], + dtype=np.float64 + ) + array2 = array1.T + array3 = array1[::-1, ::-1] + + with self._mk_linesender() as sender: + (sender + .table(table_name) + .column_f64_arr('f64_arr1', array1) + .column_f64_arr('f64_arr2', array2) + .column_f64_arr('f64_arr3', array3) + .at_now()) + + resp = retry_check_table(table_name) + exp_columns = [{'dim': 3, 'elemType': 'DOUBLE', 'name': 'f64_arr1', 'type': 'ARRAY'}, + {'dim': 3, 'elemType': 'DOUBLE', 'name': 'f64_arr2', 'type': 'ARRAY'}, + {'dim': 3, 'elemType': 'DOUBLE', 'name': 'f64_arr3', 'type': 'ARRAY'}, + {'name': 'timestamp', 'type': 'TIMESTAMP'}] + self.assertEqual(resp['columns'], exp_columns) + expected_data = [[[[[1.1, 2.2], [3.3, 4.4]], [[5.5, 6.6], [7.7, 8.8]]], + [[[1.1, 5.5], [3.3, 7.7]], [[2.2, 6.6], [4.4, 8.8]]], + [[[7.7, 8.8], [5.5, 6.6]], [[3.3, 4.4], [1.1, 2.2]]]]] + scrubbed_data = [row[:-1] for row in resp['dataset']] + self.assertEqual(scrubbed_data, expected_data) + + def test_f64_arr_empty(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + table_name = uuid.uuid4().hex + empty_array = np.array([], dtype=np.float64).reshape(0, 0, 0) + with self._mk_linesender() as sender: + (sender.table(table_name) + .column_f64_arr('empty', empty_array) + .at_now()) + + resp = retry_check_table(table_name) + exp_columns = [{'dim': 3, 'elemType': 'DOUBLE', 'name': 'empty', 'type': 'ARRAY'}, + {'name': 'timestamp', 'type': 'TIMESTAMP'}] + self.assertEqual(exp_columns, resp['columns']) + self.assertEqual(resp['dataset'][0][0], []) + + def test_f64_arr_non_contiguous(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + table_name = uuid.uuid4().hex + array = np.array([[1.1, 2.2], [3.3, 4.4]], dtype=np.float64)[:, ::2] + with self._mk_linesender() as sender: + (sender.table(table_name) + .column_f64_arr('non_contiguous', array) + .at_now()) + + resp = retry_check_table(table_name) + exp_columns = [{'dim': 2, 'elemType': 'DOUBLE', 'name': 'non_contiguous', 'type': 'ARRAY'}, + {'name': 'timestamp', 'type': 'TIMESTAMP'}] + self.assertEqual(exp_columns, resp['columns']) + self.assertEqual(resp['dataset'][0][0], [[1.1], [3.3]]) + + def test_f64_arr_zero_dimensional(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + table_name = uuid.uuid4().hex + array = np.array(42.0, dtype=np.float64) + try: + with self._mk_linesender() as sender: + (sender.table(table_name) + .column_f64_arr('scalar', array) + .at_now()) + except qls.SenderError as e: + self.assertIn('Zero-dimensional arrays are not supported', str(e)) + + def test_f64_arr_wrong_datatype(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + table_name = uuid.uuid4().hex + array = np.array([1, 2], dtype=np.int32) + try: + with self._mk_linesender() as sender: + (sender.table(table_name) + .column_f64_arr('wrong', array) + .at_now()) + except ValueError as e: + self.assertIn('expect float64 array', str(e)) + + def test_f64_arr_mix_dims(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array issues support') + + array_2d = np.array([[1.1, 2.2], [3.3, 4.4]], dtype=np.float64) + array_1d = np.array([1.1], dtype=np.float64) + table_name = uuid.uuid4().hex + try: + with self._mk_linesender() as sender: + (sender.table(table_name) + .column_f64_arr('array', array_2d) + .at_now() + ) + (sender.table(table_name) + .column_f64_arr('array', array_1d) + .at_now() + ) + except qls.SenderError as e: + self.assertIn('cast error from protocol type: DOUBLE[] to column type: DOUBLE[][]', str(e)) + + def test_line_protocol_version_v1(self): + if QDB_FIXTURE.version <= (6, 1, 2): + self.skipTest('Float issues support') + numbers = [ + 0.0, + -0.0, + 1.0, + -1.0] # Converted to `None`. + + table_name = uuid.uuid4().hex + pending = None + with self._mk_linesender() as sender: + sender.buffer.set_line_protocol_version(qls.LineProtocolVersion.V1) + for num in numbers: + sender.table(table_name) + sender.column('n', num) + sender.at_now() + pending = sender.buffer.peek() + + resp = retry_check_table( + table_name, + min_rows=len(numbers), + log_ctx=pending) + exp_columns = [ + {'name': 'n', 'type': 'DOUBLE'}, + {'name': 'timestamp', 'type': 'TIMESTAMP'}] + self.assertEqual(resp['columns'], exp_columns) + + def massage(num): + if math.isnan(num) or math.isinf(num): + return None + elif num == -0.0: + return 0.0 + else: + return num + + # Comparison excludes timestamp column. + exp_dataset = [[massage(num)] for num in numbers] + scrubbed_dataset = [row[:-1] for row in resp['dataset']] + self.assertEqual(scrubbed_dataset, exp_dataset) + + def test_line_protocol_version_v1_array_unsupported(self): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array unsupported') + + array1 = np.array( + [ + [[1.1, 2.2], [3.3, 4.4]], + [[5.5, 6.6], [7.7, 8.8]] + ], + dtype=np.float64 + ) + table_name = uuid.uuid4().hex + try: + with self._mk_linesender(True) as sender: + sender.buffer.set_line_protocol_version(qls.LineProtocolVersion.V1) + sender.table(table_name) + sender.column_f64_arr('f64_arr1', array1) + sender.at_now() + except qls.SenderError as e: + self.assertIn('line protocol version v1 does not support array datatype', str(e)) + def _test_example(self, bin_name, table_name, tls=False): if BUILD_MODE != qls.BuildMode.API: self.skipTest('BuildMode.API-only test') @@ -514,7 +691,11 @@ def _test_example(self, bin_name, table_name, tls=False): {'name': 'timestamp', 'type': 'TIMESTAMP'}] self.assertEqual(resp['columns'], exp_columns) - exp_dataset = [['ETH-USD', 'sell', 2615.54, 0.00044]] # Comparison excludes timestamp column. + exp_dataset = [['ETH-USD', + 'sell', + 2615.54, + 0.00044]] + # Comparison excludes timestamp column. scrubbed_dataset = [row[:-1] for row in resp['dataset']] self.assertEqual(scrubbed_dataset, exp_dataset) @@ -544,6 +725,48 @@ def test_cpp_tls_example(self): 'cpp_trades_tls_ca', tls=True) + def test_cpp_array_example(self): + self._test_array_example( + 'line_sender_cpp_example_array', + 'cpp_market_orders') + + def test_c_array_example(self): + self._test_array_example( + 'line_sender_c_example_array', + 'market_orders') + + def _test_array_example(self, bin_name, table_name): + if QDB_FIXTURE.version < (8, 3, 1): + self.skipTest('array unsupported') + if QDB_FIXTURE.http: + self.skipTest('TCP-only test') + if BUILD_MODE != qls.BuildMode.API: + self.skipTest('BuildMode.API-only test') + if QDB_FIXTURE.auth: + self.skipTest('auth') + + proj = Project() + ext = '.exe' if sys.platform == 'win32' else '' + try: + bin_path = next(proj.build_dir.glob(f'**/{bin_name}{ext}')) + except StopIteration: + raise RuntimeError(f'Could not find {bin_name}{ext} in {proj.build_dir}') + port = QDB_FIXTURE.line_tcp_port + args = [str(bin_path)] + args.extend(['localhost', str(port)]) + subprocess.check_call(args, cwd=bin_path.parent) + resp = retry_check_table(table_name) + exp_columns = [ + {'name': 'symbol', 'type': 'SYMBOL'}, + {'dim': 3, 'elemType': 'DOUBLE', 'name': 'order_book', 'type': 'ARRAY'}, + {'name': 'timestamp', 'type': 'TIMESTAMP'}] + self.assertEqual(resp['columns'], exp_columns) + exp_dataset = [['BTC-USD', + [[[48123.5, 2.4], [48124.0, 1.8], [48124.5, 0.9]], + [[48122.5, 3.1], [48122.0, 2.7], [48121.5, 4.3]]]]] + scrubbed_dataset = [row[:-1] for row in resp['dataset']] + self.assertEqual(scrubbed_dataset, exp_dataset) + def test_opposite_auth(self): """ We simulate incorrectly connecting either: @@ -570,9 +793,9 @@ def test_opposite_auth(self): # The sending the first line will not fail. (sender - .table(table_name) - .symbol('s1', 'v1') - .at_now()) + .table(table_name) + .symbol('s1', 'v1') + .at_now()) sender.flush() self._expect_eventual_disconnect(sender) @@ -663,7 +886,7 @@ def test_tls_insecure_skip_verify(self): def test_tls_roots(self): protocol = qls.Protocol.HTTPS if QDB_FIXTURE.http else qls.Protocol.TCPS - auth = auth=AUTH if QDB_FIXTURE.auth else {} + auth = auth = AUTH if QDB_FIXTURE.auth else {} sender = qls.Sender( BUILD_MODE, protocol, @@ -679,7 +902,7 @@ def _test_tls_ca(self, tls_ca): try: os.environ['SSL_CERT_FILE'] = str( Project().tls_certs_dir / 'server_rootCA.pem') - auth = auth=AUTH if QDB_FIXTURE.auth else {} + auth = auth = AUTH if QDB_FIXTURE.auth else {} sender = qls.Sender( BUILD_MODE, protocol, @@ -803,14 +1026,15 @@ def run_with_existing(args): global QDB_FIXTURE MockFixture = namedtuple( 'MockFixture', - ('host', 'line_tcp_port', 'http_server_port', 'version', 'http')) + ('host', 'line_tcp_port', 'http_server_port', 'version', 'http', "auth")) host, line_tcp_port, http_server_port = args.existing.split(':') QDB_FIXTURE = MockFixture( host, int(line_tcp_port), int(http_server_port), (999, 999, 999), - True) + True, + False) unittest.main() @@ -831,11 +1055,11 @@ def iter_versions(args): if versions_args: versions = { version: ( - 'https://github.com/questdb/questdb/releases/download/' + - version + - '/questdb-' + - version + - '-no-jre-bin.tar.gz') + 'https://github.com/questdb/questdb/releases/download/' + + version + + '/questdb-' + + version + + '-no-jre-bin.tar.gz') for version in versions_args} else: last_n = getattr(args, 'last_n', None) or 1 @@ -857,7 +1081,8 @@ def run_with_fixtures(args): for auth in (False, True): for http in (False, True): for build_mode in list(qls.BuildMode): - print(f'Running tests [questdb_dir={questdb_dir}, auth={auth}, http={http}, build_mode={build_mode}]') + print( + f'Running tests [questdb_dir={questdb_dir}, auth={auth}, http={http}, build_mode={build_mode}]') if http and last_version <= (7, 3, 7): print('Skipping ILP/HTTP tests for versions <= 7.3.7') continue