diff --git a/.github/actions/compile-libraries/action.yml b/.github/actions/compile-libraries/action.yml index b8767a314..fd3c94ae4 100644 --- a/.github/actions/compile-libraries/action.yml +++ b/.github/actions/compile-libraries/action.yml @@ -19,6 +19,4 @@ runs: if: inputs.os == 'Windows' shell: pwsh run: | - cmake --preset windows-x64 - cmake --build --preset windows-x64 --config Debug - cmake --install build_windows_x64 --config Debug + ./scripts/build.ps1 diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 370df92a5..52bdca95c 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -9,29 +9,54 @@ inputs: runs: using: "composite" steps: - - name: Install dependencies for MITM tests - if: inputs.os == 'Linux' + - name: Install scapy for MITM tests shell: bash run: uv pip install --system scapy==2.* + - name: Install pydivert and pywin32 for MITM tests (Windows) + if: inputs.os == 'Windows' + shell: pwsh + run: uv pip install --system pydivert pywin32 + - name: Run C++ Tests (Linux) if: inputs.os == 'Linux' shell: bash run: sudo ./scripts/test.sh - # TODO: build wheel first, then run the test + - name: Run C++ Tests (Windows) + if: inputs.os == 'Windows' + shell: pwsh + run: ./scripts/test.ps1 -C Debug + + - name: Build Wheel + if: inputs.os != 'Windows' + shell: bash + run: | + uv pip install --system build + python -m build --wheel + + - name: Install Wheel + if: inputs.os != 'Windows' + shell: bash + run: | + wheel_file=$(ls dist/*.whl | head -n 1) + echo "Installing wheel: $wheel_file" + uv pip install --system "$wheel_file" + - name: Run Unittests + if: inputs.os != 'Windows' shell: bash run: | - PYTHONPATH=src python -m unittest discover -v tests + python -m unittest discover -v tests - name: Run Examples + if: inputs.os != 'Windows' shell: bash run: | uv pip install --system -r examples/applications/requirements_applications.txt for example in "./examples"/*.py; do echo "Running $example" - PYTHONPATH=src python $example + python $example done for example in "./examples/applications"/*.py; do if python -c 'import sys; sys.exit(not sys.version_info <= (3, 10))'; then @@ -42,5 +67,5 @@ runs: fi echo "Running $example" - PYTHONPATH=src python $example + python $example done diff --git a/CMakePresets.json b/CMakePresets.json index dd5858ce1..0f79bb889 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -6,7 +6,7 @@ "displayName": "Linux 64 bit Intel build", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}" + "CMAKE_INSTALL_PREFIX": "${sourceDir}/src" }, "binaryDir": "${sourceDir}/build_linux_x86_64" }, @@ -15,7 +15,7 @@ "displayName": "Linux 64 bit ARM build", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}" + "CMAKE_INSTALL_PREFIX": "${sourceDir}/src" }, "binaryDir": "${sourceDir}/build_linux_aarch64" }, @@ -24,7 +24,7 @@ "displayName": "Apple ARM 64 bit build", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}" + "CMAKE_INSTALL_PREFIX": "${sourceDir}/src" }, "binaryDir": "${sourceDir}/build_darwin_arm64" }, @@ -34,7 +34,7 @@ "generator": "Visual Studio 17 2022", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_INSTALL_PREFIX": "${sourceDir}" + "CMAKE_INSTALL_PREFIX": "${sourceDir}/src" }, "architecture": { "value": "x64", diff --git a/pyproject.toml b/pyproject.toml index fe2f93a97..18de20d19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ uvloop = [ ] gui = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", - "nicegui[plotly]==3.2.0; python_version >= '3.9'", + "nicegui[plotly]==3.3.1; python_version >= '3.9'", ] graphblas = [ "python-graphblas", @@ -50,7 +50,7 @@ aws = [ ] all = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", - "nicegui[plotly]==3.2.0; python_version >= '3.9'", + "nicegui[plotly]==3.3.1; python_version >= '3.9'", "python-graphblas", "numpy==1.24.4; python_version == '3.8'", "numpy==2.0.2; python_version == '3.9'", diff --git a/scripts/build.ps1 b/scripts/build.ps1 new file mode 100644 index 000000000..77d08eb57 --- /dev/null +++ b/scripts/build.ps1 @@ -0,0 +1,28 @@ +# This script builds and installs in-place the C++ components. +# +# Usage: +# ./scripts/build.ps1 + +$ErrorActionPreference = "Stop" +$OS = "windows" +$ARCH = "x64" +$BUILD_DIR = "build_${OS}_${ARCH}" +$BUILD_PRESET = "${OS}-${ARCH}" + +# Clean up previous build artifacts +if (Test-Path $BUILD_DIR) { + Remove-Item -Recurse -Force $BUILD_DIR +} +Get-ChildItem "scaler/protocol/capnp" -Include *.c++, *.h -ErrorAction SilentlyContinue | Remove-Item -Force + +Write-Host "Build directory: $BUILD_DIR" +Write-Host "Build preset: $BUILD_PRESET" + +# Configure +cmake --preset $BUILD_PRESET @args + +# Build +cmake --build --preset $BUILD_PRESET + +# Install +cmake --install $BUILD_DIR diff --git a/scripts/library_tool.ps1 b/scripts/library_tool.ps1 index f61e5e583..20aacc09c 100755 --- a/scripts/library_tool.ps1 +++ b/scripts/library_tool.ps1 @@ -80,20 +80,25 @@ elseif ($dependency -eq "capnp") } elseif ($action -eq "compile") { - Remove-Item -Path "$THIRD_PARTY_DOWNLOADED\$CAPNP_FOLDER_NAME" -Recurse -Force - tar -xzvf "$THIRD_PARTY_COMPILED\$CAPNP_FOLDER_NAME.tar.gz" -C "$THIRD_PARTY_COMPILED\$CAPNP_FOLDER_NAME" + Remove-Item -Path "$THIRD_PARTY_DOWNLOADED\$CAPNP_FOLDER_NAME" -Recurse -Force -ErrorAction SilentlyContinue + mkdir "$THIRD_PARTY_COMPILED" -Force + tar -xzvf "$THIRD_PARTY_DOWNLOADED\$CAPNP_FOLDER_NAME.tar.gz" -C "$THIRD_PARTY_COMPILED" # Configure and build with Visual Studio using CMake + $oldDir = Get-Location Set-Location -Path "$THIRD_PARTY_COMPILED\$CAPNP_FOLDER_NAME" cmake -G "Visual Studio 17 2022" -B build cmake --build build --config Release Write-Host "Compiled capnp into $THIRD_PARTY_COMPILED\$CAPNP_FOLDER_NAME" + Set-Location $oldDir } elseif ($action -eq "install") { + $oldDir = Get-Location Set-Location -Path "$THIRD_PARTY_COMPILED\$CAPNP_FOLDER_NAME" cmake --install build --config Release --prefix $PREFIX Write-Host "Installed capnp into $PREFIX" + Set-Location $oldDir } else { diff --git a/scripts/test.ps1 b/scripts/test.ps1 new file mode 100644 index 000000000..636c32601 --- /dev/null +++ b/scripts/test.ps1 @@ -0,0 +1,13 @@ +# +# This script tests the C++ components. +# +# Usage: +# ./scripts/test.ps1 + +$OS="windows" +$ARCH="x64" +$BUILD_DIR="build_${OS}_${ARCH}" +$BUILD_PRESET="${OS}-${ARCH}" + +# Run tests +ctest --preset $BUILD_PRESET -VV @args diff --git a/slides/bermudan_option.ipynb b/slides/bermudan_option.ipynb new file mode 100644 index 000000000..2a15b905b --- /dev/null +++ b/slides/bermudan_option.ipynb @@ -0,0 +1,464 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "32be14f1-36f4-40ea-a154-87d0224370ac", + "metadata": {}, + "source": [ + "# OpenGRIS Scaler Demo With Multiple Backends (IBM Symphony + AWS ECS)" + ] + }, + { + "cell_type": "markdown", + "id": "0e853b2f-4b1f-4c70-a173-7147f404d5aa", + "metadata": {}, + "source": [ + "## Example: Heston model Bermudan call option pricing using Monte Carlo simulation\n", + "\n", + "A Bermudan option is an exotic option exercisable on predetermined dates, usually monthly. It differs from European options which can only be exercised on the date of expiration and American options which can be exercised at any time before expiration.\n", + "\n", + "Given a basket of stocks, download price data from Yahoo finance and use the Heston model to price a Bermudan call option using Monte Carlo simulation." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9b780bb4-941b-41ab-89f2-58c5b9a493a8", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import yfinance as yf\n", + "import logging\n", + "from time import time\n", + "import warnings\n", + "from tqdm import tqdm\n", + "from scaler import Client\n", + "\n", + "warnings.filterwarnings('ignore')\n", + "\n", + "logging.basicConfig(level=logging.INFO)\n", + "logger = logging.getLogger(__name__)" + ] + }, + { + "cell_type": "markdown", + "id": "57acb2c8-f107-40bf-aae9-4b7eb6a39702", + "metadata": {}, + "source": [ + "## Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "f9cfdecf-1b89-42a7-b570-fb02741a4956", + "metadata": {}, + "outputs": [], + "source": [ + "TICKERS = [\n", + " 'NVDA', 'MSFT', 'AAPL', 'GOOGL', 'AMZN', 'META', 'AVGO', '2222.SR',\n", + " 'TSLA', 'TSM', 'BRK-B', 'ORCL', 'WMT', 'JPM', 'TCEHY', 'LLY',\n", + " 'V', 'NFLX', 'MA', 'XOM'\n", + "]\n", + "\n", + "HESTON_DEFAULTS = {\n", + " 'kappa': 2.0,\n", + " 'theta': 0.04,\n", + " 'sigma': 0.3,\n", + " 'rho': -0.7,\n", + " 'v0': 0.04,\n", + "}\n", + "\n", + "r_default = 0.05\n", + "q = 0.0\n", + "basis_degree = 3\n", + "barrier_frac = 1.0\n", + "\n", + "scheduler_address=\"tcp://127.0.0.1:8080\"" + ] + }, + { + "cell_type": "markdown", + "id": "36bc4e1c-cc24-45da-8f1c-e3d6775aa234", + "metadata": { + "jp-MarkdownHeadingCollapsed": true + }, + "source": [ + "## Download, Generate Tasks, Calculate " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c99acc15-c51b-422c-9042-dec369b8da24", + "metadata": {}, + "outputs": [], + "source": [ + "def download_prices(tickers, start, end):\n", + " data = yf.download(tickers, start=start, end=end, auto_adjust=False)['Adj Close']\n", + " last_prices = data.iloc[-1].values\n", + " returns = np.log(data / data.shift(1)).dropna()\n", + " corr_matrix = np.corrcoef(returns.T)\n", + " np.fill_diagonal(corr_matrix, 1.0)\n", + " logger.info(f\"Last prices: {dict(zip(tickers, last_prices))}\")\n", + " return last_prices, corr_matrix\n", + "\n", + "\n", + "def generate_one_path_using_rng(S0, corr_chol, dt, num_steps, heston, rng):\n", + " \"\"\"\n", + " Generate one path using a provided numpy.random.Generator 'rng' (so it's safe to call\n", + " in different processes with controlled seeding).\n", + " Returns array shape (num_steps+1, n_assets)\n", + " \"\"\"\n", + " n = len(S0)\n", + " S = np.zeros((num_steps + 1, n))\n", + " V = np.zeros((num_steps + 1, n))\n", + " S[0] = S0\n", + " V[0] = heston['v0']\n", + " sqrt_dt = np.sqrt(dt)\n", + " rho = heston['rho']\n", + " sqrt_1mrho2 = np.sqrt(max(0, 1 - rho**2))\n", + " kappa = heston['kappa']\n", + " theta = heston['theta']\n", + " sigma_v = heston['sigma']\n", + "\n", + " for t in range(1, num_steps + 1):\n", + " Z_stock = rng.standard_normal(n)\n", + " dW_stock = corr_chol @ Z_stock * sqrt_dt\n", + " Z_vol_ind = rng.standard_normal(n) * sqrt_dt\n", + " dW_vol = rho * dW_stock + sqrt_1mrho2 * Z_vol_ind\n", + " drift_s = (r_default - q - 0.5 * V[t - 1]) * dt\n", + " S[t] = S[t - 1] * np.exp(drift_s + np.sqrt(np.maximum(V[t - 1], 0)) * dW_stock)\n", + " drift_v = kappa * (theta - V[t - 1]) * dt\n", + " vol_sqrtv = sigma_v * np.sqrt(np.maximum(V[t - 1], 0))\n", + " V[t] = np.maximum(V[t - 1] + drift_v + vol_sqrtv * dW_vol, 1e-6)\n", + " return S\n", + "\n", + "\n", + "def generate_paths_batch(num_paths_batch, S0, corr_chol, dt, num_steps, heston, seed=None):\n", + " \"\"\"\n", + " Worker function for distributed execution. Generates 'num_paths_batch' paths and returns\n", + " a numpy array with shape (num_paths_batch, num_steps+1, n_assets).\n", + " This function is picklable and safe to submit to the scaler.Client.\n", + " \"\"\"\n", + " rng = np.random.default_rng(seed)\n", + " batch = []\n", + " for i in range(num_paths_batch):\n", + " path = generate_one_path_using_rng(S0, corr_chol, dt, num_steps, heston, rng)\n", + " batch.append(path)\n", + " return np.stack(batch, axis=0)\n", + "\n", + "\n", + "\n", + "def generate_many_path_tasks(num_paths, S0, corr_chol, dt, num_steps, heston, batch_size):\n", + " \"\"\"\n", + " Use scaler.Client to submit batches of path generation as separate tasks.\n", + " - batch_size controls how many paths each submitted job generates.\n", + " - show_progress toggles a tqdm bar for completed batches.\n", + " Returns stacked paths array shape (num_paths, num_steps+1, n_assets)\n", + " \"\"\"\n", + " # Compute batches\n", + " num_batches = (num_paths + batch_size - 1) // batch_size\n", + " batches = []\n", + " # compute exact sizes for each batch (last may be smaller)\n", + " batch_sizes = [batch_size] * num_batches\n", + " last_mod = num_paths - batch_size * (num_batches - 1)\n", + " batch_sizes[-1] = last_mod if last_mod > 0 else batch_size\n", + "\n", + " base_seed = int(time()) & 0x7FFFFFFF\n", + " \n", + " tasks = []\n", + " for i, bs in enumerate(batch_sizes):\n", + " seed = base_seed + i\n", + " tasks.append((generate_paths_batch, (bs, S0, corr_chol, dt, num_steps, heston, seed)))\n", + " \n", + " return tasks\n", + "\n", + "\n", + "def price_bermudan(paths, K, barrier, r, exercise_indices, ex_times, degree):\n", + " num_paths, num_steps, n_assets = paths.shape\n", + " basket_paths = np.mean(paths, axis=2)\n", + " basket_ex = basket_paths[:, exercise_indices]\n", + " num_ex = basket_ex.shape[1]\n", + " hit_mask = basket_ex < barrier\n", + " has_hit = np.any(hit_mask, axis=1)\n", + " first_hit_idx = np.argmax(hit_mask, axis=1)\n", + " first_hit_idx[~has_hit] = num_ex\n", + " cashflow = np.zeros(num_paths)\n", + " alive_at_last = first_hit_idx == num_ex\n", + " cashflow[alive_at_last] = np.maximum(basket_ex[alive_at_last, -1] - K, 0)\n", + " for j in range(num_ex - 2, -1, -1):\n", + " delta_t_j = ex_times[j + 1] - ex_times[j]\n", + " disc_factor = np.exp(-r * delta_t_j)\n", + " alive = first_hit_idx > j\n", + " num_alive = np.sum(alive)\n", + " if num_alive == 0:\n", + " continue\n", + " disc_cash_next = disc_factor * cashflow[alive]\n", + " basket_j = basket_ex[alive, j]\n", + " itm = basket_j > K\n", + " num_itm = np.sum(itm)\n", + " if num_itm < degree + 1:\n", + " cont = np.full(num_alive, np.mean(disc_cash_next))\n", + " else:\n", + " X = np.vander(basket_j[itm], degree + 1, increasing=True)\n", + " y = disc_cash_next[itm]\n", + " beta, _, _, _ = np.linalg.lstsq(X, y, rcond=None)\n", + " X_all = np.vander(basket_j, degree + 1, increasing=True)\n", + " cont = X_all @ beta\n", + " intrinsic = np.maximum(basket_j - K, 0)\n", + " exercise = intrinsic > cont\n", + " exercised_idx = np.where(alive)[0][exercise]\n", + " cashflow[exercised_idx] = intrinsic[exercise]\n", + " price = np.mean(np.exp(-r * ex_times[-1]) * cashflow)\n", + " print(f\"The price of the Bermudan down-and-out call option on the equally-weighted basket is {price:.4f}\")\n", + " return price\n", + "\n", + "\n", + "def generate_tasks(start_date, end_date, num_paths, num_steps, maturity, num_exercises,\n", + " risk_free_rate, batch_size):\n", + " \"\"\"\n", + " Main entry. Set use_distributed=True to use scaler.Client and submit batch jobs.\n", + " \"\"\"\n", + " global r_default\n", + " r_default = risk_free_rate\n", + "\n", + " S0, corr = download_prices(TICKERS, start_date, end_date)\n", + " try:\n", + " corr_chol = np.linalg.cholesky(corr)\n", + " except np.linalg.LinAlgError:\n", + " logger.warning(\"Correlation matrix not positive semi-definite, adding jitter.\")\n", + " corr += np.eye(corr.shape[0]) * 1e-6\n", + " corr_chol = np.linalg.cholesky(corr)\n", + "\n", + " basket0 = np.mean(S0)\n", + " K = basket0\n", + " barrier = barrier_frac * K\n", + " dt = maturity / num_steps\n", + " step_per_period = num_steps // num_exercises\n", + " exercise_indices = np.arange(1, num_exercises + 1) * step_per_period\n", + " exercise_indices = exercise_indices[exercise_indices <= num_steps]\n", + " ex_times = exercise_indices * dt\n", + " heston = HESTON_DEFAULTS\n", + "\n", + " tasks = generate_many_path_tasks(num_paths, S0, corr_chol, dt, num_steps, heston, batch_size=batch_size)\n", + " return K, barrier, r_default, exercise_indices, ex_times, basis_degree, tasks\n", + "\n", + "\n", + "def get_bermudan_price(K, barrier, r_default, exercise_indices, ex_times, basis_degree, results):\n", + " paths = np.concatenate(results, axis=0)\n", + " # If we over-allocated for any reason, trim to requested num_paths\n", + " # if paths.shape[0] > num_paths:\n", + " # paths = paths[:num_paths]\n", + "\n", + " return price_bermudan(paths, K, barrier, r_default, exercise_indices, ex_times, basis_degree)\n" + ] + }, + { + "cell_type": "markdown", + "id": "fd54518c-7d49-4368-b6e9-48b9e4c24ec0", + "metadata": {}, + "source": [ + "## Generate Task Batches" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d171c5fd-3682-42cb-8b58-c64e23c2729b", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[*********************100%***********************] 20 of 20 completed\n", + "INFO:__main__:Last prices: {'NVDA': 24.97572135925293, 'MSFT': 247.21035766601562, 'AAPL': 214.47000122070312, 'GOOGL': 354.1499938964844, 'AMZN': 488.80999755859375, 'META': 251.4600067138672, 'AVGO': 298.5400085449219, '2222.SR': 818.1784057617188, 'TSLA': 549.8800048828125, 'TSM': 712.0700073242188, 'BRK-B': 511.6099853515625, 'ORCL': 118.35900115966797, 'WMT': 181.80999755859375, 'JPM': 313.0, 'TCEHY': 79.68000030517578, 'LLY': 428.75, 'V': 299.8399963378906, 'NFLX': 334.7369689941406, 'MA': 106.47000122070312, 'XOM': 109.68067169189453}\n" + ] + } + ], + "source": [ + "K, barrier, r_default, exercise_indices, ex_times, basis_degree, tasks = generate_tasks(\n", + " start_date='2024-10-17', end_date='2025-10-17', num_paths=100_000, num_steps=252, maturity=1.0, num_exercises=12, risk_free_rate=0.05, batch_size=10\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "211de5c0-1f53-4ac7-8ce4-f9d6d94cfdfe", + "metadata": {}, + "source": [ + "## Compute On IBM Symphony Only" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e084c8c0-1781-43b0-9be1-835ebcdec3d7", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:ScalerClient: connect to scheduler at tcp://127.0.0.1:8080\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ClientHeartbeatManager: started\n", + "INFO:root:ScalerClient: connect to object storage at tcp://127.0.0.1:8081\n", + "batches: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [05:50<00:00, 28.50batch/s]\n", + "INFO:root:ScalerClient: disconnect from tcp://127.0.0.1:8080\n", + "INFO:root:canceling 0 task(s)\n", + "INFO:root:ClientAgent: client quitting\n", + "INFO:root:ZMQAsyncConnector: exited\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The price of the Bermudan down-and-out call option on the equally-weighted basket is 14.5174\n" + ] + } + ], + "source": [ + "def compute_symphony(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree):\n", + " with Client(scheduler_address) as client:\n", + " # Submit all batch jobs\n", + " sym_futures = [client.submit_verbose(*task, kwargs={}, capabilities={\"symphony\": -1}) for task in tasks]\n", + " \n", + " symphony_result = get_bermudan_price(\n", + " K, barrier, r_default, exercise_indices, ex_times, basis_degree,\n", + " [sym_futures[idx].result() for idx in tqdm(range(len(sym_futures)), desc=\"batches\", unit=\"batch\")]\n", + " )\n", + "\n", + "compute_symphony(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree)" + ] + }, + { + "cell_type": "markdown", + "id": "bf405fe3-42ca-4131-b040-5dc7319fddb2", + "metadata": {}, + "source": [ + "## Compute On AWS ECS Only" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "aa742fad-df5c-41db-acdf-491903e869b0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:ScalerClient: connect to scheduler at tcp://127.0.0.1:8080\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ClientHeartbeatManager: started\n", + "INFO:root:ScalerClient: connect to object storage at tcp://127.0.0.1:8081\n", + "batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [01:29<00:00, 111.57batch/s]\n", + "INFO:root:ScalerClient: disconnect from tcp://127.0.0.1:8080\n", + "INFO:root:canceling 0 task(s)\n", + "INFO:root:ClientAgent: client quitting\n", + "INFO:root:ZMQAsyncConnector: exited\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The price of the Bermudan down-and-out call option on the equally-weighted basket is 14.5174\n" + ] + } + ], + "source": [ + "def compute_ecs(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree):\n", + " with Client(scheduler_address) as client:\n", + " # Submit all batch jobs\n", + " futures = [client.submit_verbose(*task, kwargs={}, capabilities={\"ecs\": -1}) for task in tasks]\n", + " \n", + " # Gather results with progress\n", + " results = [futures[idx].result() for idx in tqdm(range(len(futures)), desc=\"batches\", unit=\"batch\")]\n", + " get_bermudan_price(K, barrier, r_default, exercise_indices, ex_times, basis_degree, results)\n", + "\n", + "compute_ecs(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree)" + ] + }, + { + "cell_type": "markdown", + "id": "a559bcda-4bfd-4f1f-92a2-fc82c8012514", + "metadata": {}, + "source": [ + "## Compute on Both IBM Symphony and AWS ECS" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "b7d78262-982e-4f30-b42a-4395f1ac0df4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO:root:ScalerClient: connect to scheduler at tcp://127.0.0.1:8080\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ZMQAsyncConnector: started\n", + "INFO:root:ClientHeartbeatManager: started\n", + "INFO:root:ScalerClient: connect to object storage at tcp://127.0.0.1:8081\n", + "batches: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [06:04<00:00, 27.42batch/s]\n", + "INFO:root:ScalerClient: disconnect from tcp://127.0.0.1:8080\n", + "INFO:root:canceling 0 task(s)\n", + "INFO:root:ClientAgent: client quitting\n", + "INFO:root:ZMQAsyncConnector: exited\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The price of the Bermudan down-and-out call option on the equally-weighted basket is 14.6561\n" + ] + } + ], + "source": [ + "def compute_both(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree):\n", + " with Client(scheduler_address) as client:\n", + " # Submit all batch jobs\n", + " futures = [client.submit_verbose(*task, kwargs={}, capabilities={\"python3\": -1}) for task in tasks]\n", + " get_bermudan_price(\n", + " K, barrier, r_default, exercise_indices, ex_times, basis_degree,\n", + " [futures[idx].result() for idx in tqdm(range(len(futures)), desc=\"batches\", unit=\"batch\")]\n", + " )\n", + "\n", + "compute_both(scheduler_address, K, barrier, r_default, exercise_indices, ex_times, basis_degree)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.20" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/cpp/scaler/object_storage/CMakeLists.txt b/src/cpp/scaler/object_storage/CMakeLists.txt index d885bdbbe..36b001807 100644 --- a/src/cpp/scaler/object_storage/CMakeLists.txt +++ b/src/cpp/scaler/object_storage/CMakeLists.txt @@ -45,6 +45,6 @@ target_link_libraries(py_object_storage_server PRIVATE install( TARGETS py_object_storage_server - LIBRARY DESTINATION src/scaler/object_storage/ + LIBRARY DESTINATION scaler/object_storage/ COMPONENT object_storage_server ) diff --git a/src/cpp/scaler/ymq/CMakeLists.txt b/src/cpp/scaler/ymq/CMakeLists.txt index 2c6b1cf8f..b984fa666 100644 --- a/src/cpp/scaler/ymq/CMakeLists.txt +++ b/src/cpp/scaler/ymq/CMakeLists.txt @@ -15,10 +15,9 @@ add_library(ymq_objs OBJECT event_loop_thread.cpp event_manager.h - # file_descriptor.h - message_connection_tcp.h - message_connection_tcp.cpp + message_connection.h + message_connection.cpp third_party/concurrentqueue.h interruptive_concurrent_queue.h @@ -31,11 +30,11 @@ add_library(ymq_objs OBJECT io_socket.h io_socket.cpp - tcp_server.h - tcp_server.cpp + stream_server.h + stream_server.cpp - tcp_client.h - tcp_client.cpp + stream_client.h + stream_client.cpp tcp_operations.h @@ -50,16 +49,28 @@ add_library(ymq_objs OBJECT internal/defs.h - internal/raw_connection_tcp_fd.h - internal/raw_connection_tcp_fd.cpp + internal/raw_stream_connection_handle.h - internal/raw_server_tcp_fd.h - internal/raw_server_tcp_fd.cpp + internal/raw_stream_server_handle.h - internal/raw_client_tcp_fd.h - internal/raw_client_tcp_fd.cpp + internal/raw_stream_client_handle.h ) +# System dependent file for YMQ internal +if(LINUX) + target_sources(ymq_objs PRIVATE + internal/raw_stream_connection_handle_linux.cpp + internal/raw_stream_server_handle_linux.cpp + internal/raw_stream_client_handle_linux.cpp + ) +elseif(WIN32) + target_sources(ymq_objs PRIVATE + internal/raw_stream_connection_handle_windows.cpp + internal/raw_stream_server_handle_windows.cpp + internal/raw_stream_client_handle_windows.cpp + ) +endif() + set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/src/scaler/ymq) if(LINUX) @@ -94,7 +105,7 @@ if(LINUX) install( TARGETS py_ymq - LIBRARY DESTINATION src/scaler/io/ymq + LIBRARY DESTINATION scaler/io/ymq ) endif() diff --git a/src/cpp/scaler/ymq/configuration.h b/src/cpp/scaler/ymq/configuration.h index eae8c6adb..65f2f29f5 100644 --- a/src/cpp/scaler/ymq/configuration.h +++ b/src/cpp/scaler/ymq/configuration.h @@ -14,7 +14,7 @@ namespace scaler { namespace ymq { class EpollContext; -class IocpContext; +class IOCPContext; class Message; class IOSocket; @@ -35,7 +35,7 @@ struct Configuration { using PollingContext = EpollContext; #endif // __linux__ #ifdef _WIN32 - using PollingContext = IocpContext; + using PollingContext = IOCPContext; #endif // _WIN32 using IOSocketIdentity = std::string; diff --git a/src/cpp/scaler/ymq/event_manager.h b/src/cpp/scaler/ymq/event_manager.h index ecee1e2a9..6c98e2f40 100644 --- a/src/cpp/scaler/ymq/event_manager.h +++ b/src/cpp/scaler/ymq/event_manager.h @@ -46,7 +46,7 @@ class EventManager INHERIT_OVERLAPPED { } #endif // __linux__ #ifdef _WIN32 - if constexpr (std::same_as) { + if constexpr (std::same_as) { onRead(); onWrite(); if (events & IOCP_SOCKET_CLOSED) { diff --git a/src/cpp/scaler/ymq/internal/network_utils.h b/src/cpp/scaler/ymq/internal/network_utils.h new file mode 100644 index 000000000..3a8b67bc1 --- /dev/null +++ b/src/cpp/scaler/ymq/internal/network_utils.h @@ -0,0 +1,70 @@ +#pragma once +#include // std::upper_bound + +#include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/raw_stream_connection_handle.h" + +namespace scaler { +namespace ymq { + +bool setReuseAddress(auto fd) +{ + int optval = 1; + return !(setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (const char*)&optval, sizeof(optval)) == -1); +} + +std::pair tryReadUntilComplete(void* dest, size_t size, auto readBytes) +{ + uint64_t cnt = 0; + while (size) { + const auto current = readBytes((char*)dest + cnt, size); + if (current) { + cnt += current.value(); + size -= current.value(); + } else { + return {cnt, current.error()}; + } + } + return {cnt, RawStreamConnectionHandle::IOStatus::MoreBytesAvailable}; +} + +std::pair tryWriteUntilComplete( + const std::vector>& buffers, auto writeBytes) +{ + if (buffers.empty()) { + return {0, RawStreamConnectionHandle::IOStatus::MoreBytesAvailable}; + } + + std::vector prefixSum(buffers.size() + 1); + for (size_t i = 0; i < buffers.size(); ++i) { + prefixSum[i + 1] = prefixSum[i] + buffers[i].second; + } + const size_t total = prefixSum.back(); + + size_t sent = 0; + while (sent != total) { + auto unfinished = std::upper_bound(prefixSum.begin(), prefixSum.end(), sent); + --unfinished; + + std::vector> currentBuffers; + + auto begin = buffers.begin() + std::distance(prefixSum.begin(), unfinished); + const size_t remain = sent - *unfinished; + + currentBuffers.push_back({(char*)begin->first + remain, begin->second - remain}); + while (++begin != buffers.end()) { + currentBuffers.push_back(*begin); + } + + const auto res = writeBytes(currentBuffers); + if (res) { + sent += res.value(); + } else { + return {sent, res.error()}; + } + } + return {total, RawStreamConnectionHandle::IOStatus::MoreBytesAvailable}; +} + +} // namespace ymq +} // namespace scaler diff --git a/src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.cpp b/src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.cpp deleted file mode 100644 index cbb2fe4ba..000000000 --- a/src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.cpp +++ /dev/null @@ -1,430 +0,0 @@ -#include "scaler/ymq/internal/raw_connection_tcp_fd.h" - -#include -#include // assert - -#include "scaler/error/error.h" -#include "scaler/ymq/internal/defs.h" - -namespace scaler { -namespace ymq { - -std::pair RawConnectionTCPFD::tryReadUntilComplete(void* dest, size_t size) -{ - uint64_t cnt = 0; - while (size) { - const auto current = readBytes((char*)dest + cnt, size); - if (current) { - cnt += current.value(); - size -= current.value(); - } else { - return {cnt, current.error()}; - } - } - return {cnt, IOStatus::MoreBytesAvailable}; -} - -std::pair RawConnectionTCPFD::tryWriteUntilComplete( - const std::vector>& buffers) -{ - if (buffers.empty()) { - return {0, IOStatus::MoreBytesAvailable}; - } - - std::vector prefixSum(buffers.size() + 1); - for (size_t i = 0; i < buffers.size(); ++i) { - prefixSum[i + 1] = prefixSum[i] + buffers[i].second; - } - const size_t total = prefixSum.back(); - - size_t sent = 0; - while (sent != total) { - auto unfinished = std::upper_bound(prefixSum.begin(), prefixSum.end(), sent); - --unfinished; - - std::vector> currentBuffers; - - auto begin = buffers.begin() + std::distance(prefixSum.begin(), unfinished); - const size_t remain = sent - *unfinished; - - currentBuffers.push_back({(char*)begin->first + remain, begin->second - remain}); - while (++begin != buffers.end()) { - currentBuffers.push_back(*begin); - } - - const auto res = writeBytes(currentBuffers); - if (res) { - sent += res.value(); - } else { - return {sent, res.error()}; - } - } - return {total, IOStatus::MoreBytesAvailable}; -} - -std::expected RawConnectionTCPFD::readBytes(void* dest, size_t size) -{ - assert(_fd); - assert(dest); - assert(size); - - constexpr const int flags = 0; - - const int n = ::recv(_fd, (char*)dest, size, flags); - if (n > 0) { - return n; - } - - if (!n) { - _socketStatus = SocketStatus::Disconnected; - return std::unexpected {IOStatus::Disconnected}; - } - - if (n == -1) { -#ifdef __linux__ - // handle Linux errors - const int myErrno = errno; - if (myErrno == ECONNRESET) { - if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { - _socketStatus = SocketStatus::Disconnected; - return std::unexpected {IOStatus::Disconnected}; - } else { - return std::unexpected {IOStatus::Aborted}; - } - } - if (myErrno == EAGAIN || myErrno == EWOULDBLOCK) { - return std::unexpected {IOStatus::Drained}; - } else { - const int myErrno = errno; - switch (myErrno) { - case EBADF: - case EISDIR: - case EINVAL: - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "recv(2)", - "Errno is", - strerror(myErrno), - "_fd", - _fd, - "dest", - (void*)dest, - "size", - size, - }); - - case EINTR: - unrecoverableError({ - Error::ErrorCode::SignalNotSupported, - "Originated from", - "recv(2)", - "Errno is", - strerror(myErrno), - }); - - case EFAULT: - case EIO: - default: - unrecoverableError({ - Error::ErrorCode::ConfigurationError, - "Originated from", - "recv(2)", - "Errno is", - strerror(myErrno), - }); - } - } -#endif // __linux__ - -#ifdef _WIN32 - const int myErrno = WSAGetLastError(); - if (myErrno == WSAEWOULDBLOCK) { - return std::unexpected {IOStatus::Drained}; - } - if (myErrno == WSAECONNRESET || myErrno == WSAENOTSOCK || myErrno == WSAECONNABORTED) { - if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { - _socketStatus = SocketStatus::Disconnected; - return std::unexpected {IOStatus::Disconnected}; - } else { - return std::unexpected {IOStatus::Aborted}; - } - } else { - // NOTE: On Windows we don't have signals and weird IO Errors - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "recv(2)", - "Errno is", - myErrno, - "_fd", - _fd, - "dest", - (void*)dest, - "size", - size, - }); - } -#endif // _WIN32 - } - - std::unreachable(); -} - -std::expected RawConnectionTCPFD::writeBytes( - const std::vector>& buffers) -{ - assert(buffers.size()); -#ifdef _WIN32 -#define iovec ::WSABUF -#define IOV_MAX (1024) -#define iov_base buf -#define iov_len len -#endif // _WIN32 - - std::vector iovecs; - iovecs.reserve(IOV_MAX); - - size_t total = 0; - for (const auto& [ptr, len]: buffers) { - if (iovecs.size() == IOV_MAX) { - break; - } - iovec current; - current.iov_base = (char*)ptr; - current.iov_len = len; - iovecs.push_back(std::move(current)); - total += current.iov_len; - // iovecs.emplace_back((void*)ptr, len); - } - - assert(total); - (void)total; - - if (iovecs.empty()) { - return 0; - } - -#ifdef _WIN32 - DWORD bytesSent {}; - const int sendToResult = WSASendTo(_fd, iovecs.data(), iovecs.size(), &bytesSent, 0, nullptr, 0, nullptr, nullptr); - if (sendToResult == 0) { - return bytesSent; - } - const int myErrno = WSAGetLastError(); - if (myErrno == WSAEWOULDBLOCK) { - return std::unexpected {IOStatus::Drained}; - } - if (myErrno == WSAESHUTDOWN || myErrno == WSAENOTCONN || myErrno == WSAECONNRESET) { - if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { - _socketStatus = SocketStatus::Disconnected; - return std::unexpected {IOStatus::Disconnected}; - } else { - return std::unexpected {IOStatus::Aborted}; - } - } - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "WSASendTo", - "Errno is", - myErrno, - "_fd", - _fd, - "iovecs.size()", - iovecs.size(), - }); -#undef iovec -#undef IOV_MAX -#undef iov_base -#undef iov_len -#endif // _WIN32 - -#ifdef __linux__ - struct msghdr msg {}; - msg.msg_iov = iovecs.data(); - msg.msg_iovlen = iovecs.size(); - - ssize_t bytesSent = ::sendmsg(_fd, &msg, MSG_NOSIGNAL); - if (bytesSent == -1) { - if (errno == EAGAIN || errno == EWOULDBLOCK) { - return std::unexpected {IOStatus::Drained}; - } else { - const int myErrno = errno; - switch (myErrno) { - case EAFNOSUPPORT: - case EBADF: - case EINVAL: - case EMSGSIZE: - case ENOTCONN: - case ENOTSOCK: - case EOPNOTSUPP: - case ENAMETOOLONG: - case ENOENT: - case ENOTDIR: - case ELOOP: - case EDESTADDRREQ: - case EHOSTUNREACH: - case EISCONN: - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "sendmsg(2)", - "Errno is", - strerror(myErrno), - "_fd", - _fd, - "msg.msg_iovlen", - msg.msg_iovlen, - }); - break; - - case ECONNRESET: - // We maybe need to handle this differently, conditionally return IOStatus::Disconnected - case EPIPE: { - return std::unexpected {IOStatus::Aborted}; - break; - } - - case EINTR: - unrecoverableError({ - Error::ErrorCode::SignalNotSupported, - "Originated from", - "sendmsg(2)", - "Errno is", - strerror(myErrno), - }); - break; - - case EIO: - case EACCES: - case ENETDOWN: - case ENETUNREACH: - case ENOBUFS: - case ENOMEM: - default: - unrecoverableError({ - Error::ErrorCode::ConfigurationError, - "Originated from", - "sendmsg(2)", - "Errno is", - strerror(myErrno), - }); - break; - } - } - } - - return bytesSent; -#endif // __linux__ -} - -// TODO: This notifyHandle is a bad name but I don't have a better name for it now. -// Later, I will give it a better name. The purpose of it is to just "pass something" -// to the event loop. -bool RawConnectionTCPFD::prepareReadBytes(void* notifyHandle) -{ -#ifdef _WIN32 - // TODO: This need rewrite to better logic - if (!_fd) { - return false; - } - const bool ok = ReadFile((HANDLE)(SOCKET)_fd, nullptr, 0, nullptr, (LPOVERLAPPED)notifyHandle); - if (ok) { - // onRead(); - return true; - } - const auto lastError = GetLastError(); - if (lastError == ERROR_IO_PENDING) { - return false; - } - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "ReadFile", - "Errno is", - lastError, - "_fd", - _fd, - }); - std::unreachable(); -#endif // _WIN32 - -#ifdef __linux__ - return false; // ??? - -#endif // __linux__ -} - -// TODO: Think more about this notifyHandle, it used to be _eventManager.get() -std::pair RawConnectionTCPFD::prepareWriteBytes(void* dest, size_t size, void* notifyHandle) -{ - (void)size; -#ifdef _WIN32 - // NOTE: Precondition is the queue still has messages (perhaps a partial one). - const bool writeFileRes = WriteFile((HANDLE)(SOCKET)_fd, dest, 1, nullptr, (LPOVERLAPPED)notifyHandle); - if (writeFileRes) { - return {1, true}; - } - - const auto lastError = GetLastError(); - if (lastError == ERROR_IO_PENDING) { - return {1, false}; - } - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "prepareWriteBytes", - "Errno is", - lastError, - }); -#endif // _WIN32 - -#ifdef __linux__ - (void)notifyHandle; - return {0, false}; - -#endif // __linux__ -} - -void RawConnectionTCPFD::shutdownRead() noexcept -{ -#ifdef __linux__ - shutdown(_fd, SHUT_RD); -#endif // __linux__ -#ifdef _WIN32 - shutdown(_fd, SD_RECEIVE); -#endif // _WIN32 -} - -void RawConnectionTCPFD::shutdownWrite() noexcept -{ -#ifdef __linux__ - shutdown(_fd, SHUT_WR); -#endif // __linux__ -#ifdef _WIN32 - shutdown(_fd, SD_SEND); -#endif // _WIN32 - _socketStatus = SocketStatus::Disconnecting; -} - -void RawConnectionTCPFD::shutdownBoth() noexcept -{ - shutdownWrite(); - shutdownRead(); -} - -void RawConnectionTCPFD::closeAndZero() noexcept -{ -#ifdef __linux__ - close(_fd); -#endif // __linux__ -#ifdef _WIN32 - closesocket(_fd); -#endif // _WIN32 - _fd = 0; - _socketStatus = SocketStatus::Disconnected; -} - -} // namespace ymq -} // namespace scaler diff --git a/src/cpp/scaler/ymq/internal/raw_client_tcp_fd.h b/src/cpp/scaler/ymq/internal/raw_stream_client_handle.h similarity index 55% rename from src/cpp/scaler/ymq/internal/raw_client_tcp_fd.h rename to src/cpp/scaler/ymq/internal/raw_stream_client_handle.h index cab850f8d..1f1ae8668 100644 --- a/src/cpp/scaler/ymq/internal/raw_client_tcp_fd.h +++ b/src/cpp/scaler/ymq/internal/raw_stream_client_handle.h @@ -7,15 +7,15 @@ namespace scaler { namespace ymq { -class RawClientTCPFD { +class RawStreamClientHandle { public: - RawClientTCPFD(sockaddr remoteAddr); - ~RawClientTCPFD(); + RawStreamClientHandle(sockaddr remoteAddr); + ~RawStreamClientHandle(); - RawClientTCPFD(RawClientTCPFD&&) = delete; - RawClientTCPFD& operator=(RawClientTCPFD&& other) = delete; - RawClientTCPFD(const RawClientTCPFD&) = delete; - RawClientTCPFD& operator=(const RawClientTCPFD&) = delete; + RawStreamClientHandle(RawStreamClientHandle&&) = delete; + RawStreamClientHandle& operator=(RawStreamClientHandle&& other) = delete; + RawStreamClientHandle(const RawStreamClientHandle&) = delete; + RawStreamClientHandle& operator=(const RawStreamClientHandle&) = delete; void create(); void destroy(); diff --git a/src/cpp/scaler/ymq/internal/raw_client_tcp_fd.cpp b/src/cpp/scaler/ymq/internal/raw_stream_client_handle_linux.cpp similarity index 58% rename from src/cpp/scaler/ymq/internal/raw_client_tcp_fd.cpp rename to src/cpp/scaler/ymq/internal/raw_stream_client_handle_linux.cpp index 9095e419a..2f4e25a60 100644 --- a/src/cpp/scaler/ymq/internal/raw_client_tcp_fd.cpp +++ b/src/cpp/scaler/ymq/internal/raw_stream_client_handle_linux.cpp @@ -1,46 +1,16 @@ -#include "scaler/ymq/internal/raw_client_tcp_fd.h" - +#ifdef __linux__ #include "scaler/error/error.h" +#include "scaler/ymq/internal/raw_stream_client_handle.h" namespace scaler { namespace ymq { -RawClientTCPFD::RawClientTCPFD(sockaddr remoteAddr): _clientFD {}, _remoteAddr(std::move(remoteAddr)) +RawStreamClientHandle::RawStreamClientHandle(sockaddr remoteAddr): _clientFD {}, _remoteAddr(std::move(remoteAddr)) { -#ifdef _WIN32 - _connectExFunc = {}; - - auto tmp = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - DWORD res; - GUID guid = WSAID_CONNECTEX; - WSAIoctl( - tmp, - SIO_GET_EXTENSION_FUNCTION_POINTER, - (void*)&guid, - sizeof(GUID), - &_connectExFunc, - sizeof(_connectExFunc), - &res, - 0, - 0); - closesocket(tmp); - if (!_connectExFunc) { - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "WSAIoctl", - "Errno is", - GetErrorCode(), - "_connectExFunc", - (void*)_connectExFunc, - }); - } -#endif // _WIN32 } -void RawClientTCPFD::create() +void RawStreamClientHandle::create() { -#ifdef __linux__ _clientFD = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP); if ((int)_clientFD == -1) { @@ -73,27 +43,10 @@ void RawClientTCPFD::create() } return; } -#endif - -#ifdef _WIN32 - _clientFD = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); - if (_clientFD == -1) { - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "socket(2)", - "Errno is", - strerror(GetErrorCode()), - }); - } - u_long nonblock = 1; - ioctlsocket(_clientFD, FIONBIO, &nonblock); -#endif } -bool RawClientTCPFD::prepConnect(void* notifyHandle) +bool RawStreamClientHandle::prepConnect(void* notifyHandle) { -#ifdef __linux__ const int ret = connect((int)_clientFD, (sockaddr*)&_remoteAddr, sizeof(_remoteAddr)); if (ret >= 0) [[unlikely]] { @@ -151,59 +104,10 @@ bool RawClientTCPFD::prepConnect(void* notifyHandle) _clientFD, }); } -#endif - -#ifdef _WIN32 - sockaddr_in localAddr = {}; - localAddr.sin_family = AF_INET; - const char ip4[] = {127, 0, 0, 1}; - *(int*)&localAddr.sin_addr = *(int*)ip4; - - const int bindRes = bind(_clientFD, (struct sockaddr*)&localAddr, sizeof(struct sockaddr_in)); - if (bindRes == -1) { - unrecoverableError({ - Error::ErrorCode::ConfigurationError, - "Originated from", - "bind", - "Errno is", - GetErrorCode(), - "_clientFD", - _clientFD, - }); - } - - const bool ok = - _connectExFunc(_clientFD, &_remoteAddr, sizeof(struct sockaddr), NULL, 0, NULL, (LPOVERLAPPED)notifyHandle); - if (ok) { - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "connectEx", - "_clientFD", - _clientFD, - }); - } - - const int myErrno = GetErrorCode(); - if (myErrno == ERROR_IO_PENDING) { - return false; - } - - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "connectEx", - "Errno is", - myErrno, - "_clientFD", - _clientFD, - }); -#endif } -bool RawClientTCPFD::needRetry() +bool RawStreamClientHandle::needRetry() { -#ifdef __linux__ int err {}; socklen_t errLen {sizeof(err)}; if (getsockopt((int)_clientFD, SOL_SOCKET, SO_ERROR, &err, &errLen) < 0) { @@ -254,35 +158,25 @@ bool RawClientTCPFD::needRetry() }); } return false; -#endif - -#ifdef _WIN32 - const int iResult = setsockopt(_clientFD, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0); - return iResult == -1; -#endif } -void RawClientTCPFD::destroy() +void RawStreamClientHandle::destroy() { -#ifdef _WIN32 - if (_clientFD) { - CancelIoEx((HANDLE)_clientFD, nullptr); - } -#endif if (_clientFD) { CloseAndZeroSocket(_clientFD); } } -void RawClientTCPFD::zeroNativeHandle() noexcept +void RawStreamClientHandle::zeroNativeHandle() noexcept { _clientFD = 0; } -RawClientTCPFD::~RawClientTCPFD() +RawStreamClientHandle::~RawStreamClientHandle() { destroy(); } } // namespace ymq } // namespace scaler +#endif diff --git a/src/cpp/scaler/ymq/internal/raw_stream_client_handle_windows.cpp b/src/cpp/scaler/ymq/internal/raw_stream_client_handle_windows.cpp new file mode 100644 index 000000000..9db07a254 --- /dev/null +++ b/src/cpp/scaler/ymq/internal/raw_stream_client_handle_windows.cpp @@ -0,0 +1,132 @@ +#ifdef _WIN32 +#include "scaler/error/error.h" +#include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/raw_stream_client_handle.h" + +namespace scaler { +namespace ymq { + +RawStreamClientHandle::RawStreamClientHandle(sockaddr remoteAddr): _clientFD {}, _remoteAddr(std::move(remoteAddr)) +{ + _connectExFunc = {}; + + auto tmp = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + DWORD res; + GUID guid = WSAID_CONNECTEX; + WSAIoctl( + tmp, + SIO_GET_EXTENSION_FUNCTION_POINTER, + (void*)&guid, + sizeof(GUID), + &_connectExFunc, + sizeof(_connectExFunc), + &res, + 0, + 0); + closesocket(tmp); + if (!_connectExFunc) { + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "WSAIoctl", + "Errno is", + GetErrorCode(), + "_connectExFunc", + (void*)_connectExFunc, + }); + } +} + +void RawStreamClientHandle::create() +{ + _clientFD = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (_clientFD == -1) { + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "socket(2)", + "Errno is", + strerror(GetErrorCode()), + }); + } + u_long nonblock = 1; + ioctlsocket(_clientFD, FIONBIO, &nonblock); +} + +bool RawStreamClientHandle::prepConnect(void* notifyHandle) +{ + sockaddr_in localAddr = {}; + localAddr.sin_family = AF_INET; + const char ip4[] = {127, 0, 0, 1}; + *(int*)&localAddr.sin_addr = *(int*)ip4; + + const int bindRes = bind(_clientFD, (struct sockaddr*)&localAddr, sizeof(struct sockaddr_in)); + if (bindRes == -1) { + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "bind", + "Errno is", + GetErrorCode(), + "_clientFD", + _clientFD, + }); + } + + const bool ok = + _connectExFunc(_clientFD, &_remoteAddr, sizeof(struct sockaddr), NULL, 0, NULL, (LPOVERLAPPED)notifyHandle); + if (ok) { + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "connectEx", + "_clientFD", + _clientFD, + }); + } + + const int myErrno = GetErrorCode(); + if (myErrno == ERROR_IO_PENDING) { + return false; + } + + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "connectEx", + "Errno is", + myErrno, + "_clientFD", + _clientFD, + }); +} + +bool RawStreamClientHandle::needRetry() +{ + const int iResult = setsockopt(_clientFD, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, NULL, 0); + return iResult == -1; +} + +void RawStreamClientHandle::destroy() +{ + if (_clientFD) { + CancelIoEx((HANDLE)_clientFD, nullptr); + } + if (_clientFD) { + CloseAndZeroSocket(_clientFD); + } +} + +void RawStreamClientHandle::zeroNativeHandle() noexcept +{ + _clientFD = 0; +} + +RawStreamClientHandle::~RawStreamClientHandle() +{ + destroy(); +} + +} // namespace ymq +} // namespace scaler +#endif diff --git a/src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.h b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle.h similarity index 68% rename from src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.h rename to src/cpp/scaler/ymq/internal/raw_stream_connection_handle.h index b0a64a4a0..490af8af2 100644 --- a/src/cpp/scaler/ymq/internal/raw_connection_tcp_fd.h +++ b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle.h @@ -6,12 +6,10 @@ #include // std::pair #include -#include "scaler/ymq/internal/defs.h" // system compatible header - namespace scaler { namespace ymq { -class RawConnectionTCPFD { +class RawStreamConnectionHandle { public: enum class IOStatus { MoreBytesAvailable, @@ -39,36 +37,26 @@ class RawConnectionTCPFD { // TODO: This might need error handling std::pair prepareWriteBytes(void* dest, size_t len, void* notifyHandle); - // It has to be here, as the return type is different - auto nativeHandle() const noexcept - { -#ifdef _WIN32 - return (SOCKET)_fd; -#endif //_WIN32 - -#ifdef __linux__ - return (int)_fd; -#endif //__linux__ - } + int nativeHandle() const noexcept { return (int)_fd; } void shutdownRead() noexcept; void shutdownWrite() noexcept; void shutdownBoth() noexcept; void closeAndZero() noexcept; - RawConnectionTCPFD(uint64_t fd): _socketStatus(SocketStatus::Connected), _fd(fd) {} - RawConnectionTCPFD(): _fd {} {} - ~RawConnectionTCPFD() noexcept + RawStreamConnectionHandle(uint64_t fd): _socketStatus(SocketStatus::Connected), _fd(fd) {} + RawStreamConnectionHandle(): _fd {} {} + ~RawStreamConnectionHandle() noexcept { if (_fd) { closeAndZero(); } } - RawConnectionTCPFD(const RawConnectionTCPFD&) = delete; - RawConnectionTCPFD(RawConnectionTCPFD&&) = delete; - RawConnectionTCPFD& operator=(const RawConnectionTCPFD&) = delete; - RawConnectionTCPFD& operator=(RawConnectionTCPFD&&) = delete; + RawStreamConnectionHandle(const RawStreamConnectionHandle&) = delete; + RawStreamConnectionHandle(RawStreamConnectionHandle&&) = delete; + RawStreamConnectionHandle& operator=(const RawStreamConnectionHandle&) = delete; + RawStreamConnectionHandle& operator=(RawStreamConnectionHandle&&) = delete; private: std::expected readBytes(void* dest, size_t size); diff --git a/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_linux.cpp b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_linux.cpp new file mode 100644 index 000000000..d8ae85211 --- /dev/null +++ b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_linux.cpp @@ -0,0 +1,248 @@ +#ifdef __linux__ +#include // assert + +#include "scaler/error/error.h" +#include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/network_utils.h" +#include "scaler/ymq/internal/raw_stream_connection_handle.h" + +namespace scaler { +namespace ymq { + +std::pair RawStreamConnectionHandle::tryReadUntilComplete( + void* dest, size_t size) +{ + return scaler::ymq::tryReadUntilComplete( + dest, size, [&](char* dest, size_t size) { return this->readBytes(dest, size); }); +} + +std::pair RawStreamConnectionHandle::tryWriteUntilComplete( + const std::vector>& buffers) +{ + return scaler::ymq::tryWriteUntilComplete(buffers, [&](std::vector> currentBuffers) { + return this->writeBytes(currentBuffers); + }); +} + +void RawStreamConnectionHandle::shutdownBoth() noexcept +{ + shutdownWrite(); + shutdownRead(); +} + +std::expected RawStreamConnectionHandle::readBytes( + void* dest, size_t size) +{ + assert(_fd); + assert(dest); + assert(size); + + constexpr const int flags = 0; + + const int n = ::recv(_fd, (char*)dest, size, flags); + if (n > 0) { + return n; + } + + if (!n) { + _socketStatus = SocketStatus::Disconnected; + return std::unexpected {IOStatus::Disconnected}; + } + + if (n == -1) { + // handle Linux errors + const int myErrno = errno; + if (myErrno == ECONNRESET) { + if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { + _socketStatus = SocketStatus::Disconnected; + return std::unexpected {IOStatus::Disconnected}; + } else { + return std::unexpected {IOStatus::Aborted}; + } + } + if (myErrno == EAGAIN || myErrno == EWOULDBLOCK) { + return std::unexpected {IOStatus::Drained}; + } else { + const int myErrno = errno; + switch (myErrno) { + case EBADF: + case EISDIR: + case EINVAL: + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "recv(2)", + "Errno is", + strerror(myErrno), + "_fd", + _fd, + "dest", + (void*)dest, + "size", + size, + }); + + case EINTR: + unrecoverableError({ + Error::ErrorCode::SignalNotSupported, + "Originated from", + "recv(2)", + "Errno is", + strerror(myErrno), + }); + + case EFAULT: + case EIO: + default: + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "recv(2)", + "Errno is", + strerror(myErrno), + }); + } + } + } + + std::unreachable(); +} + +std::expected RawStreamConnectionHandle::writeBytes( + const std::vector>& buffers) +{ + assert(buffers.size()); + + std::vector iovecs; + iovecs.reserve(IOV_MAX); + + size_t total = 0; + for (const auto& [ptr, len]: buffers) { + if (iovecs.size() == IOV_MAX) { + break; + } + iovec current; + current.iov_base = (char*)ptr; + current.iov_len = len; + iovecs.push_back(std::move(current)); + total += current.iov_len; + } + + assert(total); + (void)total; + + if (iovecs.empty()) { + return 0; + } + + struct msghdr msg {}; + msg.msg_iov = iovecs.data(); + msg.msg_iovlen = iovecs.size(); + + ssize_t bytesSent = ::sendmsg(_fd, &msg, MSG_NOSIGNAL); + if (bytesSent == -1) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return std::unexpected {IOStatus::Drained}; + } else { + const int myErrno = errno; + switch (myErrno) { + case EAFNOSUPPORT: + case EBADF: + case EINVAL: + case EMSGSIZE: + case ENOTCONN: + case ENOTSOCK: + case EOPNOTSUPP: + case ENAMETOOLONG: + case ENOENT: + case ENOTDIR: + case ELOOP: + case EDESTADDRREQ: + case EHOSTUNREACH: + case EISCONN: + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "sendmsg(2)", + "Errno is", + strerror(myErrno), + "_fd", + _fd, + "msg.msg_iovlen", + msg.msg_iovlen, + }); + break; + + case ECONNRESET: + // We maybe need to handle this differently, conditionally return IOStatus::Disconnected + case EPIPE: { + return std::unexpected {IOStatus::Aborted}; + break; + } + + case EINTR: + unrecoverableError({ + Error::ErrorCode::SignalNotSupported, + "Originated from", + "sendmsg(2)", + "Errno is", + strerror(myErrno), + }); + break; + + case EIO: + case EACCES: + case ENETDOWN: + case ENETUNREACH: + case ENOBUFS: + case ENOMEM: + default: + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "sendmsg(2)", + "Errno is", + strerror(myErrno), + }); + break; + } + } + } + + return bytesSent; +} + +bool RawStreamConnectionHandle::prepareReadBytes(void* notifyHandle) +{ + (void)notifyHandle; + return false; +} + +std::pair RawStreamConnectionHandle::prepareWriteBytes(void* dest, size_t size, void* notifyHandle) +{ + (void)size; + (void)notifyHandle; + return {0, false}; +} + +void RawStreamConnectionHandle::shutdownRead() noexcept +{ + shutdown(_fd, SHUT_RD); +} + +void RawStreamConnectionHandle::shutdownWrite() noexcept +{ + shutdown(_fd, SHUT_WR); + _socketStatus = SocketStatus::Disconnecting; +} + +void RawStreamConnectionHandle::closeAndZero() noexcept +{ + close(_fd); + _fd = 0; + _socketStatus = SocketStatus::Disconnected; +} + +} // namespace ymq +} // namespace scaler +#endif diff --git a/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_windows.cpp b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_windows.cpp new file mode 100644 index 000000000..9b8e9777c --- /dev/null +++ b/src/cpp/scaler/ymq/internal/raw_stream_connection_handle_windows.cpp @@ -0,0 +1,225 @@ +#ifdef _WIN32 +#include +#include // assert + +#include "scaler/error/error.h" +#include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/network_utils.h" +#include "scaler/ymq/internal/raw_stream_connection_handle.h" + +namespace scaler { +namespace ymq { + +std::pair RawStreamConnectionHandle::tryReadUntilComplete( + void* dest, size_t size) +{ + return scaler::ymq::tryReadUntilComplete( + dest, size, [&](char* dest, size_t size) { return this->readBytes(dest, size); }); +} + +std::pair RawStreamConnectionHandle::tryWriteUntilComplete( + const std::vector>& buffers) +{ + return scaler::ymq::tryWriteUntilComplete(buffers, [&](std::vector> currentBuffers) { + return this->writeBytes(currentBuffers); + }); +} + +void RawStreamConnectionHandle::shutdownBoth() noexcept +{ + shutdownWrite(); + shutdownRead(); +} + +std::expected RawStreamConnectionHandle::readBytes( + void* dest, size_t size) +{ + assert(_fd); + assert(dest); + assert(size); + + constexpr const int flags = 0; + + const int n = ::recv(_fd, (char*)dest, size, flags); + if (n > 0) { + return n; + } + + if (!n) { + _socketStatus = SocketStatus::Disconnected; + return std::unexpected {IOStatus::Disconnected}; + } + + if (n == -1) { + const int myErrno = WSAGetLastError(); + if (myErrno == WSAEWOULDBLOCK) { + return std::unexpected {IOStatus::Drained}; + } + if (myErrno == WSAECONNRESET || myErrno == WSAENOTSOCK || myErrno == WSAECONNABORTED) { + if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { + _socketStatus = SocketStatus::Disconnected; + return std::unexpected {IOStatus::Disconnected}; + } else { + return std::unexpected {IOStatus::Aborted}; + } + } else { + // NOTE: On Windows we don't have signals and weird IO Errors + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "recv(2)", + "Errno is", + myErrno, + "_fd", + _fd, + "dest", + (void*)dest, + "size", + size, + }); + } + } + + std::unreachable(); +} + +std::expected RawStreamConnectionHandle::writeBytes( + const std::vector>& buffers) +{ + assert(buffers.size()); +#define iovec ::WSABUF +#define IOV_MAX (1024) +#define iov_base buf +#define iov_len len + + std::vector iovecs; + iovecs.reserve(IOV_MAX); + + size_t total = 0; + for (const auto& [ptr, len]: buffers) { + if (iovecs.size() == IOV_MAX) { + break; + } + iovec current; + current.iov_base = (char*)ptr; + current.iov_len = len; + iovecs.push_back(std::move(current)); + total += current.iov_len; + // iovecs.emplace_back((void*)ptr, len); + } + + assert(total); + (void)total; + + if (iovecs.empty()) { + return 0; + } + + DWORD bytesSent {}; + const int sendToResult = WSASendTo(_fd, iovecs.data(), iovecs.size(), &bytesSent, 0, nullptr, 0, nullptr, nullptr); + if (sendToResult == 0) { + return bytesSent; + } + const int myErrno = WSAGetLastError(); + if (myErrno == WSAEWOULDBLOCK) { + return std::unexpected {IOStatus::Drained}; + } + if (myErrno == WSAESHUTDOWN || myErrno == WSAENOTCONN || myErrno == WSAECONNRESET) { + if (_socketStatus == SocketStatus::Disconnecting || _socketStatus == SocketStatus::Disconnected) { + _socketStatus = SocketStatus::Disconnected; + return std::unexpected {IOStatus::Disconnected}; + } else { + return std::unexpected {IOStatus::Aborted}; + } + } + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "WSASendTo", + "Errno is", + myErrno, + "_fd", + _fd, + "iovecs.size()", + iovecs.size(), + }); +#undef iovec +#undef IOV_MAX +#undef iov_base +#undef iov_len +} + +// TODO: This notifyHandle is a bad name but I don't have a better name for it now. +// Later, I will give it a better name. The purpose of it is to just "pass something" +// to the event loop. +bool RawStreamConnectionHandle::prepareReadBytes(void* notifyHandle) +{ + // TODO: This need rewrite to better logic + if (!_fd) { + return false; + } + const bool ok = ReadFile((HANDLE)(SOCKET)_fd, nullptr, 0, nullptr, (LPOVERLAPPED)notifyHandle); + if (ok) { + // onRead(); + return true; + } + const auto lastError = GetLastError(); + if (lastError == ERROR_IO_PENDING) { + return false; + } + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "ReadFile", + "Errno is", + lastError, + "_fd", + _fd, + }); + std::unreachable(); +} + +// TODO: Think more about this notifyHandle, it used to be _eventManager.get() +std::pair RawStreamConnectionHandle::prepareWriteBytes(void* dest, size_t size, void* notifyHandle) +{ + (void)size; + // NOTE: Precondition is the queue still has messages (perhaps a partial one). + const bool writeFileRes = WriteFile((HANDLE)(SOCKET)_fd, dest, 1, nullptr, (LPOVERLAPPED)notifyHandle); + if (writeFileRes) { + return {1, true}; + } + + const auto lastError = GetLastError(); + if (lastError == ERROR_IO_PENDING) { + return {1, false}; + } + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "prepareWriteBytes", + "Errno is", + lastError, + }); +} + +void RawStreamConnectionHandle::shutdownRead() noexcept +{ + shutdown(_fd, SD_RECEIVE); +} + +void RawStreamConnectionHandle::shutdownWrite() noexcept +{ + shutdown(_fd, SD_SEND); + _socketStatus = SocketStatus::Disconnecting; +} + +void RawStreamConnectionHandle::closeAndZero() noexcept +{ + closesocket(_fd); + _fd = 0; + _socketStatus = SocketStatus::Disconnected; +} + +} // namespace ymq +} // namespace scaler +#endif diff --git a/src/cpp/scaler/ymq/internal/raw_server_tcp_fd.h b/src/cpp/scaler/ymq/internal/raw_stream_server_handle.h similarity index 60% rename from src/cpp/scaler/ymq/internal/raw_server_tcp_fd.h rename to src/cpp/scaler/ymq/internal/raw_stream_server_handle.h index 530fa5bdc..48675d9e1 100644 --- a/src/cpp/scaler/ymq/internal/raw_server_tcp_fd.h +++ b/src/cpp/scaler/ymq/internal/raw_stream_server_handle.h @@ -7,16 +7,16 @@ namespace scaler { namespace ymq { -class RawServerTCPFD { +class RawStreamServerHandle { public: - RawServerTCPFD(sockaddr addr); + RawStreamServerHandle(sockaddr addr); - RawServerTCPFD(const RawServerTCPFD&) = delete; - RawServerTCPFD(RawServerTCPFD&&) = delete; - RawServerTCPFD& operator=(const RawServerTCPFD&) = delete; - RawServerTCPFD& operator=(RawServerTCPFD&&) = delete; + RawStreamServerHandle(const RawStreamServerHandle&) = delete; + RawStreamServerHandle(RawStreamServerHandle&&) = delete; + RawStreamServerHandle& operator=(const RawStreamServerHandle&) = delete; + RawStreamServerHandle& operator=(RawStreamServerHandle&&) = delete; - ~RawServerTCPFD(); + ~RawStreamServerHandle(); void prepareAcceptSocket(void* notifyHandle); std::vector> getNewConns(); diff --git a/src/cpp/scaler/ymq/internal/raw_stream_server_handle_linux.cpp b/src/cpp/scaler/ymq/internal/raw_stream_server_handle_linux.cpp new file mode 100644 index 000000000..ef633c390 --- /dev/null +++ b/src/cpp/scaler/ymq/internal/raw_stream_server_handle_linux.cpp @@ -0,0 +1,178 @@ +#ifdef __linux__ + +#include // std::move + +#include "scaler/error/error.h" +#include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/network_utils.h" +#include "scaler/ymq/internal/raw_stream_server_handle.h" + +namespace scaler { +namespace ymq { + +RawStreamServerHandle::RawStreamServerHandle(sockaddr addr) +{ + _serverFD = {}; + _addr = std::move(addr); + + _serverFD = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP); + if ((int)_serverFD == -1) { + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "socket(2)", + "Errno is", + strerror(errno), + "_serverFD", + _serverFD, + }); + + return; + } +} + +bool RawStreamServerHandle::setReuseAddress() +{ + if (::scaler::ymq::setReuseAddress(_serverFD)) { + return true; + } else { + CloseAndZeroSocket(_serverFD); + return false; + } +} + +void RawStreamServerHandle::bindAndListen() +{ + if (bind(_serverFD, &_addr, sizeof(_addr)) == -1) { + const auto serverFD = _serverFD; + CloseAndZeroSocket(_serverFD); + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "bind(2)", + "Errno is", + strerror(GetErrorCode()), + "_serverFD", + serverFD, + }); + + return; + } + + if (listen(_serverFD, SOMAXCONN) == -1) { + const auto serverFD = _serverFD; + CloseAndZeroSocket(_serverFD); + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "listen(2)", + "Errno is", + strerror(GetErrorCode()), + "_serverFD", + serverFD, + }); + + return; + } +} + +RawStreamServerHandle::~RawStreamServerHandle() +{ + if (_serverFD) { + CloseAndZeroSocket(_serverFD); + } +} + +void RawStreamServerHandle::prepareAcceptSocket(void* notifyHandle) +{ + (void)notifyHandle; +} + +std::vector> RawStreamServerHandle::getNewConns() +{ + std::vector> res; + while (true) { + sockaddr remoteAddr {}; + socklen_t remoteAddrLen = sizeof(remoteAddr); + + int fd = accept4(_serverFD, &remoteAddr, &remoteAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC); + if (fd < 0) { + const int myErrno = errno; + switch (myErrno) { + // Not an error + // case EWOULDBLOCK: // same as EAGAIN + case EAGAIN: + case ECONNABORTED: return res; + + case ENOTSOCK: + case EOPNOTSUPP: + case EINVAL: + case EBADF: + unrecoverableError({ + Error::ErrorCode::CoreBug, + "Originated from", + "accept4(2)", + "Errno is", + strerror(myErrno), + "_serverFD", + _serverFD, + }); + + case EINTR: + unrecoverableError({ + Error::ErrorCode::SignalNotSupported, + "Originated from", + "accept4(2)", + "Errno is", + strerror(myErrno), + }); + + // config + case EMFILE: + case ENFILE: + case ENOBUFS: + case ENOMEM: + case EFAULT: + case EPERM: + case EPROTO: + case ENOSR: + case ESOCKTNOSUPPORT: + case EPROTONOSUPPORT: + case ETIMEDOUT: + default: + unrecoverableError({ + Error::ErrorCode::ConfigurationError, + "Originated from", + "accept4(2)", + "Errno is", + strerror(myErrno), + }); + } + } + + if (remoteAddrLen > sizeof(remoteAddr)) { + unrecoverableError({ + Error::ErrorCode::IPv6NotSupported, + "Originated from", + "accept4(2)", + "remoteAddrLen", + remoteAddrLen, + "sizeof(remoteAddr)", + sizeof(remoteAddr), + }); + } + res.push_back({fd, remoteAddr}); + } +} + +void RawStreamServerHandle::destroy() +{ + if (_serverFD) { + CloseAndZeroSocket(_serverFD); + } +} + +} // namespace ymq +} // namespace scaler + +#endif diff --git a/src/cpp/scaler/ymq/internal/raw_server_tcp_fd.cpp b/src/cpp/scaler/ymq/internal/raw_stream_server_handle_windows.cpp similarity index 66% rename from src/cpp/scaler/ymq/internal/raw_server_tcp_fd.cpp rename to src/cpp/scaler/ymq/internal/raw_stream_server_handle_windows.cpp index a054dde77..61a289260 100644 --- a/src/cpp/scaler/ymq/internal/raw_server_tcp_fd.cpp +++ b/src/cpp/scaler/ymq/internal/raw_stream_server_handle_windows.cpp @@ -1,43 +1,25 @@ - -#include "scaler/ymq/internal/raw_server_tcp_fd.h" +#ifdef _WIN32 #include // std::move #include "scaler/error/error.h" #include "scaler/ymq/internal/defs.h" +#include "scaler/ymq/internal/network_utils.h" +#include "scaler/ymq/internal/raw_stream_server_handle.h" #include "scaler/ymq/network_utils.h" namespace scaler { namespace ymq { -RawServerTCPFD::RawServerTCPFD(sockaddr addr) +RawStreamServerHandle::RawStreamServerHandle(sockaddr addr) { _serverFD = {}; _addr = std::move(addr); -#ifdef _WIN32 _newConn = {}; _acceptExFunc = {}; memset(_buffer, 0, sizeof(_buffer)); -#endif // _WIN32 -#ifdef __linux__ - _serverFD = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, IPPROTO_TCP); - if ((int)_serverFD == -1) { - unrecoverableError({ - Error::ErrorCode::ConfigurationError, - "Originated from", - "socket(2)", - "Errno is", - strerror(errno), - "_serverFD", - _serverFD, - }); - - return; - } -#endif // __linux__ -#ifdef _WIN32 _serverFD = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (_serverFD == -1) { const int myErrno = GetErrorCode(); @@ -93,37 +75,22 @@ RawServerTCPFD::RawServerTCPFD(sockaddr addr) _serverFD, }); } -#endif // _WIN32 -} - -RawServerTCPFD::~RawServerTCPFD() -{ -#ifdef _WIN32 - if (_newConn) { - CloseAndZeroSocket(_newConn); - } -#endif - if (_serverFD) { -#ifdef _WIN32 - CancelIoEx((HANDLE)_serverFD, nullptr); -#endif - CloseAndZeroSocket(_serverFD); - } } -bool RawServerTCPFD::setReuseAddress() +bool RawStreamServerHandle::setReuseAddress() { - int optval = 1; - if (setsockopt(_serverFD, SOL_SOCKET, SO_REUSEADDR, (const char*)&optval, sizeof(optval)) == -1) { + if (::scaler::ymq::setReuseAddress(_serverFD)) { + return true; + } else { CloseAndZeroSocket(_serverFD); return false; } - return true; } -void RawServerTCPFD::bindAndListen() +void RawStreamServerHandle::bindAndListen() { if (bind(_serverFD, &_addr, sizeof(_addr)) == -1) { + const auto serverFD = _serverFD; CloseAndZeroSocket(_serverFD); unrecoverableError({ Error::ErrorCode::ConfigurationError, @@ -132,13 +99,14 @@ void RawServerTCPFD::bindAndListen() "Errno is", strerror(GetErrorCode()), "_serverFD", - _serverFD, + serverFD, }); return; } if (listen(_serverFD, SOMAXCONN) == -1) { + const auto serverFD = _serverFD; CloseAndZeroSocket(_serverFD); unrecoverableError({ Error::ErrorCode::ConfigurationError, @@ -147,16 +115,26 @@ void RawServerTCPFD::bindAndListen() "Errno is", strerror(GetErrorCode()), "_serverFD", - _serverFD, + serverFD, }); return; } } -void RawServerTCPFD::prepareAcceptSocket(void* notifyHandle) +RawStreamServerHandle::~RawStreamServerHandle() +{ + if (_newConn) { + CloseAndZeroSocket(_newConn); + } + if (_serverFD) { + CancelIoEx((HANDLE)_serverFD, nullptr); + CloseAndZeroSocket(_serverFD); + } +} + +void RawStreamServerHandle::prepareAcceptSocket(void* notifyHandle) { -#ifdef _WIN32 _newConn = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (_newConn == INVALID_SOCKET) { const int myErrno = GetErrorCode(); @@ -216,88 +194,12 @@ void RawServerTCPFD::prepareAcceptSocket(void* notifyHandle) return; } // acceptEx never succeed. -#endif // _WIN32 } -std::vector> RawServerTCPFD::getNewConns() +std::vector> RawStreamServerHandle::getNewConns() { std::vector> res; -#ifdef __linux__ - while (true) { - sockaddr remoteAddr {}; - socklen_t remoteAddrLen = sizeof(remoteAddr); - - int fd = accept4(_serverFD, &remoteAddr, &remoteAddrLen, SOCK_NONBLOCK | SOCK_CLOEXEC); - if (fd < 0) { - const int myErrno = errno; - switch (myErrno) { - // Not an error - // case EWOULDBLOCK: // same as EAGAIN - case EAGAIN: - case ECONNABORTED: return res; - case ENOTSOCK: - case EOPNOTSUPP: - case EINVAL: - case EBADF: - unrecoverableError({ - Error::ErrorCode::CoreBug, - "Originated from", - "accept4(2)", - "Errno is", - strerror(myErrno), - "_serverFD", - _serverFD, - }); - - case EINTR: - unrecoverableError({ - Error::ErrorCode::SignalNotSupported, - "Originated from", - "accept4(2)", - "Errno is", - strerror(myErrno), - }); - - // config - case EMFILE: - case ENFILE: - case ENOBUFS: - case ENOMEM: - case EFAULT: - case EPERM: - case EPROTO: - case ENOSR: - case ESOCKTNOSUPPORT: - case EPROTONOSUPPORT: - case ETIMEDOUT: - default: - unrecoverableError({ - Error::ErrorCode::ConfigurationError, - "Originated from", - "accept4(2)", - "Errno is", - strerror(myErrno), - }); - } - } - - if (remoteAddrLen > sizeof(remoteAddr)) { - unrecoverableError({ - Error::ErrorCode::IPv6NotSupported, - "Originated from", - "accept4(2)", - "remoteAddrLen", - remoteAddrLen, - "sizeof(remoteAddr)", - sizeof(remoteAddr), - }); - } - res.push_back({fd, remoteAddr}); - } -#endif - -#ifdef _WIN32 if (setsockopt( _newConn, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, reinterpret_cast(&_serverFD), sizeof(_serverFD)) == SOCKET_ERROR) { @@ -351,16 +253,13 @@ std::vector> RawServerTCPFD::getNewConns() _newConn = 0; // This _newConn will be handled by connection class return res; -#endif } -void RawServerTCPFD::destroy() +void RawStreamServerHandle::destroy() { -#ifdef _WIN32 if (_serverFD) { CancelIoEx((HANDLE)_serverFD, nullptr); } -#endif if (_serverFD) { CloseAndZeroSocket(_serverFD); } @@ -368,3 +267,5 @@ void RawServerTCPFD::destroy() } // namespace ymq } // namespace scaler + +#endif diff --git a/src/cpp/scaler/ymq/io_socket.cpp b/src/cpp/scaler/ymq/io_socket.cpp index ac43cc3f7..f3fe878bb 100644 --- a/src/cpp/scaler/ymq/io_socket.cpp +++ b/src/cpp/scaler/ymq/io_socket.cpp @@ -11,11 +11,11 @@ #include "scaler/error/error.h" #include "scaler/ymq/event_loop_thread.h" #include "scaler/ymq/event_manager.h" -#include "scaler/ymq/internal/raw_connection_tcp_fd.h" -#include "scaler/ymq/message_connection_tcp.h" +#include "scaler/ymq/internal/raw_stream_connection_handle.h" +#include "scaler/ymq/message_connection.h" #include "scaler/ymq/network_utils.h" -#include "scaler/ymq/tcp_client.h" -#include "scaler/ymq/tcp_server.h" +#include "scaler/ymq/stream_client.h" +#include "scaler/ymq/stream_server.h" #include "scaler/ymq/typedefs.h" namespace scaler { @@ -77,13 +77,13 @@ void IOSocket::sendMessage(Message message, SendMessageCallback onMessageSent) n default: break; } - MessageConnectionTCP* conn = nullptr; + MessageConnection* conn = nullptr; if (this->_identityToConnection.contains(address)) { conn = this->_identityToConnection[address].get(); } else { - const auto it = std::ranges::find( - _unestablishedConnection, address, &MessageConnectionTCP::_remoteIOSocketIdentity); + const auto it = + std::ranges::find(_unestablishedConnection, address, &MessageConnection::_remoteIOSocketIdentity); if (it != _unestablishedConnection.end()) { conn = it->get(); } else { @@ -179,7 +179,7 @@ void IOSocket::closeConnection(Identity remoteSocketIdentity) noexcept // TODO: The function should be separated into onConnectionAborted, onConnectionDisconnected, // and probably onConnectionAbortedBeforeEstablished(?) -void IOSocket::onConnectionDisconnected(MessageConnectionTCP* conn, bool keepInBook) noexcept +void IOSocket::onConnectionDisconnected(MessageConnection* conn, bool keepInBook) noexcept { if (!conn->_remoteIOSocketIdentity) { auto connIt = std::ranges::find_if(_unestablishedConnection, [&](const auto& x) { return x.get() == conn; }); @@ -227,7 +227,7 @@ void IOSocket::onConnectionDisconnected(MessageConnectionTCP* conn, bool keepInB // - look up _unconnectedConnections to find if there's a connection with the same identity // if so, merge it to this connection that currently resides in _connectingConnections // Similar thing for disconnection as well. -void IOSocket::onConnectionIdentityReceived(MessageConnectionTCP* conn) noexcept +void IOSocket::onConnectionIdentityReceived(MessageConnection* conn) noexcept { auto& s = conn->_remoteIOSocketIdentity; if (socketType() == IOSocketType::Connector) { @@ -269,7 +269,7 @@ void IOSocket::onConnectionIdentityReceived(MessageConnectionTCP* conn) noexcept void IOSocket::onConnectionCreated(std::string remoteIOSocketIdentity) noexcept { _unestablishedConnection.push_back( - std::make_unique( + std::make_unique( _eventLoopThread.get(), this->identity(), std::move(remoteIOSocketIdentity), @@ -281,7 +281,7 @@ void IOSocket::onConnectionCreated(std::string remoteIOSocketIdentity) noexcept void IOSocket::onConnectionCreated(int fd, sockaddr localAddr, sockaddr remoteAddr, bool responsibleForRetry) noexcept { _unestablishedConnection.push_back( - std::make_unique( + std::make_unique( _eventLoopThread.get(), fd, std::move(localAddr), @@ -293,7 +293,7 @@ void IOSocket::onConnectionCreated(int fd, sockaddr localAddr, sockaddr remoteAd _unestablishedConnection.back()->onCreated(); } -void IOSocket::removeConnectedTcpClient() noexcept +void IOSocket::removeConnectedTCPClient() noexcept { if (this->_tcpClient && this->_tcpClient->_connected) { this->_tcpClient.reset(); diff --git a/src/cpp/scaler/ymq/io_socket.h b/src/cpp/scaler/ymq/io_socket.h index 52d5bd293..4da2ec1f6 100644 --- a/src/cpp/scaler/ymq/io_socket.h +++ b/src/cpp/scaler/ymq/io_socket.h @@ -19,15 +19,15 @@ // First-party #include "scaler/ymq/configuration.h" #include "scaler/ymq/message.h" -#include "scaler/ymq/tcp_client.h" -#include "scaler/ymq/tcp_server.h" +#include "scaler/ymq/stream_client.h" +#include "scaler/ymq/stream_server.h" #include "scaler/ymq/typedefs.h" namespace scaler { namespace ymq { class EventLoopThread; -class MessageConnectionTCP; +class MessageConnection; class TcpWriteOperation; class IOSocket { @@ -65,9 +65,9 @@ class IOSocket { // TODO: Maybe figure out a better name than keepInBook. When keepInBook is true, the system will remember this // remote identity and will treat the next connection with that identity as the reincarnation of this identity. // Thus, keeping the identity in the book. - void onConnectionDisconnected(MessageConnectionTCP* conn, bool keepInBook = true) noexcept; + void onConnectionDisconnected(MessageConnection* conn, bool keepInBook = true) noexcept; // From Connection Class only - void onConnectionIdentityReceived(MessageConnectionTCP* conn) noexcept; + void onConnectionIdentityReceived(MessageConnection* conn) noexcept; // NOTE: These two functions are called respectively by sendMessage and server/client. // Notice that in the each case only the needed information are passed in; so it's less @@ -76,8 +76,8 @@ class IOSocket { void onConnectionCreated(std::string remoteIOSocketIdentity) noexcept; void onConnectionCreated(int fd, sockaddr localAddr, sockaddr remoteAddr, bool responsibleForRetry) noexcept; - // From TcpClient class only - void removeConnectedTcpClient() noexcept; + // From TCPClient class only + void removeConnectedTCPClient() noexcept; void requestStop() noexcept; @@ -91,15 +91,15 @@ class IOSocket { const Identity _identity; const IOSocketType _socketType; - // NOTE: Owning one TcpClient means the user cannot issue another connectTo + // NOTE: Owning one TCPClient means the user cannot issue another connectTo // when some message connection is retring to connect. - std::optional _tcpClient; + std::optional _tcpClient; - // NOTE: Owning one TcpServer means the user cannot bindTo multiple addresses. - std::optional _tcpServer; + // NOTE: Owning one TCPServer means the user cannot bindTo multiple addresses. + std::optional _tcpServer; // Remote identity to connection map - std::map> _identityToConnection; + std::map> _identityToConnection; // NOTE: An unestablished connection can be in the following states: // 1. The underlying socket is not yet defined. This happens when user call sendMessage @@ -111,7 +111,7 @@ class IOSocket { // On the other hand, `Established Connection` are stored in _identityToConnection map. // An established connection is a network connection that is currently connected, and // exchanged their identity. - std::vector> _unestablishedConnection; + std::vector> _unestablishedConnection; // NOTE: This variable needs to present in the IOSocket level because the user // does not care which connection a message is coming from. diff --git a/src/cpp/scaler/ymq/iocp_context.cpp b/src/cpp/scaler/ymq/iocp_context.cpp index 4fbe586a5..e46adc5d4 100644 --- a/src/cpp/scaler/ymq/iocp_context.cpp +++ b/src/cpp/scaler/ymq/iocp_context.cpp @@ -11,7 +11,7 @@ namespace scaler { namespace ymq { -void IocpContext::execPendingFunctions() +void IOCPContext::execPendingFunctions() { while (_delayedFunctions.size()) { auto top = std::move(_delayedFunctions.front()); @@ -20,7 +20,7 @@ void IocpContext::execPendingFunctions() } } -void IocpContext::loop() +void IOCPContext::loop() { std::array events {}; ULONG n = 0; @@ -77,7 +77,7 @@ void IocpContext::loop() execPendingFunctions(); } -void IocpContext::addFdToLoop(int fd, uint64_t, EventManager*) +void IOCPContext::addFdToLoop(int fd, uint64_t, EventManager*) { const DWORD threadCount = 1; if (!CreateIoCompletionPort((HANDLE)(SOCKET)fd, _completionPort, ((uint64_t)fd << 32) | _isSocket, threadCount)) { @@ -102,7 +102,7 @@ void IocpContext::addFdToLoop(int fd, uint64_t, EventManager*) // The file handle is automaticaly released when one call closesocket(fd). // This interface is required by the concept, and we need it for select(2) or poll(2). // Instead of relaxing constraint, we leave the implementation empty. -void IocpContext::removeFdFromLoop(int fd) +void IOCPContext::removeFdFromLoop(int fd) { CancelIoEx((HANDLE)(SOCKET)fd, nullptr); _sockets.erase(fd); diff --git a/src/cpp/scaler/ymq/iocp_context.h b/src/cpp/scaler/ymq/iocp_context.h index 490233e7e..021626d3c 100644 --- a/src/cpp/scaler/ymq/iocp_context.h +++ b/src/cpp/scaler/ymq/iocp_context.h @@ -21,7 +21,7 @@ class EventManager; // In the constructor, the epoll context should register eventfd/timerfd from // This way, the queues need not know about the event manager. We don't use callbacks. -class IocpContext { +class IOCPContext { public: using Function = Configuration::ExecutionFunction; using DelayedFunctionQueue = std::queue; @@ -29,7 +29,7 @@ class IocpContext { HANDLE _completionPort; // TODO: Handle error with unrecoverable error in the next PR. - IocpContext() + IOCPContext() : _completionPort(CreateIoCompletionPort(INVALID_HANDLE_VALUE, nullptr, (ULONG_PTR)0, 1)) , _timingFunctions(_completionPort, _isTimingFd) , _interruptiveFunctions(_completionPort, _isInterruptiveFd) @@ -45,7 +45,7 @@ class IocpContext { } } - ~IocpContext() { CloseHandle(_completionPort); } + ~IOCPContext() { CloseHandle(_completionPort); } void loop(); diff --git a/src/cpp/scaler/ymq/message_connection_tcp.cpp b/src/cpp/scaler/ymq/message_connection.cpp similarity index 90% rename from src/cpp/scaler/ymq/message_connection_tcp.cpp rename to src/cpp/scaler/ymq/message_connection.cpp index 63f7f6eac..e35357caa 100644 --- a/src/cpp/scaler/ymq/message_connection_tcp.cpp +++ b/src/cpp/scaler/ymq/message_connection.cpp @@ -1,5 +1,5 @@ -#include "scaler/ymq/message_connection_tcp.h" +#include "scaler/ymq/message_connection.h" #include #include @@ -23,7 +23,7 @@ namespace ymq { static constexpr const size_t HEADER_SIZE = sizeof(uint64_t); -constexpr bool MessageConnectionTCP::isCompleteMessage(const TcpReadOperation& x) +constexpr bool MessageConnection::isCompleteMessage(const TcpReadOperation& x) { if (x._cursor < HEADER_SIZE) { return false; @@ -34,7 +34,7 @@ constexpr bool MessageConnectionTCP::isCompleteMessage(const TcpReadOperation& x return false; } -MessageConnectionTCP::MessageConnectionTCP( +MessageConnection::MessageConnection( EventLoopThread* eventLoopThread, int connFd, sockaddr localAddr, @@ -62,7 +62,7 @@ MessageConnectionTCP::MessageConnectionTCP( _eventManager->onError = [this] { this->onError(); }; } -MessageConnectionTCP::MessageConnectionTCP( +MessageConnection::MessageConnection( EventLoopThread* eventLoopThread, std::string localIOSocketIdentity, std::string remoteIOSocketIdentity, @@ -88,7 +88,7 @@ MessageConnectionTCP::MessageConnectionTCP( _eventManager->onError = [this] { this->onError(); }; } -void MessageConnectionTCP::onCreated() +void MessageConnection::onCreated() { if (_rawConn.nativeHandle() != 0) { this->_eventLoopThread->_eventLoop.addFdToLoop( @@ -107,12 +107,12 @@ void MessageConnectionTCP::onCreated() } } -bool MessageConnectionTCP::disconnected() +bool MessageConnection::disconnected() { return _rawConn.nativeHandle() == 0; } -std::expected MessageConnectionTCP::tryReadOneMessage() +std::expected MessageConnection::tryReadOneMessage() { if (_receivedReadOperations.empty() || isCompleteMessage(_receivedReadOperations.back())) { _receivedReadOperations.emplace(); @@ -162,18 +162,18 @@ std::expected MessageConnectionTCP::tryRead _readSomeBytes = true; } - if (status != RawConnectionTCPFD::IOStatus::MoreBytesAvailable) { + if (status != RawStreamConnectionHandle::IOStatus::MoreBytesAvailable) { switch (status) { - case RawConnectionTCPFD::IOStatus::Aborted: { + case RawStreamConnectionHandle::IOStatus::Aborted: { return std::unexpected {IOError::Aborted}; } - case RawConnectionTCPFD::IOStatus::Disconnected: { + case RawStreamConnectionHandle::IOStatus::Disconnected: { return std::unexpected {IOError::Disconnected}; } - case RawConnectionTCPFD::IOStatus::Drained: { + case RawStreamConnectionHandle::IOStatus::Drained: { return std::unexpected {IOError::Drained}; } - case RawConnectionTCPFD::IOStatus::MoreBytesAvailable: { + case RawStreamConnectionHandle::IOStatus::MoreBytesAvailable: { std::unreachable(); } } @@ -183,7 +183,7 @@ std::expected MessageConnectionTCP::tryRead } // on Return, unexpected value shall be interpreted as this - 0 = close, other -> errno -std::expected MessageConnectionTCP::tryReadMessages() +std::expected MessageConnection::tryReadMessages() { while (true) { auto res = tryReadOneMessage(); @@ -193,7 +193,7 @@ std::expected MessageConnectionTCP::tryRead } } -void MessageConnectionTCP::updateReadOperation() +void MessageConnection::updateReadOperation() { while (_pendingRecvMessageCallbacks->size() && _receivedReadOperations.size()) { if (isCompleteMessage(_receivedReadOperations.front())) { @@ -211,7 +211,7 @@ void MessageConnectionTCP::updateReadOperation() } } -void MessageConnectionTCP::setRemoteIdentity() noexcept +void MessageConnection::setRemoteIdentity() noexcept { if (!_remoteIOSocketIdentity && (_receivedReadOperations.size() && isCompleteMessage(_receivedReadOperations.front()))) { @@ -223,7 +223,7 @@ void MessageConnectionTCP::setRemoteIdentity() noexcept } } -void MessageConnectionTCP::onRead() +void MessageConnection::onRead() { if (_rawConn.nativeHandle() == 0) { return; @@ -294,7 +294,7 @@ void MessageConnectionTCP::onRead() _rawConn.prepareReadBytes(this->_eventManager.get()); } -void MessageConnectionTCP::onWrite() +void MessageConnection::onWrite() { // This is because after disconnected, onRead will be called first, and that will set // _connFd to 0. There's no way to not call onWrite in this case. So we return early. @@ -339,7 +339,7 @@ void MessageConnectionTCP::onWrite() } } -void MessageConnectionTCP::onClose() +void MessageConnection::onClose() { if (_rawConn.nativeHandle()) { if (_remoteIOSocketIdentity) { @@ -360,7 +360,7 @@ void MessageConnectionTCP::onClose() } }; -std::expected MessageConnectionTCP::trySendQueuedMessages() +std::expected MessageConnection::trySendQueuedMessages() { // TODO: Should this accept 0 length send? if (_writeOperations.empty()) { @@ -399,16 +399,16 @@ std::expected MessageConnectionTCP::trySe return n; } switch (status) { - case RawConnectionTCPFD::IOStatus::Drained: { + case RawStreamConnectionHandle::IOStatus::Drained: { return std::unexpected {IOError::Drained}; } - case RawConnectionTCPFD::IOStatus::Disconnected: { + case RawStreamConnectionHandle::IOStatus::Disconnected: { return std::unexpected {IOError::Disconnected}; } - case RawConnectionTCPFD::IOStatus::Aborted: { + case RawStreamConnectionHandle::IOStatus::Aborted: { return std::unexpected {IOError::Aborted}; } - case RawConnectionTCPFD::IOStatus::MoreBytesAvailable: { + case RawStreamConnectionHandle::IOStatus::MoreBytesAvailable: { std::unreachable(); } } @@ -418,7 +418,7 @@ std::expected MessageConnectionTCP::trySe // TODO: There is a classic optimization that can (and should) be done. That is, we store // prefix sum in each write operation, and perform binary search instead of linear search // to find the first write operation we haven't complete. - gxu -void MessageConnectionTCP::updateWriteOperations(size_t n) +void MessageConnection::updateWriteOperations(size_t n) { auto firstIncomplete = _writeOperations.begin(); _sendCursor += n; @@ -451,7 +451,7 @@ void MessageConnectionTCP::updateWriteOperations(size_t n) // _writeOperations.shrink_to_fit(); } -void MessageConnectionTCP::sendMessage(Message msg, SendMessageCallback onMessageSent) +void MessageConnection::sendMessage(Message msg, SendMessageCallback onMessageSent) { TcpWriteOperation writeOp(std::move(msg), std::move(onMessageSent)); _writeOperations.emplace_back(std::move(writeOp)); @@ -462,7 +462,7 @@ void MessageConnectionTCP::sendMessage(Message msg, SendMessageCallback onMessag onWrite(); } -bool MessageConnectionTCP::recvMessage() +bool MessageConnection::recvMessage() { if (_receivedReadOperations.empty() || _pendingRecvMessageCallbacks->empty() || !isCompleteMessage(_receivedReadOperations.front())) { @@ -473,14 +473,14 @@ bool MessageConnectionTCP::recvMessage() return true; } -void MessageConnectionTCP::disconnect() +void MessageConnection::disconnect() { if (_rawConn.nativeHandle()) { _rawConn.shutdownWrite(); } } -MessageConnectionTCP::~MessageConnectionTCP() noexcept +MessageConnection::~MessageConnection() noexcept { if (_rawConn.nativeHandle() != 0) { _eventLoopThread->_eventLoop.removeFdFromLoop(_rawConn.nativeHandle()); diff --git a/src/cpp/scaler/ymq/message_connection_tcp.h b/src/cpp/scaler/ymq/message_connection.h similarity index 91% rename from src/cpp/scaler/ymq/message_connection_tcp.h rename to src/cpp/scaler/ymq/message_connection.h index 0420d32f2..264f080d8 100644 --- a/src/cpp/scaler/ymq/message_connection_tcp.h +++ b/src/cpp/scaler/ymq/message_connection.h @@ -7,7 +7,7 @@ #include "scaler/logging/logging.h" #include "scaler/ymq/configuration.h" -#include "scaler/ymq/internal/raw_connection_tcp_fd.h" +#include "scaler/ymq/internal/raw_stream_connection_handle.h" #include "scaler/ymq/io_socket.h" #include "scaler/ymq/tcp_operations.h" @@ -17,12 +17,12 @@ namespace ymq { class EventLoopThread; class EventManager; -class MessageConnectionTCP { +class MessageConnection { public: using SendMessageCallback = Configuration::SendMessageCallback; using RecvMessageCallback = Configuration::RecvMessageCallback; - MessageConnectionTCP( + MessageConnection( EventLoopThread* eventLoopThread, int connFd, sockaddr localAddr, @@ -32,14 +32,14 @@ class MessageConnectionTCP { std::queue* _pendingRecvMessageCallbacks, std::queue* leftoverMessagesAfterConnectionDied) noexcept; - MessageConnectionTCP( + MessageConnection( EventLoopThread* eventLoopThread, std::string localIOSocketIdentity, std::string remoteIOSocketIdentity, std::queue* _pendingRecvMessageCallbacks, std::queue* leftoverMessagesAfterConnectionDied) noexcept; - ~MessageConnectionTCP() noexcept; + ~MessageConnection() noexcept; void onCreated(); @@ -82,7 +82,7 @@ class MessageConnectionTCP { void setRemoteIdentity() noexcept; std::unique_ptr _eventManager; - RawConnectionTCPFD _rawConn; + RawStreamConnectionHandle _rawConn; sockaddr _localAddr; std::string _localIOSocketIdentity; @@ -102,8 +102,8 @@ class MessageConnectionTCP { bool _readSomeBytes; constexpr static bool isCompleteMessage(const TcpReadOperation& x); - friend void IOSocket::onConnectionIdentityReceived(MessageConnectionTCP* conn) noexcept; - friend void IOSocket::onConnectionDisconnected(MessageConnectionTCP* conn, bool keepInBook) noexcept; + friend void IOSocket::onConnectionIdentityReceived(MessageConnection* conn) noexcept; + friend void IOSocket::onConnectionDisconnected(MessageConnection* conn, bool keepInBook) noexcept; }; } // namespace ymq diff --git a/src/cpp/scaler/ymq/pymod_ymq/python.h b/src/cpp/scaler/ymq/pymod_ymq/python.h index f7df2caf2..1c3f1b48f 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/python.h +++ b/src/cpp/scaler/ymq/pymod_ymq/python.h @@ -1,7 +1,17 @@ #pragma once #define PY_SSIZE_T_CLEAN + +// if on Windows and in debug mode, undefine _DEBUG before including Python.h +// this prevents issues including the debug version of the Python library +#if defined(_WIN32) && defined(_DEBUG) +#undef _DEBUG +#include +#define _DEBUG +#else #include +#endif + #include #include "scaler/error/error.h" diff --git a/src/cpp/scaler/ymq/tcp_client.cpp b/src/cpp/scaler/ymq/stream_client.cpp similarity index 90% rename from src/cpp/scaler/ymq/tcp_client.cpp rename to src/cpp/scaler/ymq/stream_client.cpp index e5268f954..833056056 100644 --- a/src/cpp/scaler/ymq/tcp_client.cpp +++ b/src/cpp/scaler/ymq/stream_client.cpp @@ -1,4 +1,4 @@ -#include "scaler/ymq/tcp_client.h" +#include "scaler/ymq/stream_client.h" #include #include @@ -7,14 +7,14 @@ #include "scaler/ymq/event_loop_thread.h" #include "scaler/ymq/event_manager.h" #include "scaler/ymq/io_socket.h" -#include "scaler/ymq/message_connection_tcp.h" +#include "scaler/ymq/message_connection.h" #include "scaler/ymq/network_utils.h" #include "scaler/ymq/timestamp.h" namespace scaler { namespace ymq { -void TcpClient::onCreated() +void StreamClient::onCreated() { assert(_rawClient.nativeHandle() == 0); assert(_eventManager.get() != nullptr); @@ -44,7 +44,7 @@ void TcpClient::onCreated() } } -TcpClient::TcpClient( +StreamClient::StreamClient( EventLoopThread* eventLoopThread, std::string localIOSocketIdentity, sockaddr remoteAddr, @@ -66,11 +66,11 @@ TcpClient::TcpClient( _eventManager->onError = [this] { this->onError(); }; } -void TcpClient::onRead() +void StreamClient::onRead() { } -void TcpClient::onWrite() +void StreamClient::onWrite() { if (!_rawClient.nativeHandle()) { return; @@ -94,10 +94,10 @@ void TcpClient::onWrite() _rawClient.zeroNativeHandle(); _connected = true; - _eventLoopThread->_eventLoop.executeLater([sock] { sock->removeConnectedTcpClient(); }); + _eventLoopThread->_eventLoop.executeLater([sock] { sock->removeConnectedTCPClient(); }); } -void TcpClient::retry() +void StreamClient::retry() { if (_retryTimes > _maxRetryTimes) { _logger.log(Logger::LoggingLevel::error, "Retried times has reached maximum: ", _maxRetryTimes); @@ -113,7 +113,7 @@ void TcpClient::retry() _retryIdentifier = _eventLoopThread->_eventLoop.executeAt(at, [this] { this->onCreated(); }); } -void TcpClient::disconnect() +void StreamClient::disconnect() { if (_rawClient.nativeHandle()) { _eventLoopThread->_eventLoop.removeFdFromLoop(_rawClient.nativeHandle()); @@ -121,13 +121,13 @@ void TcpClient::disconnect() } } -TcpClient::~TcpClient() noexcept +StreamClient::~StreamClient() noexcept { disconnect(); if (_retryTimes > 0) { _eventLoopThread->_eventLoop.cancelExecution(_retryIdentifier); } - // TODO: Do we think this is an error? See TcpServer::~TcpServer for detail. + // TODO: Do we think this is an error? See TCPServer::~TCPServer for detail. if (_onConnectReturn) { _onConnectReturn({}); } diff --git a/src/cpp/scaler/ymq/tcp_client.h b/src/cpp/scaler/ymq/stream_client.h similarity index 81% rename from src/cpp/scaler/ymq/tcp_client.h rename to src/cpp/scaler/ymq/stream_client.h index 084257f89..4701e2712 100644 --- a/src/cpp/scaler/ymq/tcp_client.h +++ b/src/cpp/scaler/ymq/stream_client.h @@ -5,7 +5,7 @@ // First-party #include "scaler/logging/logging.h" #include "scaler/ymq/configuration.h" -#include "scaler/ymq/internal/raw_client_tcp_fd.h" +#include "scaler/ymq/internal/raw_stream_client_handle.h" namespace scaler { namespace ymq { @@ -13,19 +13,19 @@ namespace ymq { class EventLoopThread; class EventManager; -class TcpClient { +class StreamClient { public: using ConnectReturnCallback = Configuration::ConnectReturnCallback; - TcpClient( + StreamClient( EventLoopThread* eventLoopThread, std::string localIOSocketIdentity, sockaddr remoteAddr, ConnectReturnCallback onConnectReturn, size_t maxRetryTimes) noexcept; - TcpClient(const TcpClient&) = delete; - TcpClient& operator=(const TcpClient&) = delete; - ~TcpClient() noexcept; + StreamClient(const StreamClient&) = delete; + StreamClient& operator=(const StreamClient&) = delete; + ~StreamClient() noexcept; void onCreated(); void retry(); @@ -55,7 +55,7 @@ class TcpClient { const size_t _maxRetryTimes; - RawClientTCPFD _rawClient; + RawStreamClientHandle _rawClient; }; } // namespace ymq diff --git a/src/cpp/scaler/ymq/tcp_server.cpp b/src/cpp/scaler/ymq/stream_server.cpp similarity index 90% rename from src/cpp/scaler/ymq/tcp_server.cpp rename to src/cpp/scaler/ymq/stream_server.cpp index a24ddf9a3..89d950836 100644 --- a/src/cpp/scaler/ymq/tcp_server.cpp +++ b/src/cpp/scaler/ymq/stream_server.cpp @@ -1,4 +1,4 @@ -#include "scaler/ymq/tcp_server.h" +#include "scaler/ymq/stream_server.h" #include #include @@ -7,13 +7,13 @@ #include "scaler/ymq/event_loop_thread.h" #include "scaler/ymq/event_manager.h" #include "scaler/ymq/io_socket.h" -#include "scaler/ymq/message_connection_tcp.h" +#include "scaler/ymq/message_connection.h" #include "scaler/ymq/network_utils.h" namespace scaler { namespace ymq { -bool TcpServer::createAndBindSocket() +bool StreamServer::createAndBindSocket() { if (!_rawServer.setReuseAddress()) { _logger.log( @@ -33,7 +33,7 @@ bool TcpServer::createAndBindSocket() return true; } -TcpServer::TcpServer( +StreamServer::StreamServer( EventLoopThread* eventLoopThread, std::string localIOSocketIdentity, sockaddr addr, @@ -51,7 +51,7 @@ TcpServer::TcpServer( _eventManager->onError = [this] { this->onError(); }; } -void TcpServer::onCreated() +void StreamServer::onCreated() { if (!createAndBindSocket()) { return; @@ -64,7 +64,7 @@ void TcpServer::onCreated() _onBindReturn = {}; } -void TcpServer::disconnect() +void StreamServer::disconnect() { if (_rawServer.nativeHandle()) { _eventLoopThread->_eventLoop.removeFdFromLoop(_rawServer.nativeHandle()); @@ -72,7 +72,7 @@ void TcpServer::disconnect() } } -void TcpServer::onRead() +void StreamServer::onRead() { if (!_rawServer.nativeHandle()) { return; @@ -91,7 +91,7 @@ void TcpServer::onRead() _rawServer.prepareAcceptSocket((void*)_eventManager.get()); } -TcpServer::~TcpServer() noexcept +StreamServer::~StreamServer() noexcept { disconnect(); // TODO: Do we think this is an error? In extreme cases: diff --git a/src/cpp/scaler/ymq/tcp_server.h b/src/cpp/scaler/ymq/stream_server.h similarity index 81% rename from src/cpp/scaler/ymq/tcp_server.h rename to src/cpp/scaler/ymq/stream_server.h index afdfb6068..e087a54c6 100644 --- a/src/cpp/scaler/ymq/tcp_server.h +++ b/src/cpp/scaler/ymq/stream_server.h @@ -2,7 +2,7 @@ #include -#include "scaler/ymq/internal/raw_server_tcp_fd.h" +#include "scaler/ymq/internal/raw_stream_server_handle.h" // First-party #include "scaler/logging/logging.h" @@ -14,18 +14,18 @@ namespace ymq { class EventLoopThread; class EventManager; -class TcpServer { +class StreamServer { public: using BindReturnCallback = Configuration::BindReturnCallback; - TcpServer( + StreamServer( EventLoopThread* eventLoop, std::string localIOSocketIdentity, sockaddr addr, BindReturnCallback onBindReturn) noexcept; - TcpServer(const TcpServer&) = delete; - TcpServer& operator=(const TcpServer&) = delete; - ~TcpServer() noexcept; + StreamServer(const StreamServer&) = delete; + StreamServer& operator=(const StreamServer&) = delete; + ~StreamServer() noexcept; void disconnect(); void onCreated(); @@ -51,7 +51,7 @@ class TcpServer { Logger _logger; - RawServerTCPFD _rawServer; + RawStreamServerHandle _rawServer; }; } // namespace ymq diff --git a/src/scaler/entry_points/scheduler.py b/src/scaler/entry_points/scheduler.py index e210bf037..a91fdfda0 100644 --- a/src/scaler/entry_points/scheduler.py +++ b/src/scaler/entry_points/scheduler.py @@ -8,7 +8,6 @@ from scaler.scheduler.allocate_policy.allocate_policy import AllocatePolicy from scaler.scheduler.controllers.scaling_policies.types import ScalingControllerStrategy from scaler.utility.event_loop import EventLoopType -from scaler.utility.network_util import get_available_tcp_port def get_args(): @@ -105,8 +104,9 @@ def main(): object_storage = None if object_storage_address is None: + assert scheduler_config.scheduler_address.port is not None, "Scheduler address must have a port" object_storage_address = ObjectStorageConfig( - host=scheduler_config.scheduler_address.host, port=get_available_tcp_port() + host=scheduler_config.scheduler_address.host, port=scheduler_config.scheduler_address.port + 1 ) object_storage = ObjectStorageServerProcess( object_storage_address=object_storage_address, diff --git a/src/scaler/io/sync_object_storage_connector.py b/src/scaler/io/sync_object_storage_connector.py index 97d7e064c..3947e9bdd 100644 --- a/src/scaler/io/sync_object_storage_connector.py +++ b/src/scaler/io/sync_object_storage_connector.py @@ -156,12 +156,14 @@ def __send_buffers(self, buffers: List[bytes]) -> None: if len(buffers) < 1: return + assert self._socket is not None + total_size = sum(len(buffer) for buffer in buffers) # If the message is small enough, first try to send it at once with sendmsg(). This would ensure the message can # be transmitted within a single TCP segment. if total_size < MAX_CHUNK_SIZE: - sent = self._socket.sendmsg(buffers) + sent = self._socket.sendmsg(buffers) # type: ignore[attr-defined] if sent <= 0: self.__raise_connection_failure() diff --git a/src/scaler/scheduler/allocate_policy/capability_allocate_policy.py b/src/scaler/scheduler/allocate_policy/capability_allocate_policy.py index 556be8fb8..a60aae4b0 100644 --- a/src/scaler/scheduler/allocate_policy/capability_allocate_policy.py +++ b/src/scaler/scheduler/allocate_policy/capability_allocate_policy.py @@ -104,11 +104,11 @@ def balance(self) -> Dict[WorkerID, List[TaskID]]: # # The overall worst-case time complexity of the balancing algorithm is: # - # O(n_workers • log(n_workers) + n_tasks • n_workers • n_capabilities) + # O(n_workers * log(n_workers) + n_tasks * n_workers * n_capabilities) # # However, if the cluster does not use any capability, time complexity is always: # - # O(n_workers • log(n_workers) + n_tasks • log(n_workers)) + # O(n_workers * log(n_workers) + n_tasks * log(n_workers)) # # If capability constraints are used, this might result in less than optimal balancing. That's because, in some # cases, the optimal balancing might require to move tasks between more than two workers. Consider this @@ -152,7 +152,7 @@ def is_balanced(worker: _WorkerHolder) -> bool: # Then, we sort the remaining workers by the number of queued tasks. # - # Time complexity is O(n_workers • log(n_workers)) + # Time complexity is O(n_workers * log(n_workers)) sorted_workers: SortedList[_WorkerHolder] = SortedList(workers, key=lambda worker: worker.n_tasks()) @@ -161,8 +161,8 @@ def is_balanced(worker: _WorkerHolder) -> bool: # - all workers are balanced; # - we cannot find a low-load worker than can accept tasks from a high-load worker. # - # Worst-case time complexity is O(n_tasks • n_workers • n_capabilities). - # If no tag is used in the cluster, complexity is always O(n_tasks • log(n_workers)) + # Worst-case time complexity is O(n_tasks * n_workers * n_capabilities). + # If no tag is used in the cluster, complexity is always O(n_tasks * log(n_workers)) balancing_advice: Dict[WorkerID, List[TaskID]] = defaultdict(list) unbalanceable_tasks: Set[bytes] = set() @@ -218,7 +218,7 @@ def is_balanced(worker: _WorkerHolder) -> bool: def __balance_try_reassign_task(task: _TaskHolder, worker_candidates: Iterable[_WorkerHolder]) -> Optional[int]: """Returns the index of the first worker that can accept the task.""" - # Time complexity is O(n_workers • len(task.capabilities)) + # Time complexity is O(n_workers * len(task.capabilities)) for worker_index, worker in enumerate(worker_candidates): if task.capabilities.issubset(worker.capabilities): @@ -227,7 +227,7 @@ def __balance_try_reassign_task(task: _TaskHolder, worker_candidates: Iterable[_ return None def assign_task(self, task: Task) -> WorkerID: - # Worst-case time complexity is O(n_workers • len(task.capabilities)) + # Worst-case time complexity is O(n_workers * len(task.capabilities)) available_workers = self.__get_available_workers_for_capabilities(task.capabilities) @@ -265,7 +265,7 @@ def statistics(self) -> Dict: } def __get_available_workers_for_capabilities(self, capabilities: Dict[str, int]) -> List[_WorkerHolder]: - # Worst-case time complexity is O(n_workers • len(capabilities)) + # Worst-case time complexity is O(n_workers * len(capabilities)) if any(capability not in self._capability_to_worker_ids for capability in capabilities.keys()): return [] diff --git a/src/scaler/ui/common/__init__.py b/src/scaler/ui/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scaler/ui/constants.py b/src/scaler/ui/common/constants.py similarity index 100% rename from src/scaler/ui/constants.py rename to src/scaler/ui/common/constants.py diff --git a/src/scaler/ui/live_display.py b/src/scaler/ui/common/live_display.py similarity index 97% rename from src/scaler/ui/live_display.py rename to src/scaler/ui/common/live_display.py index 5e4a5f1ae..5d730f2a2 100644 --- a/src/scaler/ui/live_display.py +++ b/src/scaler/ui/common/live_display.py @@ -8,7 +8,7 @@ from scaler.protocol.python.common import WorkerState from scaler.protocol.python.message import StateTask, StateWorker from scaler.protocol.python.status import WorkerStatus -from scaler.ui.utility import display_capabilities, format_worker_name +from scaler.ui.common.utility import display_capabilities, format_worker_name from scaler.utility.formatter import format_microseconds, format_seconds @@ -97,7 +97,7 @@ def draw_row(self): ui.label().bind_text_from(self, "queued") ui.label().bind_text_from(self, "suspended") ui.label().bind_text_from(self, "lag") - ui.label().bind_text_from(self, "ITL") + ui.label().bind_text_from(self, "itl") ui.label().bind_text_from(self, "last_seen") ui.label().bind_text_from(self, "display_capabilities") diff --git a/src/scaler/ui/memory_window.py b/src/scaler/ui/common/memory_window.py similarity index 97% rename from src/scaler/ui/memory_window.py rename to src/scaler/ui/common/memory_window.py index 4193a8b1a..75136362a 100644 --- a/src/scaler/ui/memory_window.py +++ b/src/scaler/ui/common/memory_window.py @@ -4,8 +4,8 @@ from nicegui import ui from scaler.protocol.python.message import StateTask, StateWorker -from scaler.ui.setting_page import Settings -from scaler.ui.utility import format_timediff, get_bounds, make_taskstream_ticks, make_tick_text +from scaler.ui.common.setting_page import Settings +from scaler.ui.common.utility import format_timediff, get_bounds, make_taskstream_ticks, make_tick_text from scaler.utility.formatter import format_bytes from scaler.utility.metadata.profile_result import ProfileResult diff --git a/src/scaler/ui/setting_page.py b/src/scaler/ui/common/setting_page.py similarity index 100% rename from src/scaler/ui/setting_page.py rename to src/scaler/ui/common/setting_page.py diff --git a/src/scaler/ui/task_graph.py b/src/scaler/ui/common/task_graph.py similarity index 98% rename from src/scaler/ui/task_graph.py rename to src/scaler/ui/common/task_graph.py index c528e2fcc..8331c9ed9 100644 --- a/src/scaler/ui/task_graph.py +++ b/src/scaler/ui/common/task_graph.py @@ -10,8 +10,9 @@ from scaler.protocol.python.common import TaskState, WorkerState from scaler.protocol.python.message import StateTask, StateWorker -from scaler.ui.setting_page import Settings -from scaler.ui.utility import ( +from scaler.ui.util import NICEGUI_MAJOR_VERSION +from scaler.ui.common.setting_page import Settings +from scaler.ui.common.utility import ( COMPLETED_TASK_STATUSES, display_capabilities, format_timediff, @@ -176,8 +177,15 @@ def __free_row_for_task(self, task_id: bytes): def setup_task_stream(self, settings: Settings): self._card = ui.card() self._card.classes("w-full").style("height: 800px; overflow:auto;") + + # TODO: remove when v1 and v2 are separated + def html_func(x: str): + if NICEGUI_MAJOR_VERSION < 3: + return ui.html(x) + return ui.html(x, sanitize=False) # type: ignore[call-arg] + with self._card: - ui.html( + html_func( """
Legend: diff --git a/src/scaler/ui/task_log.py b/src/scaler/ui/common/task_log.py similarity index 91% rename from src/scaler/ui/task_log.py rename to src/scaler/ui/common/task_log.py index eb6e19c41..11b598748 100644 --- a/src/scaler/ui/task_log.py +++ b/src/scaler/ui/common/task_log.py @@ -7,7 +7,8 @@ from scaler.protocol.python.common import TaskState from scaler.protocol.python.message import StateTask, StateWorker -from scaler.ui.utility import COMPLETED_TASK_STATUSES, display_capabilities +from scaler.ui.util import NICEGUI_MAJOR_VERSION +from scaler.ui.common.utility import COMPLETED_TASK_STATUSES, display_capabilities from scaler.utility.formatter import format_bytes from scaler.utility.metadata.profile_result import ProfileResult @@ -52,7 +53,10 @@ def populate( def draw_row(self): color = "color: green" if self.status == TaskState.Success.name else "color: red" - ui.html(TASK_ID_HTML_TEMPLATE.format(task=self.task)) + if NICEGUI_MAJOR_VERSION < 3: + ui.html(TASK_ID_HTML_TEMPLATE.format(task=self.task)) + else: + ui.html(TASK_ID_HTML_TEMPLATE.format(task=self.task), sanitize=False) # type: ignore[call-arg] ui.label(self.function) ui.label(self.duration) ui.label(self.peak_mem) diff --git a/src/scaler/ui/utility.py b/src/scaler/ui/common/utility.py similarity index 97% rename from src/scaler/ui/utility.py rename to src/scaler/ui/common/utility.py index 54abc4dc4..72f9b1afc 100644 --- a/src/scaler/ui/utility.py +++ b/src/scaler/ui/common/utility.py @@ -2,7 +2,7 @@ from typing import List, Set, Tuple from scaler.protocol.python.common import TaskState -from scaler.ui.setting_page import Settings +from scaler.ui.common.setting_page import Settings COMPLETED_TASK_STATUSES = { TaskState.Success, diff --git a/src/scaler/ui/common/webui.py b/src/scaler/ui/common/webui.py new file mode 100644 index 000000000..e4c0fc04e --- /dev/null +++ b/src/scaler/ui/common/webui.py @@ -0,0 +1,80 @@ +import dataclasses +import logging + +from scaler.protocol.python.message import StateBalanceAdvice, StateScheduler, StateTask, StateWorker +from scaler.protocol.python.mixins import Message +from scaler.ui.common.live_display import SchedulerSection, WorkersSection +from scaler.ui.common.memory_window import MemoryChart +from scaler.ui.common.setting_page import Settings +from scaler.ui.common.task_graph import TaskStream +from scaler.ui.common.task_log import TaskLogTable +from scaler.ui.common.worker_processors import WorkerProcessors +from scaler.utility.formatter import format_bytes, format_percentage + + +@dataclasses.dataclass +class Sections: + scheduler_section: SchedulerSection + workers_section: WorkersSection + task_stream_section: TaskStream + memory_usage_section: MemoryChart + tasklog_section: TaskLogTable + worker_processors: WorkerProcessors + settings_section: Settings + + +def process_scheduler_message(status: Message, tables: Sections): + if isinstance(status, StateScheduler): + __update_scheduler_state(status, tables) + return + + if isinstance(status, StateWorker): + logging.info(f"Received StateWorker update for worker {status.worker_id.decode()} with {status.state.name}") + tables.scheduler_section.handle_worker_state(status) + tables.workers_section.handle_worker_state(status) + tables.task_stream_section.handle_worker_state(status) + tables.memory_usage_section.handle_worker_state(status) + tables.tasklog_section.handle_worker_state(status) + tables.worker_processors.handle_worker_state(status) + tables.settings_section.handle_worker_state(status) + return + + if isinstance(status, StateTask): + logging.debug(f"Received StateTask update for task {status.task_id.hex()} with {status.state.name}") + tables.scheduler_section.handle_task_state(status) + tables.workers_section.handle_task_state(status) + tables.task_stream_section.handle_task_state(status) + tables.memory_usage_section.handle_task_state(status) + tables.tasklog_section.handle_task_state(status) + tables.worker_processors.handle_task_state(status) + tables.settings_section.handle_task_state(status) + return + + if isinstance(status, StateBalanceAdvice): + logging.debug(f"Received StateBalanceAdvice for {status.worker_id.decode()} with {len(status.task_ids)} tasks") + return + + logging.info(f"Unhandled message received: {type(status)}") + + +def __update_scheduler_state(data: StateScheduler, tables: Sections): + tables.scheduler_section.cpu = format_percentage(data.scheduler.cpu) + tables.scheduler_section.rss = format_bytes(data.scheduler.rss) + tables.scheduler_section.rss_free = format_bytes(data.rss_free) + + previous_workers = set(tables.workers_section.workers.keys()) + current_workers = set(worker_data.worker_id.decode() for worker_data in data.worker_manager.workers) + + for worker_data in data.worker_manager.workers: + worker_name = worker_data.worker_id.decode() + tables.workers_section.workers[worker_name].populate(worker_data) + + for died_worker in previous_workers - current_workers: + tables.workers_section.workers.pop(died_worker) + tables.worker_processors.remove_worker(died_worker) + tables.task_stream_section.mark_dead_worker(died_worker) + + if previous_workers != current_workers: + tables.workers_section.draw_section.refresh() + + tables.worker_processors.update_data(data.worker_manager.workers) diff --git a/src/scaler/ui/worker_processors.py b/src/scaler/ui/common/worker_processors.py similarity index 98% rename from src/scaler/ui/worker_processors.py rename to src/scaler/ui/common/worker_processors.py index eb5c137cc..1d54c0d84 100644 --- a/src/scaler/ui/worker_processors.py +++ b/src/scaler/ui/common/worker_processors.py @@ -8,7 +8,7 @@ from scaler.protocol.python.common import WorkerState from scaler.protocol.python.message import StateTask, StateWorker from scaler.protocol.python.status import ProcessorStatus, WorkerStatus -from scaler.ui.utility import format_worker_name +from scaler.ui.common.utility import format_worker_name @dataclasses.dataclass diff --git a/src/scaler/ui/util.py b/src/scaler/ui/util.py new file mode 100644 index 000000000..0381dae52 --- /dev/null +++ b/src/scaler/ui/util.py @@ -0,0 +1,9 @@ +from packaging.version import parse + +try: + from nicegui.version import __version__ +except ImportError as e: + raise ImportError("Could not determine NiceGUI version. Is it installed?") from e + +NICEGUI_VERSION = parse(__version__) +NICEGUI_MAJOR_VERSION = NICEGUI_VERSION.major diff --git a/src/scaler/ui/v1.py b/src/scaler/ui/v1.py new file mode 100644 index 000000000..49fafb3dc --- /dev/null +++ b/src/scaler/ui/v1.py @@ -0,0 +1,74 @@ +import threading +from functools import partial + +from nicegui import ui + +from scaler.config.types.zmq import ZMQConfig +from scaler.io.sync_subscriber import ZMQSyncSubscriber +from scaler.ui.common.constants import ( + MEMORY_USAGE_UPDATE_INTERVAL, + TASK_LOG_REFRESH_INTERVAL, + TASK_STREAM_UPDATE_INTERVAL, + WORKER_PROCESSORS_REFRESH_INTERVAL, +) +from scaler.ui.common.live_display import SchedulerSection, WorkersSection +from scaler.ui.common.memory_window import MemoryChart +from scaler.ui.common.setting_page import Settings +from scaler.ui.common.task_graph import TaskStream +from scaler.ui.common.task_log import TaskLogTable +from scaler.ui.common.webui import Sections, process_scheduler_message +from scaler.ui.common.worker_processors import WorkerProcessors + + +def start_webui_v1(address: str, host: str, port: int): + tables = Sections( + scheduler_section=SchedulerSection(), + workers_section=WorkersSection(), + task_stream_section=TaskStream(), + memory_usage_section=MemoryChart(), + tasklog_section=TaskLogTable(), + worker_processors=WorkerProcessors(), + settings_section=Settings(), + ) + + with ui.tabs().classes("w-full h-full") as tabs: + live_tab = ui.tab("Live") + tasklog_tab = ui.tab("Task Log") + stream_tab = ui.tab("Worker Task Stream") + worker_processors_tab = ui.tab("Worker Processors") + settings_tab = ui.tab("Settings") + + with ui.tab_panels(tabs, value=live_tab).classes("w-full"): + with ui.tab_panel(live_tab): + tables.scheduler_section.draw_section() + tables.workers_section.draw_section() # type: ignore[call-arg] + + with ui.tab_panel(tasklog_tab): + tables.tasklog_section.draw_section() # type: ignore[call-arg] + ui.timer(TASK_LOG_REFRESH_INTERVAL, tables.tasklog_section.draw_section.refresh, active=True) + + with ui.tab_panel(stream_tab): + tables.task_stream_section.setup_task_stream(tables.settings_section) + ui.timer(TASK_STREAM_UPDATE_INTERVAL, tables.task_stream_section.update_plot, active=True) + + tables.memory_usage_section.setup_memory_chart(tables.settings_section) + ui.timer(MEMORY_USAGE_UPDATE_INTERVAL, tables.memory_usage_section.update_plot, active=True) + + with ui.tab_panel(worker_processors_tab): + tables.worker_processors.draw_section() # type: ignore[call-arg] + ui.timer(WORKER_PROCESSORS_REFRESH_INTERVAL, tables.worker_processors.draw_section.refresh, active=True) + + with ui.tab_panel(settings_tab): + tables.settings_section.draw_section() + + subscriber = ZMQSyncSubscriber( + address=ZMQConfig.from_string(address), + callback=partial(process_scheduler_message, tables=tables), + topic=b"", + timeout_seconds=-1, + ) + subscriber.start() + + ui_thread = threading.Thread(target=partial(ui.run, host=host, port=port, reload=False), daemon=False) + ui_thread.start() + ui_thread.join() diff --git a/src/scaler/ui/v2.py b/src/scaler/ui/v2.py new file mode 100644 index 000000000..1df398bc3 --- /dev/null +++ b/src/scaler/ui/v2.py @@ -0,0 +1,100 @@ +import threading +from typing import Optional +from nicegui import Event, app, ui # type: ignore[attr-defined] +from scaler.config.types.zmq import ZMQConfig +from scaler.io.sync_subscriber import ZMQSyncSubscriber +from scaler.protocol.python.mixins import Message +from scaler.ui.common.constants import ( + MEMORY_USAGE_UPDATE_INTERVAL, + TASK_LOG_REFRESH_INTERVAL, + TASK_STREAM_UPDATE_INTERVAL, + WORKER_PROCESSORS_REFRESH_INTERVAL, +) +from scaler.ui.common.live_display import SchedulerSection, WorkersSection +from scaler.ui.common.memory_window import MemoryChart +from scaler.ui.common.setting_page import Settings +from scaler.ui.common.task_graph import TaskStream +from scaler.ui.common.task_log import TaskLogTable +from scaler.ui.common.webui import Sections, process_scheduler_message +from scaler.ui.common.worker_processors import WorkerProcessors + + +class WebUI: + def __init__(self) -> None: + self.scheduler_message = Event[Message]() + self.tables: Optional[Sections] = None + + def start(self, host: str, port: int) -> None: + """Start the NiceGUI server in a separate thread.""" + started = threading.Event() + app.on_startup(started.set) + thread = threading.Thread( + target=lambda: ui.run(self.root, host=host, port=port, reload=False), # type: ignore[misc,arg-type] + daemon=True, + ) + thread.start() + if not started.wait(timeout=3.0): + raise RuntimeError("NiceGUI did not start within 3 seconds.") + + def root(self) -> None: + """Create the UI for each new visitor.""" + self.scheduler_message.subscribe(self.handle_message) + tables = Sections( + scheduler_section=SchedulerSection(), + workers_section=WorkersSection(), + task_stream_section=TaskStream(), + memory_usage_section=MemoryChart(), + tasklog_section=TaskLogTable(), + worker_processors=WorkerProcessors(), + settings_section=Settings(), + ) + self.tables = tables + + with ui.tabs().classes("w-full h-full") as tabs: + live_tab = ui.tab("Live") + tasklog_tab = ui.tab("Task Log") + stream_tab = ui.tab("Worker Task Stream") + worker_processors_tab = ui.tab("Worker Processors") + settings_tab = ui.tab("Settings") + + with ui.tab_panels(tabs, value=live_tab).classes("w-full"): + with ui.tab_panel(live_tab): + tables.scheduler_section.draw_section() + tables.workers_section.draw_section() # type: ignore[call-arg] + + with ui.tab_panel(tasklog_tab): + tables.tasklog_section.draw_section() # type: ignore[call-arg] + ui.timer(TASK_LOG_REFRESH_INTERVAL, tables.tasklog_section.draw_section.refresh, active=True) + + with ui.tab_panel(stream_tab): + tables.task_stream_section.setup_task_stream(tables.settings_section) + ui.timer(TASK_STREAM_UPDATE_INTERVAL, tables.task_stream_section.update_plot, active=True) + + tables.memory_usage_section.setup_memory_chart(tables.settings_section) + ui.timer(MEMORY_USAGE_UPDATE_INTERVAL, tables.memory_usage_section.update_plot, active=True) + + with ui.tab_panel(worker_processors_tab): + tables.worker_processors.draw_section() # type: ignore[call-arg] + ui.timer(WORKER_PROCESSORS_REFRESH_INTERVAL, tables.worker_processors.draw_section.refresh, active=True) + + with ui.tab_panel(settings_tab): + tables.settings_section.draw_section() + + def new_message(self, status: Message): + self.scheduler_message.emit(status) + + def handle_message(self, status: Message): + process_scheduler_message(status, self.tables) + + +def start_webui_v2(address: str, host: str, port: int): + webui = WebUI() + webui.start(host, port) + + subscriber = ZMQSyncSubscriber( + address=ZMQConfig.from_string(address), callback=webui.new_message, topic=b"", timeout_seconds=-1 + ) + subscriber.start() + + while True: + pass diff --git a/src/scaler/ui/webui.py b/src/scaler/ui/webui.py index bbcbe64c7..521f7de0f 100644 --- a/src/scaler/ui/webui.py +++ b/src/scaler/ui/webui.py @@ -1,42 +1,9 @@ -import dataclasses import logging -import threading -from functools import partial from typing import Optional, Tuple - -from nicegui import ui - -from scaler.config.types.zmq import ZMQConfig -from scaler.io.sync_subscriber import ZMQSyncSubscriber -from scaler.protocol.python.message import StateBalanceAdvice, StateScheduler, StateTask, StateWorker -from scaler.protocol.python.mixins import Message -from scaler.ui.constants import ( - MEMORY_USAGE_UPDATE_INTERVAL, - TASK_LOG_REFRESH_INTERVAL, - TASK_STREAM_UPDATE_INTERVAL, - WORKER_PROCESSORS_REFRESH_INTERVAL, -) -from scaler.ui.live_display import SchedulerSection, WorkersSection -from scaler.ui.memory_window import MemoryChart -from scaler.ui.setting_page import Settings -from scaler.ui.task_graph import TaskStream -from scaler.ui.task_log import TaskLogTable -from scaler.ui.worker_processors import WorkerProcessors -from scaler.utility.formatter import format_bytes, format_percentage +from scaler.ui.util import NICEGUI_MAJOR_VERSION from scaler.utility.logging.utility import setup_logger -@dataclasses.dataclass -class Sections: - scheduler_section: SchedulerSection - workers_section: WorkersSection - task_stream_section: TaskStream - memory_usage_section: MemoryChart - tasklog_section: TaskLogTable - worker_processors: WorkerProcessors - settings_section: Settings - - def start_webui( address: str, host: str, @@ -48,100 +15,13 @@ def start_webui( setup_logger(logging_paths, logging_config_file, logging_level) - tables = Sections( - scheduler_section=SchedulerSection(), - workers_section=WorkersSection(), - task_stream_section=TaskStream(), - memory_usage_section=MemoryChart(), - tasklog_section=TaskLogTable(), - worker_processors=WorkerProcessors(), - settings_section=Settings(), - ) - - with ui.tabs().classes("w-full h-full") as tabs: - live_tab = ui.tab("Live") - tasklog_tab = ui.tab("Task Log") - stream_tab = ui.tab("Worker Task Stream") - worker_processors_tab = ui.tab("Worker Processors") - settings_tab = ui.tab("Settings") - - with ui.tab_panels(tabs, value=live_tab).classes("w-full"): - with ui.tab_panel(live_tab): - tables.scheduler_section.draw_section() - tables.workers_section.draw_section() # type: ignore[call-arg] - - with ui.tab_panel(tasklog_tab): - tables.tasklog_section.draw_section() # type: ignore[call-arg] - ui.timer(TASK_LOG_REFRESH_INTERVAL, tables.tasklog_section.draw_section.refresh, active=True) - - with ui.tab_panel(stream_tab): - tables.task_stream_section.setup_task_stream(tables.settings_section) - ui.timer(TASK_STREAM_UPDATE_INTERVAL, tables.task_stream_section.update_plot, active=True) - - tables.memory_usage_section.setup_memory_chart(tables.settings_section) - ui.timer(MEMORY_USAGE_UPDATE_INTERVAL, tables.memory_usage_section.update_plot, active=True) - - with ui.tab_panel(worker_processors_tab): - tables.worker_processors.draw_section() # type: ignore[call-arg] - ui.timer(WORKER_PROCESSORS_REFRESH_INTERVAL, tables.worker_processors.draw_section.refresh, active=True) - - with ui.tab_panel(settings_tab): - tables.settings_section.draw_section() - - subscriber = ZMQSyncSubscriber( - address=ZMQConfig.from_string(address), - callback=partial(__show_status, tables=tables), - topic=b"", - timeout_seconds=-1, - ) - subscriber.start() - - ui_thread = threading.Thread(target=partial(ui.run, host=host, port=port, reload=False), daemon=False) - ui_thread.start() - ui_thread.join() - - -def __show_status(status: Message, tables: Sections): - if isinstance(status, StateScheduler): - __update_scheduler_state(status, tables) - return - - if isinstance(status, StateWorker): - logging.info(f"Received StateWorker update for worker {status.worker_id.decode()} with {status.state.name}") - tables.scheduler_section.handle_worker_state(status) - tables.workers_section.handle_worker_state(status) - tables.task_stream_section.handle_worker_state(status) - tables.memory_usage_section.handle_worker_state(status) - tables.tasklog_section.handle_worker_state(status) - tables.worker_processors.handle_worker_state(status) - tables.settings_section.handle_worker_state(status) - return - - if isinstance(status, StateTask): - logging.debug(f"Received StateTask update for task {status.task_id.hex()} with {status.state.name}") - tables.scheduler_section.handle_task_state(status) - tables.workers_section.handle_task_state(status) - tables.task_stream_section.handle_task_state(status) - tables.memory_usage_section.handle_task_state(status) - tables.tasklog_section.handle_task_state(status) - tables.worker_processors.handle_task_state(status) - tables.settings_section.handle_task_state(status) - return - - if isinstance(status, StateBalanceAdvice): - logging.debug(f"Received StateBalanceAdvice for {status.worker_id.decode()} with {len(status.task_ids)} tasks") - return - - logging.info(f"Unhandled message received: {type(status)}") - - -def __update_scheduler_state(data: StateScheduler, tables: Sections): - tables.scheduler_section.cpu = format_percentage(data.scheduler.cpu) - tables.scheduler_section.rss = format_bytes(data.scheduler.rss) - tables.scheduler_section.rss_free = format_bytes(data.rss_free) + if NICEGUI_MAJOR_VERSION < 3: + logging.info(f"Detected {NICEGUI_MAJOR_VERSION}. Using GUI v1.") + from scaler.ui.v1 import start_webui_v1 - for worker_data in data.worker_manager.workers: - worker_name = worker_data.worker_id.decode() - tables.workers_section.workers[worker_name].populate(worker_data) + start_webui_v1(address, host, port) + else: + logging.info(f"Detected {NICEGUI_MAJOR_VERSION}. Using GUI v2.") + from scaler.ui.v2 import start_webui_v2 - tables.worker_processors.update_data(data.worker_manager.workers) + start_webui_v2(address, host, port) diff --git a/src/scaler/version.txt b/src/scaler/version.txt index 7ff4a620a..4737b5812 100644 --- a/src/scaler/version.txt +++ b/src/scaler/version.txt @@ -1 +1 @@ -1.12.28 +1.12.35 diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index f3757afd7..c6fc70ba3 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -1,6 +1,8 @@ # this fetches Google Test, so it must be included first. if(LINUX OR APPLE) add_subdirectory(object_storage) +else() + message(WARNING "Not building OSS tests, as it's not supported on this system currently!") endif() add_subdirectory(ymq) diff --git a/tests/cpp/ymq/CMakeLists.txt b/tests/cpp/ymq/CMakeLists.txt index cfe191b12..318ff3f33 100644 --- a/tests/cpp/ymq/CMakeLists.txt +++ b/tests/cpp/ymq/CMakeLists.txt @@ -1,2 +1,28 @@ add_test_executable(test_ymq test_ymq.cpp) +target_sources(test_ymq PRIVATE + common/testing.h + common/utils.h + common/utils.cpp + + net/socket.h + + pipe/pipe.h + pipe/pipe_utils.h + pipe/pipe_reader.h + pipe/pipe_writer.h +) + +if(LINUX OR APPLE) + target_sources(test_ymq PRIVATE + pipe/pipe_utils_linux.cpp + pipe/pipe_reader_linux.cpp + pipe/pipe_writer_linux.cpp + net/socket_linux.cpp) +elseif(WIN32) + target_sources(test_ymq PRIVATE + pipe/pipe_utils_windows.cpp + pipe/pipe_reader_windows.cpp + pipe/pipe_writer_windows.cpp + net/socket_windows.cpp) +endif() diff --git a/tests/cpp/ymq/common.h b/tests/cpp/ymq/common.h deleted file mode 100644 index aee660053..000000000 --- a/tests/cpp/ymq/common.h +++ /dev/null @@ -1,591 +0,0 @@ -#pragma once - -#define PY_SSIZE_T_CLEAN -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std::chrono_literals; - -enum class TestResult : char { Success = 1, Failure = 2 }; - -inline TestResult return_failure_if_false( - bool cond, const char* msg = nullptr, const char* cond_str = nullptr, const char* file = nullptr, int line = 0) -{ - // Failure: ... (assertion failed) at file:line - if (!cond) { - std::cerr << "Failure"; - if (cond_str) - std::cerr << ": " << cond_str; - if (msg) - std::cerr << " (" << msg << ")"; - else - std::cerr << " (assertion failed)"; - if (file) - std::cerr << " at " << file << ":" << line; - std::cerr << '\n'; - return TestResult::Failure; - } - return TestResult::Success; -} - -// in the case that there's no msg, delegate -inline TestResult return_failure_if_false(bool cond, const char* cond_str, const char* file, int line) -{ - return return_failure_if_false(cond, nullptr, cond_str, file, line); -} - -#define RETURN_FAILURE_IF_FALSE(cond, ...) \ - if (return_failure_if_false((cond), ##__VA_ARGS__, #cond, __FILE__, __LINE__) == TestResult::Failure) \ - return TestResult::Failure; - -inline const char* check_localhost(const char* host) -{ - return std::strcmp(host, "localhost") == 0 ? "127.0.0.1" : host; -} - -inline std::string format_address(std::string host, uint16_t port) -{ - return std::format("tcp://{}:{}", check_localhost(host.c_str()), port); -} - -class OwnedFd { -public: - int fd; - - OwnedFd(int fd): fd(fd) {} - - // move-only - OwnedFd(const OwnedFd&) = delete; - OwnedFd& operator=(const OwnedFd&) = delete; - OwnedFd(OwnedFd&& other) noexcept: fd(other.fd) { other.fd = 0; } - OwnedFd& operator=(OwnedFd&& other) noexcept - { - if (this != &other) { - this->fd = other.fd; - other.fd = 0; - } - return *this; - } - - ~OwnedFd() - { - if (fd > 0 && close(fd) < 0) - std::cerr << "failed to close fd!" << std::endl; - } - - size_t write(const void* data, size_t len) - { - auto n = ::write(this->fd, data, len); - if (n < 0) - throw std::system_error(errno, std::generic_category(), "failed to write to socket"); - - return n; - } - - void write_all(const char* data, size_t len) - { - for (size_t cursor = 0; cursor < len;) - cursor += this->write(data + cursor, len - cursor); - } - - void write_all(std::string data) { this->write_all(data.data(), data.length()); } - - void write_all(std::vector data) { this->write_all(data.data(), data.size()); } - - size_t read(void* buffer, size_t len) - { - auto n = ::read(this->fd, buffer, len); - if (n < 0) - throw std::system_error(errno, std::generic_category(), "failed to read from socket"); - return n; - } - - void read_exact(char* buffer, size_t len) - { - for (size_t cursor = 0; cursor < len;) - cursor += this->read(buffer + cursor, len - cursor); - } - - operator int() { return fd; } -}; - -class Socket: public OwnedFd { -public: - Socket(int fd): OwnedFd(fd) {} - - void connect(const char* host, uint16_t port, bool nowait = false) - { - sockaddr_in addr { - .sin_family = AF_INET, - .sin_port = htons(port), - .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, - .sin_zero = {0}}; - - connect: - if (::connect(this->fd, (sockaddr*)&addr, sizeof(addr)) < 0) { - if (errno == ECONNREFUSED && !nowait) { - std::this_thread::sleep_for(300ms); - goto connect; - } - - throw std::system_error(errno, std::generic_category(), "failed to connect"); - } - } - - void bind(const char* host, int port) - { - sockaddr_in addr { - .sin_family = AF_INET, - .sin_port = htons(port), - .sin_addr = {.s_addr = inet_addr(check_localhost(host))}, - .sin_zero = {0}}; - - auto status = ::bind(this->fd, (sockaddr*)&addr, sizeof(addr)); - if (status < 0) - throw std::system_error(errno, std::generic_category(), "failed to bind"); - } - - void listen(int n = 32) - { - auto status = ::listen(this->fd, n); - if (status < 0) - throw std::system_error(errno, std::generic_category(), "failed to listen on socket"); - } - - std::pair accept(int flags = 0) - { - sockaddr_in peer_addr {}; - socklen_t len = sizeof(peer_addr); - auto fd = ::accept4(this->fd, (sockaddr*)&peer_addr, &len, flags); - if (fd < 0) - throw std::system_error(errno, std::generic_category(), "failed to accept socket"); - - return std::make_pair(Socket(fd), peer_addr); - } - - void write_message(std::string message) - { - uint64_t header = message.length(); - this->write_all((char*)&header, 8); - this->write_all(message.data(), message.length()); - } - - std::string read_message() - { - uint64_t header = 0; - this->read_exact((char*)&header, 8); - std::vector buffer(header); - this->read_exact(buffer.data(), header); - return std::string(buffer.data(), header); - } -}; - -class TcpSocket: public Socket { -public: - TcpSocket(bool nodelay = true): Socket(0) - { - this->fd = ::socket(AF_INET, SOCK_STREAM, 0); - if (this->fd < 0) - throw std::system_error(errno, std::generic_category(), "failed to create socket"); - - int on = 1; - if (nodelay && setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); - - if (setsockopt(this->fd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set reuseaddr"); - } - - void flush() - { - int on = 1; - int off = 0; - - if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&off, sizeof(off)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); - - if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); - - if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&off, sizeof(off)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); - - if (setsockopt(this->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&on, sizeof(on)) < 0) - throw std::system_error(errno, std::generic_category(), "failed to set nodelay"); - } -}; - -inline void fork_wrapper(std::function fn, int timeout_secs, OwnedFd pipe_wr) -{ - TestResult result = TestResult::Failure; - try { - result = fn(); - } catch (const std::exception& e) { - std::cerr << "Exception: " << e.what() << std::endl; - result = TestResult::Failure; - } catch (...) { - std::cerr << "Unknown exception" << std::endl; - result = TestResult::Failure; - } - - pipe_wr.write_all((char*)&result, sizeof(TestResult)); -} - -// this function along with `wait_for_python_ready_sigwait()` -// work together to wait on a signal from the python process -// indicating that the tuntap interface has been created, and that the mitm is ready -inline void wait_for_python_ready_sigblock() -{ - sigset_t set {}; - - if (sigemptyset(&set) < 0) - throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); - - if (sigaddset(&set, SIGUSR1) < 0) - throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); - - if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) - throw std::system_error(errno, std::generic_category(), "failed to mask sigusr1"); - - std::cout << "blocked signal..." << std::endl; -} - -inline void wait_for_python_ready_sigwait(int timeout_secs) -{ - sigset_t set {}; - siginfo_t sig {}; - - if (sigemptyset(&set) < 0) - throw std::system_error(errno, std::generic_category(), "failed to create empty signal set"); - - if (sigaddset(&set, SIGUSR1) < 0) - throw std::system_error(errno, std::generic_category(), "failed to add sigusr1 to the signal set"); - - std::cout << "waiting for python to be ready..." << std::endl; - timespec ts {.tv_sec = timeout_secs, .tv_nsec = 0}; - if (sigtimedwait(&set, &sig, &ts) < 0) - throw std::system_error(errno, std::generic_category(), "failed to wait on sigusr1"); - - sigprocmask(SIG_UNBLOCK, &set, nullptr); - std::cout << "signal received; python is ready" << std::endl; -} - -// run a test -// forks and runs each of the provided closures -// if `wait_for_python` is true, wait for SIGUSR1 after forking and executing the first closure -inline TestResult test( - int timeout_secs, std::vector> closures, bool wait_for_python = false) -{ - std::vector> pipes {}; - std::vector pids {}; - for (size_t i = 0; i < closures.size(); i++) { - int pipe[2] = {0}; - if (pipe2(pipe, O_NONBLOCK) < 0) { - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { - close(pipe.first); - close(pipe.second); - }); - - throw std::system_error(errno, std::generic_category(), "failed to create pipe: "); - } - pipes.push_back(std::make_pair(pipe[0], pipe[1])); - } - - for (size_t i = 0; i < closures.size(); i++) { - if (wait_for_python && i == 0) - wait_for_python_ready_sigblock(); - - auto pid = fork(); - if (pid < 0) { - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { - close(pipe.first); - close(pipe.second); - }); - - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - throw std::system_error(errno, std::generic_category(), "failed to fork"); - } - - if (pid == 0) { - // close all pipes except our write half - for (size_t j = 0; j < pipes.size(); j++) { - if (i == j) - close(pipes[i].first); - else { - close(pipes[j].first); - close(pipes[j].second); - } - } - - fork_wrapper(closures[i], timeout_secs, pipes[i].second); - std::exit(EXIT_SUCCESS); - } - - pids.push_back(pid); - - if (wait_for_python && i == 0) - wait_for_python_ready_sigwait(3); - } - - // close all write halves of the pipes - for (auto pipe: pipes) - close(pipe.second); - - std::vector pfds {}; - - OwnedFd timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); - if (timerfd < 0) { - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - throw std::system_error(errno, std::generic_category(), "failed to create timerfd"); - } - - pfds.push_back({.fd = timerfd.fd, .events = POLL_IN, .revents = 0}); - for (auto pipe: pipes) - pfds.push_back({ - .fd = pipe.first, - .events = POLL_IN, - .revents = 0, - }); - - itimerspec spec { - .it_interval = - { - .tv_sec = 0, - .tv_nsec = 0, - }, - .it_value = { - .tv_sec = timeout_secs, - .tv_nsec = 0, - }}; - - if (timerfd_settime(timerfd, 0, &spec, nullptr) < 0) { - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - throw std::system_error(errno, std::generic_category(), "failed to set timerfd"); - } - - std::vector> results(pids.size(), std::nullopt); - - for (;;) { - auto n = poll(pfds.data(), pfds.size(), -1); - if (n < 0) { - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - throw std::system_error(errno, std::generic_category(), "failed to poll: "); - } - - for (auto& pfd: std::vector(pfds)) { - if (pfd.revents == 0) - continue; - - // timed out - if (pfd.fd == timerfd) { - std::cout << "Timed out!\n"; - - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - return TestResult::Failure; - } - - auto elem = std::find_if(pipes.begin(), pipes.end(), [fd = pfd.fd](auto pipe) { return pipe.first == fd; }); - auto idx = elem - pipes.begin(); - - TestResult result = TestResult::Failure; - char buffer = 0; - auto n = read(pfd.fd, &buffer, sizeof(TestResult)); - if (n == 0) { - std::cout << "failed to read from pipe: pipe closed unexpectedly\n"; - result = TestResult::Failure; - } else if (n < 0) { - std::cout << "failed to read from pipe: " << std::strerror(errno) << std::endl; - result = TestResult::Failure; - } else - result = (TestResult)buffer; - - // the subprocess should have exited - // check its exit status - int status; - if (waitpid(pids[idx], &status, 0) < 0) - std::cout << "failed to wait on subprocess[" << idx << "]: " << std::strerror(errno) << std::endl; - - auto exit_status = WEXITSTATUS(status); - if (WIFEXITED(status) && exit_status != EXIT_SUCCESS) { - std::cout << "subprocess[" << idx << "] exited with status " << exit_status << std::endl; - } else if (WIFSIGNALED(status)) { - std::cout << "subprocess[" << idx << "] killed by signal " << WTERMSIG(status) << std::endl; - } else { - std::cout << "subprocess[" << idx << "] completed with " - << (result == TestResult::Success ? "Success" : "Failure") << std::endl; - } - - // store the result - results[idx] = result; - - // this subprocess is done, remove its pipe from the poll fds - pfds.erase(std::remove_if(pfds.begin(), pfds.end(), [&](auto p) { return p.fd == pfd.fd; }), pfds.end()); - - auto done = std::all_of(results.begin(), results.end(), [](auto result) { return result.has_value(); }); - if (done) - goto end; // justification for goto: breaks out of two levels of loop - } - } - -end: - - std::for_each(pipes.begin(), pipes.end(), [](const auto& pipe) { close(pipe.first); }); - - if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) - return TestResult::Failure; - - return TestResult::Success; -} - -inline TestResult run_python(const char* path, std::vector argv = {}) -{ - // insert the pid at the start of the argv, this is important for signalling readiness - pid_t pid = getppid(); - auto pid_ws = std::to_wstring(pid); - argv.insert(argv.begin(), pid_ws.c_str()); - - PyStatus status; - PyConfig config; - PyConfig_InitPythonConfig(&config); - - status = PyConfig_SetBytesString(&config, &config.program_name, "mitm"); - if (PyStatus_Exception(status)) - goto exception; - - argv.insert(argv.begin(), L"mitm"); - status = PyConfig_SetArgv(&config, argv.size(), (wchar_t**)argv.data()); - if (PyStatus_Exception(status)) - goto exception; - - // pass argv to the script as-is - config.parse_argv = 0; - - status = Py_InitializeFromConfig(&config); - if (PyStatus_Exception(status)) - goto exception; - PyConfig_Clear(&config); - - // add the cwd to the path - { - PyObject* sysPath = PySys_GetObject("path"); - PyObject* newPath = PyUnicode_FromString("."); - PyList_Append(sysPath, newPath); - Py_DECREF(newPath); - } - - { - auto file = fopen(path, "r"); - if (!file) - throw std::system_error(errno, std::generic_category(), "failed to open python file"); - - PyRun_SimpleFile(file, path); - fclose(file); - } - - if (Py_FinalizeEx() < 0) { - std::cerr << "finalization failure" << std::endl; - return TestResult::Failure; - } - - return TestResult::Success; - -exception: - PyConfig_Clear(&config); - Py_ExitStatusException(status); - - return TestResult::Failure; -} - -// change the current working directory to the project root -// this is important for finding the python mitm script -inline void chdir_to_project_root() -{ - auto cwd = std::filesystem::current_path(); - - // if pyproject.toml is in `path`, it's the project root - for (auto path = cwd; !path.empty(); path = path.parent_path()) { - if (std::filesystem::exists(path / "pyproject.toml")) { - // change to the project root - std::filesystem::current_path(path); - return; - } - } -} - -inline TestResult run_mitm( - std::string testcase, - std::string mitm_ip, - uint16_t mitm_port, - std::string remote_ip, - uint16_t remote_port, - std::vector extra_args = {}) -{ - auto cwd = std::filesystem::current_path(); - chdir_to_project_root(); - - // we build the args for the user to make calling the function more convenient - std::vector args { - testcase, mitm_ip, std::to_string(mitm_port), remote_ip, std::to_string(remote_port)}; - - for (auto arg: extra_args) - args.push_back(arg); - - // we need to convert to wide strings to pass to Python - std::vector wide_args_owned {}; - - // the strings are ascii so we can just make them into wstrings - for (const auto& str: args) - wide_args_owned.emplace_back(str.begin(), str.end()); - - std::vector wide_args {}; - for (const auto& wstr: wide_args_owned) - wide_args.push_back(wstr.c_str()); - - auto result = run_python("tests/cpp/ymq/py_mitm/main.py", wide_args); - - // change back to the original working directory - std::filesystem::current_path(cwd); - return result; -} diff --git a/tests/cpp/ymq/common/testing.h b/tests/cpp/ymq/common/testing.h new file mode 100644 index 000000000..f85ef15c2 --- /dev/null +++ b/tests/cpp/ymq/common/testing.h @@ -0,0 +1,612 @@ +#pragma once + +#define PY_SSIZE_T_CLEAN + +// if on Windows and in debug mode, undefine _DEBUG before including Python.h +// this prevents issues including the debug version of the Python library +#if defined(_WIN32) && defined(_DEBUG) +#undef _DEBUG +#include +#define _DEBUG +#else +#include +#endif + +#include +#include +#include + +#ifdef __linux__ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif // __linux__ +#ifdef _WIN32 +#include +#include +#include +#include + +// the windows timer apis work in 100-nanosecond units +const LONGLONG ns_per_second = 1'000'000'000LL; +const LONGLONG ns_per_unit = 100LL; // 1 unit = 100 nanoseconds + +#define popen _popen +#define pclose _pclose +#endif // _WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/socket.h" +#include "tests/cpp/ymq/pipe/pipe.h" + +using namespace std::chrono_literals; + +enum class TestResult : char { Success = 1, Failure = 2 }; + +inline TestResult return_failure_if_false( + bool cond, const char* msg = nullptr, const char* cond_str = nullptr, const char* file = nullptr, int line = 0) +{ + // Failure: ... (assertion failed) at file:line + if (!cond) { + std::cerr << "Failure"; + if (cond_str) + std::cerr << ": " << cond_str; + if (msg) + std::cerr << " (" << msg << ")"; + else + std::cerr << " (assertion failed)"; + if (file) + std::cerr << " at " << file << ":" << line; + std::cerr << '\n'; + return TestResult::Failure; + } + return TestResult::Success; +} + +// in the case that there's no msg, delegate +inline TestResult return_failure_if_false(bool cond, const char* cond_str, const char* file, int line) +{ + return return_failure_if_false(cond, nullptr, cond_str, file, line); +} + +#define RETURN_FAILURE_IF_FALSE(cond, ...) \ + if (return_failure_if_false((cond), ##__VA_ARGS__, #cond, __FILE__, __LINE__) == TestResult::Failure) \ + return TestResult::Failure; + +// hEvent: unused on linux, event handle on windows +inline void fork_wrapper(std::function fn, int timeout_secs, PipeWriter pipe_wr, void* hEvent) +{ + TestResult result = TestResult::Failure; + try { + result = fn(); + } catch (const std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + result = TestResult::Failure; + } catch (...) { + std::cerr << "Unknown exception" << std::endl; + result = TestResult::Failure; + } + + pipe_wr.write_all((char*)&result, sizeof(TestResult)); + +#ifdef _WIN32 + SetEvent((HANDLE)hEvent); +#endif // _WIN32 +} + +// this function along with `wait_for_python_ready_sigwait()` +// work together to wait on a signal from the python process +// indicating that the tuntap interface has been created, and that the mitm is ready +// +// hEvent is an output parameter for windows but unused on linux +inline void wait_for_python_ready_sigblock(void** hEvent) +{ +#ifdef __linux__ + sigset_t set {}; + + if (sigemptyset(&set) < 0) + raise_system_error("failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + raise_system_error("failed to add sigusr1 to the signal set"); + + if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) + raise_system_error("failed to mask sigusr1"); + +#endif // __linux__ +#ifdef _WIN32 + // TODO: implement signaling of this event in the python mitm + *hEvent = CreateEvent( + NULL, // default security attributes + FALSE, // auto-reset event + FALSE, // initial state is nonsignaled + "Global\\PythonSignal"); // name of the event + if (*hEvent == NULL) + raise_system_error("failed to create event"); +#endif // _WIN32 + + std::cout << "blocked signal..." << std::endl; +} + +// as in the above function, hEvent is unused on linux +inline void wait_for_python_ready_sigwait(void* hEvent, int timeout_secs) +{ + std::cout << "waiting for python to be ready..." << std::endl; + +#ifdef __linux__ + timespec ts {.tv_sec = timeout_secs, .tv_nsec = 0}; + sigset_t set {}; + siginfo_t sig {}; + + if (sigemptyset(&set) < 0) + raise_system_error("failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + raise_system_error("failed to add sigusr1 to the signal set"); + + if (sigtimedwait(&set, &sig, &ts) < 0) + raise_system_error("failed to wait on sigusr1"); + + sigprocmask(SIG_UNBLOCK, &set, nullptr); + +#endif // __linux__ +#ifdef _WIN32 + DWORD waitResult = WaitForSingleObject(hEvent, timeout_secs * 1000); + if (waitResult != WAIT_OBJECT_0) { + raise_system_error("failed to wait on event"); + } + CloseHandle(hEvent); +#endif // _WIN32 + + std::cout << "signal received; python is ready" << std::endl; +} + +// run a test +// forks and runs each of the provided closures +// if `wait_for_python` is true, wait for SIGUSR1 after forking and executing the first closure +inline TestResult test( + int timeout_secs, std::vector> closures, bool wait_for_python = false) +{ + std::vector pipes {}; + + for (size_t i = 0; i < closures.size(); i++) + pipes.emplace_back(); + +#ifdef __linux__ + std::vector pids {}; + void* hEvent = nullptr; + for (size_t i = 0; i < closures.size(); i++) { + if (wait_for_python && i == 0) + wait_for_python_ready_sigblock(&hEvent); + + auto pid = fork(); + if (pid < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + raise_system_error("failed to fork"); + } + + if (pid == 0) { + fork_wrapper(closures[i], timeout_secs, std::move(pipes[i].writer), nullptr); + std::exit(EXIT_SUCCESS); + } + + pids.push_back(pid); + + if (wait_for_python && i == 0) + wait_for_python_ready_sigwait(&hEvent, 3); + } + + std::vector pfds {}; + + int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); + if (timerfd < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + raise_system_error("failed to create timerfd"); + } + + pfds.push_back({.fd = timerfd, .events = POLL_IN, .revents = 0}); + for (const auto& pipe: pipes) + pfds.push_back({ + .fd = (int)pipe.reader.fd(), + .events = POLL_IN, + .revents = 0, + }); + + itimerspec spec { + .it_interval = + { + .tv_sec = 0, + .tv_nsec = 0, + }, + .it_value = { + .tv_sec = timeout_secs, + .tv_nsec = 0, + }}; + + if (timerfd_settime(timerfd, 0, &spec, nullptr) < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerfd); + + raise_system_error("failed to set timerfd"); + } + + std::vector> results(pids.size(), std::nullopt); + + for (;;) { + auto n = poll(pfds.data(), pfds.size(), -1); + if (n < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerfd); + + raise_system_error("failed to poll"); + } + + for (auto& pfd: std::vector(pfds)) { + if (pfd.revents == 0) + continue; + + // timed out + if (pfd.fd == timerfd) { + std::cout << "Timed out!\n"; + + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerfd); + + return TestResult::Failure; + } + + auto elem = std::find_if( + pipes.begin(), pipes.end(), [fd = pfd.fd](const auto& pipe) { return pipe.reader.fd() == fd; }); + auto idx = elem - pipes.begin(); + + TestResult result = TestResult::Failure; + char buffer = 0; + auto n = read(pfd.fd, &buffer, sizeof(TestResult)); + if (n == 0) { + std::cout << "failed to read from pipe: pipe closed unexpectedly\n"; + result = TestResult::Failure; + } else if (n < 0) { + std::cout << "failed to read from pipe: " << std::strerror(errno) << std::endl; + result = TestResult::Failure; + } else + result = (TestResult)buffer; + + // the subprocess should have exited + // check its exit status + int status; + if (waitpid(pids[idx], &status, 0) < 0) + std::cout << "failed to wait on subprocess[" << idx << "]: " << std::strerror(errno) << std::endl; + + auto exit_status = WEXITSTATUS(status); + if (WIFEXITED(status) && exit_status != EXIT_SUCCESS) { + std::cout << "subprocess[" << idx << "] exited with status " << exit_status << std::endl; + } else if (WIFSIGNALED(status)) { + std::cout << "subprocess[" << idx << "] killed by signal " << WTERMSIG(status) << std::endl; + } else { + std::cout << "subprocess[" << idx << "] completed with " + << (result == TestResult::Success ? "Success" : "Failure") << std::endl; + } + + // store the result + results[idx] = result; + + // this subprocess is done, remove its pipe from the poll fds + pfds.erase( + std::remove_if(pfds.begin(), pfds.end(), [&](const auto& p) { return p.fd == pfd.fd; }), pfds.end()); + + auto done = + std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + } + +end: + close(timerfd); + + if (std::ranges::any_of(results, [](const auto& x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +#endif // __linux__ +#ifdef _WIN32 + std::vector events {}; + std::vector threads {}; + + for (size_t i = 0; i < closures.size(); i++) { + HANDLE hEvent = CreateEvent( + nullptr, // default security attributes + true, // auto-reset event + false, // initial state is nonsignaled + nullptr); // unnamed event + if (!hEvent) + raise_system_error("failed to create event"); + events.push_back(hEvent); + } + + for (size_t i = 0; i < closures.size(); i++) { + HANDLE hEvent = nullptr; + if (wait_for_python && i == 0) + wait_for_python_ready_sigblock(&hEvent); + + threads.emplace_back(fork_wrapper, closures[i], timeout_secs, std::move(pipes[i].writer), events[i]); + + if (wait_for_python && i == 0) + wait_for_python_ready_sigwait(hEvent, 3); + } + + HANDLE timer = CreateWaitableTimer(nullptr, true, nullptr); + if (!timer) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + raise_system_error("failed to create waitable timer"); + } + + LARGE_INTEGER expires_in = {0}; + + // negative value indicates relative time + expires_in.QuadPart = -static_cast(timeout_secs) * ns_per_second / ns_per_unit; + if (!SetWaitableTimer(timer, &expires_in, 0, nullptr, nullptr, false)) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + raise_system_error("failed to set waitable timer"); + } + + // these are the handles we're going to poll + std::vector wait_handles {timer}; + + // poll all read halves of the pipes + for (const auto& ev: events) + wait_handles.push_back(ev); + + std::vector> results(threads.size(), std::nullopt); + + for (;;) { + DWORD waitResult = WaitForMultipleObjects(wait_handles.size(), wait_handles.data(), false, INFINITE); + if (waitResult == WAIT_FAILED) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + raise_system_error("failed to wait on handles"); + } + + // the idx of the handle in the handles array + // note that index 0 is the timer + // and we adjust the handles array as tasks complete + // so we need an extra step to calculate the index in `closure`-space + size_t wait_idx = (size_t)waitResult - WAIT_OBJECT_0; + + // timed out + if (wait_idx == 0) { + std::cout << "Timed out!\n"; + std::for_each(threads.begin(), threads.end(), [](auto& t) { + t.request_stop(); + t.detach(); + }); + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + return TestResult::Failure; + } + + // find the idx + const auto& hEvent = wait_handles[wait_idx]; + auto event_it = std::find_if(events.begin(), events.end(), [hEvent](const auto& ev) { return ev == hEvent; }); + const auto idx = event_it - events.begin(); + auto& pipe = pipes[idx]; + TestResult result = TestResult::Failure; + char buffer = 0; + try { + pipe.reader.read_exact(&buffer, sizeof(TestResult)); + result = (TestResult)buffer; + } catch (const std::system_error& e) { + std::cout << "failed to read from pipe: " << e.what() << std::endl; + result = TestResult::Failure; + } + + std::cout << "subprocess[" << idx << "] completed with " + << (result == TestResult::Success ? "Success" : "Failure") << std::endl; + + // store the result + results[idx] = result; + + // this subprocess is done, remove its pipe from the handles + wait_handles.erase( + std::remove_if(wait_handles.begin(), wait_handles.end(), [&](const auto& h) { return h == hEvent; }), + wait_handles.end()); + auto done = std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + +end: + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + + if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +#endif // _WIN32 +} + +inline std::wstring discover_python_home(std::string command) +{ + // leverage the system's command line to get the current python prefix + FILE* pipe = popen(std::format("{} -c \"import sys; print(sys.prefix)\"", command).c_str(), "r"); + if (!pipe) + throw std::runtime_error("failed to start python process to discover prefix"); + + std::array buffer {}; + std::string output {}; + + size_t n; + while ((n = fread(buffer.data(), 1, buffer.size(), pipe)) > 0) + output.append(buffer.data(), n); + + // remove trailing whitespace + output.erase(output.find_last_not_of("\r\n") + 1); + + auto status = pclose(pipe); + if (status < 0) + throw std::runtime_error("failed to close close process"); + else if (status > 0) + throw std::runtime_error("process returned non-zero exit code: " + std::to_string(status)); + + // assume it's ascii, so we can just cast it as a wstring + return std::wstring(output.begin(), output.end()); +} + +inline void ensure_python_initialized() +{ + if (Py_IsInitialized()) + return; + +#ifdef _WIN32 + auto python_home = discover_python_home("python"); + Py_SetPythonHome(python_home.c_str()); +#endif // _WIN32 + + Py_Initialize(); + + // add the cwd to the path + { + PyObject* sysPath = PySys_GetObject("path"); + if (!sysPath) + throw std::runtime_error("failed to get sys.path"); + + PyObject* newPath = PyUnicode_FromString("."); + if (!newPath) + throw std::runtime_error("failed to create Python string"); + + if (PyList_Append(sysPath, newPath) < 0) { + Py_DECREF(newPath); + throw std::runtime_error("failed to append to sys.path"); + } + + Py_DECREF(newPath); + } + + // release the GIL, the caller will have to acquire it again + PyEval_SaveThread(); +} + +inline void maybe_finalize_python() +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + if (!Py_IsInitialized()) + return; + + Py_Finalize(); + + // stop compiler from complaining that it's unused + (void)gstate; +} + +inline TestResult run_python(const char* path, std::vector> argv = {}) +{ + // ensure_python_initialized(); + PyGILState_STATE gstate = PyGILState_Ensure(); + +// insert the pid at the start of the argv, this is important for signalling readiness +#ifdef __linux__ + pid_t pid = getppid(); +#endif // __linux__ +#ifdef _WIN32 + DWORD pid = GetCurrentProcessId(); +#endif // _WIN32 + + auto pid_s = std::to_string(pid); + argv.insert(argv.begin(), pid_s.c_str()); + argv.insert(argv.begin(), "mitm"); + + // set argv + { + PyObject* py_argv = PyList_New(argv.size()); + if (!py_argv) + goto exception; + + for (size_t i = 0; i < argv.size(); i++) + if (argv[i]) + PyList_SET_ITEM(py_argv, i, PyUnicode_FromString(argv[i].value().c_str())); + else + PyList_SET_ITEM(py_argv, i, Py_None); + + if (PySys_SetObject("argv", py_argv) < 0) + goto exception; + + Py_DECREF(py_argv); + } + + { + std::ifstream file(path); + std::stringstream buffer; + buffer << file.rdbuf(); + + int rc = PyRun_SimpleString(buffer.str().c_str()); + file.close(); + + if (rc < 0) + throw std::runtime_error("failed to run python script"); + } + + PyGILState_Release(gstate); + return TestResult::Success; + +exception: + PyGILState_Release(gstate); + return TestResult::Failure; +} + +inline TestResult run_mitm( + std::string testcase, + std::string mitm_ip, + uint16_t mitm_port, + std::string remote_ip, + uint16_t remote_port, + std::vector extra_args = {}) +{ + auto cwd = std::filesystem::current_path(); + chdir_to_project_root(); + + // we build the args for the user to make calling the function more convenient + std::vector> args { + testcase, mitm_ip, std::to_string(mitm_port), remote_ip, std::to_string(remote_port)}; + + for (auto arg: extra_args) + args.push_back(arg); + + auto result = run_python("tests/cpp/ymq/py_mitm/main.py", args); + + // change back to the original working directory + std::filesystem::current_path(cwd); + return result; +} diff --git a/tests/cpp/ymq/common/utils.cpp b/tests/cpp/ymq/common/utils.cpp new file mode 100644 index 000000000..ba5f15ba8 --- /dev/null +++ b/tests/cpp/ymq/common/utils.cpp @@ -0,0 +1,69 @@ +#include "tests/cpp/ymq/common/utils.h" + +#include +#include +#include +#include + +#ifdef __linux__ +#include +#endif // __linux__ +#ifdef _WIN32 +#include +#include +#endif // _WIN32 + +void raise_system_error(const char* msg) +{ +#ifdef __linux__ + throw std::system_error(errno, std::generic_category(), msg); +#endif // __linux__ +#ifdef _WIN32 + throw std::system_error(GetLastError(), std::generic_category(), msg); +#endif // _WIN32 +} + +void raise_socket_error(const char* msg) +{ +#ifdef __linux__ + throw std::system_error(errno, std::generic_category(), msg); +#endif // __linux__ +#ifdef _WIN32 + throw std::system_error(WSAGetLastError(), std::generic_category(), msg); +#endif // _WIN32 +} + +const char* check_localhost(const char* host) +{ + return std::strcmp(host, "localhost") == 0 ? "127.0.0.1" : host; +} + +std::string format_address(std::string host, uint16_t port) +{ + std::ostringstream oss; + oss << "tcp://" << check_localhost(host.c_str()) << ":" << port; + return oss.str(); +} + +// change the current working directory to the project root +// this is important for finding the python mitm script +void chdir_to_project_root() +{ + auto cwd = std::filesystem::current_path(); + + // if pyproject.toml is in `path`, it's the project root + for (auto path = cwd; !path.empty(); path = path.parent_path()) { + if (std::filesystem::exists(path / "pyproject.toml")) { + // change to the project root + std::filesystem::current_path(path); + return; + } + } +} + +unsigned short random_port(unsigned short min_port, unsigned short max_port) +{ + static thread_local std::mt19937_64 rng(std::random_device {}()); + std::uniform_int_distribution dist(min_port, max_port); + return static_cast(dist(rng)); +} diff --git a/tests/cpp/ymq/common/utils.h b/tests/cpp/ymq/common/utils.h new file mode 100644 index 000000000..2861ec5e6 --- /dev/null +++ b/tests/cpp/ymq/common/utils.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +// throw an error with the last system error code +void raise_system_error(const char* msg); + +// throw wan error with the last socket error code +void raise_socket_error(const char* msg); + +// returns its input unless it matches "localhost" in which case 127.0.0.1 is returned +const char* check_localhost(const char* host); + +// formats a tcp:// address +std::string format_address(std::string host, uint16_t port); + +// change the current working directory to the project root +// this is important for finding the python mitm script +void chdir_to_project_root(); + +unsigned short random_port(unsigned short min_port = 1024, unsigned short max_port = 65535); diff --git a/tests/cpp/ymq/net/socket.h b/tests/cpp/ymq/net/socket.h new file mode 100644 index 000000000..ffad3f3a5 --- /dev/null +++ b/tests/cpp/ymq/net/socket.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include +#include +#include + +class Socket { +public: + Socket(bool nodelay = false); + Socket(bool nodelay, long long fd); + ~Socket(); + + // move-only + Socket(Socket&&) noexcept; + Socket& operator=(Socket&&) noexcept; + Socket(const Socket&) = delete; + Socket& operator=(const Socket&) = delete; + + // try to connect, retrying up to `tries` times when the connection is refused + void try_connect(const std::string& host, short port, int tries = 10) const; + void bind(short port) const; + void listen(int backlog = 5) const; + Socket accept() const; + + // write an entire buffer + void write_all(const void* data, size_t size) const; + void write_all(std::string msg) const; + + // read exactly `size` bytes + void read_exact(void* buffer, size_t size) const; + + // write a message in the YMQ protocol + void write_message(std::string msg) const; + + // read a YMQ message + std::string read_message() const; + +private: + // the native handle for this pipe reader + // on Linux, this is a file descriptor + // on Windows, this is a SOCKET + long long _fd; + + // indicates if nodelay was set + bool _nodelay; + + // write up to `size` bytes + int write(const void* buffer, size_t size) const; + + // read up to `size` bytes + int read(void* buffer, size_t size) const; +}; diff --git a/tests/cpp/ymq/net/socket_linux.cpp b/tests/cpp/ymq/net/socket_linux.cpp new file mode 100644 index 000000000..b36f76e35 --- /dev/null +++ b/tests/cpp/ymq/net/socket_linux.cpp @@ -0,0 +1,150 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/socket.h" + +Socket::Socket(bool nodelay): _fd(-1), _nodelay(nodelay) +{ + this->_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (this->_fd < 0) + raise_socket_error("failed to create socket"); + + char on = 1; + if (this->_nodelay) + if (::setsockopt(this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) < 0) + raise_socket_error("failed to set nodelay"); +} + +Socket::Socket(bool nodelay, long long fd): _fd(fd), _nodelay(nodelay) +{ + char on = 1; + if (this->_nodelay) + if (::setsockopt(this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) < 0) + raise_socket_error("failed to set nodelay"); +} + +Socket::~Socket() +{ + close(this->_fd); +} + +Socket::Socket(Socket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; +} + +Socket& Socket::operator=(Socket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +void Socket::try_connect(const std::string& host, short port, int tries) const +{ + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + inet_pton(AF_INET, check_localhost(host.c_str()), &addr.sin_addr); + + for (int i = 0; i < tries; i++) { + auto code = ::connect(this->_fd, (sockaddr*)&addr, sizeof(addr)); + + if (code < 0) { + if (errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + raise_socket_error("failed to connect"); + } + + break; // success + } +} + +void Socket::bind(short port) const +{ + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = INADDR_ANY; + if (::bind(this->_fd, (sockaddr*)&addr, sizeof(addr)) < 0) + raise_socket_error("failed to bind"); +} + +void Socket::listen(int backlog) const +{ + if (::listen(this->_fd, backlog) < 0) + raise_socket_error("failed to listen"); +} + +Socket Socket::accept() const +{ + long long fd = ::accept(this->_fd, nullptr, nullptr); + if (fd < 0) + raise_socket_error("failed to accept"); + + return Socket(this->_nodelay, fd); +} + +int Socket::write(const void* buffer, size_t size) const +{ + int n = ::write(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to send"); + return n; +} + +void Socket::write_all(const void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} + +void Socket::write_all(std::string msg) const +{ + this->write_all(msg.data(), msg.size()); +} + +void Socket::write_message(std::string msg) const +{ + uint64_t header = msg.length(); + this->write_all(&header, 8); + this->write_all(msg.data(), msg.length()); +} + +int Socket::read(void* buffer, size_t size) const +{ + int n = ::read(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to recv"); + return n; +} + +void Socket::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} + +std::string Socket::read_message() const +{ + uint64_t header = 0; + this->read_exact(&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); +} diff --git a/tests/cpp/ymq/net/socket_windows.cpp b/tests/cpp/ymq/net/socket_windows.cpp new file mode 100644 index 000000000..8f8323334 --- /dev/null +++ b/tests/cpp/ymq/net/socket_windows.cpp @@ -0,0 +1,153 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/socket.h" + +Socket::Socket(bool nodelay): _fd(-1), _nodelay(nodelay) +{ + this->_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (this->_fd == SOCKET_ERROR) + raise_socket_error("failed to create socket"); + + char on = 1; + if (this->_nodelay) + if (::setsockopt((SOCKET)this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) == SOCKET_ERROR) + raise_socket_error("failed to set nodelay"); +} + +Socket::Socket(bool nodelay, long long fd): _fd(fd), _nodelay(nodelay) +{ + char on = 1; + if (this->_nodelay) + if (::setsockopt((SOCKET)this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) == SOCKET_ERROR) + raise_socket_error("failed to set nodelay"); +} + +Socket::~Socket() +{ + ::closesocket((SOCKET)this->_fd); +} + +Socket::Socket(Socket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; +} + +Socket& Socket::operator=(Socket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +void Socket::try_connect(const std::string& host, short port, int tries) const +{ + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + inet_pton(AF_INET, check_localhost(host.c_str()), &addr.sin_addr); + + for (int i = 0; i < tries; i++) { + auto code = ::connect((SOCKET)this->_fd, (sockaddr*)&addr, sizeof(addr)); + + if (code == SOCKET_ERROR) { + if (WSAGetLastError() == WSAECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + std::printf("fpppp %d\n", WSAGetLastError()); + + raise_socket_error("failed to connect"); + } + + break; // success + } +} + +void Socket::bind(short port) const +{ + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + addr.sin_addr.s_addr = INADDR_ANY; + if (::bind((SOCKET)this->_fd, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) + raise_socket_error("failed to bind"); +} + +void Socket::listen(int backlog) const +{ + if (::listen((SOCKET)this->_fd, backlog) == SOCKET_ERROR) + raise_socket_error("failed to listen"); +} + +Socket Socket::accept() const +{ + long long fd = ::accept((SOCKET)this->_fd, nullptr, nullptr); + if (fd == SOCKET_ERROR) + raise_socket_error("failed to accept"); + + return Socket(this->_nodelay, fd); +} + +int Socket::write(const void* buffer, size_t size) const +{ + auto n = ::send((SOCKET)this->_fd, static_cast(buffer), (int)size, 0); + if (n == SOCKET_ERROR) + raise_socket_error("failed to send data"); + return n; +} + +void Socket::write_all(const void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} + +void Socket::write_all(std::string msg) const +{ + this->write_all(msg.data(), msg.size()); +} + +void Socket::write_message(std::string msg) const +{ + uint64_t header = msg.length(); + this->write_all(&header, 8); + this->write_all(msg.data(), msg.length()); +} + +int Socket::read(void* buffer, size_t size) const +{ + auto n = ::recv((SOCKET)this->_fd, static_cast(buffer), (int)size, 0); + if (n == SOCKET_ERROR) + raise_socket_error("failed to receive data"); + return n; +} + +void Socket::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} + +std::string Socket::read_message() const +{ + uint64_t header = 0; + this->read_exact(&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); +} diff --git a/tests/cpp/ymq/pipe/pipe.h b/tests/cpp/ymq/pipe/pipe.h new file mode 100644 index 000000000..3c19d8a43 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe.h @@ -0,0 +1,37 @@ +#pragma once + +#include "tests/cpp/ymq/pipe/pipe_reader.h" +#include "tests/cpp/ymq/pipe/pipe_utils.h" +#include "tests/cpp/ymq/pipe/pipe_writer.h" + +struct Pipe { +public: + Pipe(): reader(-1), writer(-1) + { + std::pair pair = create_pipe(); + this->reader = PipeReader(pair.first); + this->writer = PipeWriter(pair.second); + } + + ~Pipe() = default; + + // Move-only + Pipe(Pipe&& other) noexcept: reader(-1), writer(-1) + { + this->reader = std::move(other.reader); + this->writer = std::move(other.writer); + } + + Pipe& operator=(Pipe&& other) noexcept + { + this->reader = std::move(other.reader); + this->writer = std::move(other.writer); + return *this; + } + + Pipe(const Pipe&) = delete; + Pipe& operator=(const Pipe&) = delete; + + PipeReader reader; + PipeWriter writer; +}; diff --git a/tests/cpp/ymq/pipe/pipe_reader.h b/tests/cpp/ymq/pipe/pipe_reader.h new file mode 100644 index 000000000..9d88398d1 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_reader.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +struct Pipe; + +class PipeReader { +public: + PipeReader(long long fd); + ~PipeReader(); + + // Move-only + PipeReader(PipeReader&&) noexcept; + PipeReader& operator=(PipeReader&&) noexcept; + PipeReader(const PipeReader&) = delete; + PipeReader& operator=(const PipeReader&) = delete; + + // read exactly `size` bytes + void read_exact(void* buffer, size_t size) const; + + // returns the native handle for this pipe reader + // on linux, this is a pointer to the file descriptor + // on windows, this is the HANDLE + const long long fd() const noexcept; + +private: + // the native handle for this pipe reader + // on Linux, this is a file descriptor + // on Windows, this is a HANDLE + long long _fd; + + // read up to `size` bytes + int read(void* buffer, size_t size) const; +}; diff --git a/tests/cpp/ymq/pipe/pipe_reader_linux.cpp b/tests/cpp/ymq/pipe/pipe_reader_linux.cpp new file mode 100644 index 000000000..abdf7201c --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_reader_linux.cpp @@ -0,0 +1,48 @@ +#include + +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe_reader.h" + +PipeReader::PipeReader(long long fd): _fd(fd) +{ +} + +PipeReader::~PipeReader() +{ + close(this->_fd); +} + +PipeReader::PipeReader(PipeReader&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; +} + +PipeReader& PipeReader::operator=(PipeReader&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +const long long PipeReader::fd() const noexcept +{ + return this->_fd; +} + +int PipeReader::read(void* buffer, size_t size) const +{ + ssize_t n = ::read(this->_fd, buffer, size); + if (n < 0) + raise_system_error("read"); + return n; +} + +void PipeReader::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} diff --git a/tests/cpp/ymq/pipe/pipe_reader_windows.cpp b/tests/cpp/ymq/pipe/pipe_reader_windows.cpp new file mode 100644 index 000000000..29bb3b84b --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_reader_windows.cpp @@ -0,0 +1,48 @@ +#include + +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe_reader.h" + +PipeReader::PipeReader(long long fd): _fd(fd) +{ +} + +PipeReader::~PipeReader() +{ + CloseHandle((HANDLE)this->_fd); +} + +PipeReader::PipeReader(PipeReader&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; +} + +PipeReader& PipeReader::operator=(PipeReader&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +const long long PipeReader::fd() const noexcept +{ + return this->_fd; +} + +int PipeReader::read(void* buffer, size_t size) const +{ + DWORD bytes_read = 0; + if (!ReadFile((HANDLE)this->_fd, buffer, (DWORD)size, &bytes_read, nullptr)) + raise_system_error("failed to read"); + return bytes_read; +} + +void PipeReader::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} diff --git a/tests/cpp/ymq/pipe/pipe_utils.h b/tests/cpp/ymq/pipe/pipe_utils.h new file mode 100644 index 000000000..a15c0d9b7 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_utils.h @@ -0,0 +1,5 @@ +#include + +// create platform-specific pipe handles +// the first handle is read, the second handle is write +std::pair create_pipe(); diff --git a/tests/cpp/ymq/pipe/pipe_utils_linux.cpp b/tests/cpp/ymq/pipe/pipe_utils_linux.cpp new file mode 100644 index 000000000..bc171186d --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_utils_linux.cpp @@ -0,0 +1,13 @@ +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe.h" + +std::pair create_pipe() +{ + int fds[2]; + if (::pipe(fds) < 0) + raise_system_error("pipe"); + + return std::make_pair(fds[0], fds[1]); +} diff --git a/tests/cpp/ymq/pipe/pipe_utils_windows.cpp b/tests/cpp/ymq/pipe/pipe_utils_windows.cpp new file mode 100644 index 000000000..43f33e8f7 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_utils_windows.cpp @@ -0,0 +1,19 @@ +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe.h" + +std::pair create_pipe() +{ + SECURITY_ATTRIBUTES sa {}; + sa.nLength = sizeof(sa); + sa.bInheritHandle = TRUE; + + HANDLE reader = INVALID_HANDLE_VALUE; + HANDLE writer = INVALID_HANDLE_VALUE; + + if (!CreatePipe(&reader, &writer, &sa, 0)) + raise_system_error("failed to create pipe"); + + return std::make_pair((long long)reader, (long long)writer); +} diff --git a/tests/cpp/ymq/pipe/pipe_writer.h b/tests/cpp/ymq/pipe/pipe_writer.h new file mode 100644 index 000000000..4369a74f8 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_writer.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +struct Pipe; + +class PipeWriter { +public: + PipeWriter(long long fd); + ~PipeWriter(); + + // Move-only + PipeWriter(PipeWriter&&) noexcept; + PipeWriter& operator=(PipeWriter&&) noexcept; + PipeWriter(const PipeWriter&) = delete; + PipeWriter& operator=(const PipeWriter&) = delete; + + // write `size` bytes + void write_all(const void* data, size_t size); + +private: + // the native handle for this pipe reader + // on Linux, this is a file descriptor + // on Windows, this is a HANDLE + long long _fd; + + // write up to `size` bytes + int write(const void* buffer, size_t size); +}; diff --git a/tests/cpp/ymq/pipe/pipe_writer_linux.cpp b/tests/cpp/ymq/pipe/pipe_writer_linux.cpp new file mode 100644 index 000000000..ef8cbe345 --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_writer_linux.cpp @@ -0,0 +1,43 @@ +#include + +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe_writer.h" + +PipeWriter::PipeWriter(long long fd): _fd(fd) +{ +} + +PipeWriter::~PipeWriter() +{ + close(this->_fd); +} + +PipeWriter::PipeWriter(PipeWriter&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; +} + +PipeWriter& PipeWriter::operator=(PipeWriter&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +int PipeWriter::write(const void* buffer, size_t size) +{ + ssize_t n = ::write(this->_fd, buffer, size); + if (n < 0) + raise_system_error("write"); + return n; +} + +void PipeWriter::write_all(const void* buffer, size_t size) +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} diff --git a/tests/cpp/ymq/pipe/pipe_writer_windows.cpp b/tests/cpp/ymq/pipe/pipe_writer_windows.cpp new file mode 100644 index 000000000..c8172ac3a --- /dev/null +++ b/tests/cpp/ymq/pipe/pipe_writer_windows.cpp @@ -0,0 +1,43 @@ +#include + +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe_writer.h" + +PipeWriter::PipeWriter(long long fd): _fd(fd) +{ +} + +PipeWriter::~PipeWriter() +{ + CloseHandle((HANDLE)this->_fd); +} + +PipeWriter::PipeWriter(PipeWriter&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; +} + +PipeWriter& PipeWriter::operator=(PipeWriter&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +int PipeWriter::write(const void* buffer, size_t size) +{ + DWORD bytes_written = 0; + if (!WriteFile((HANDLE)this->_fd, buffer, (DWORD)size, &bytes_written, nullptr)) + raise_system_error("failed to write to pipe"); + return bytes_written; +} + +void PipeWriter::write_all(const void* buffer, size_t size) +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index 89d38805a..5e47536d4 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -4,81 +4,51 @@ This script provides a framework for running MITM test cases """ +from scapy.config import conf + +# only load the scapy layers that we need +conf.load_layers = ["inet"] + import argparse import os +import platform import signal import subprocess from typing import List -from scapy.all import IP, TCP, TunTapInterface # type: ignore +from scapy.all import IP, TCP # type: ignore from tests.cpp.ymq.py_mitm import passthrough, randomly_drop_packets, send_rst_to_client -from tests.cpp.ymq.py_mitm.types import AbstractMITM, TCPConnection - - -def echo_call(cmd: List[str]): - print(f"+ {' '.join(cmd)}") - subprocess.check_call(cmd) - - -def create_tuntap_interface(iface_name: str, mitm_ip: str, remote_ip: str) -> TunTapInterface: - """ - Creates a TUNTAP interface and sets brings it up and adds ips using the `ip` program - - Args: - iface_name: The name of the TUNTAP interface, usually like `tun0`, `tun1`, etc. - mitm_ip: The desired ip address of the mitm. This is the ip that clients can use to connect to the mitm - remote_ip: The ip that routes to/from the tuntap interface. - packets sent to `mitm_ip` will appear to come from `remote_ip`,\ - and conversely the tuntap interface can connect/send packets - to `remote_ip`, making it a suitable ip for binding a server - - Returns: - The TUNTAP interface - """ - iface = TunTapInterface(iface_name, mode="tun") - - try: - echo_call(["sudo", "ip", "link", "set", iface_name, "up"]) - echo_call(["sudo", "ip", "addr", "add", remote_ip, "peer", mitm_ip, "dev", iface_name]) - print(f"[+] Interface {iface_name} up with IP {mitm_ip}") - except subprocess.CalledProcessError: - print("[!] Could not bring up interface. Run as root or set manually.") - raise - - return iface +from tests.cpp.ymq.py_mitm.mitm_types import AbstractMITM, AbstractMITMInterface, TCPConnection -def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: AbstractMITM): +def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: AbstractMITM) -> None: """ This function serves as a framework for man in the middle implementations A client connects to the MITM, then the MITM connects to a remote server The MITM sits inbetween the client and the server, manipulating the packets sent depending on the test case This function: - 1. creates a TUNTAP interface and prepares it for MITM + 1. creates an interface and prepares it for MITM 2. handles connecting clients and handling connection closes 3. delegates additional logic to a pluggable callable, `mitm` - 4. returns when both connections have terminated (via ) + 4. returns when both connections have terminated Args: pid: this is the pid of the test process, used for signaling readiness \ we send SIGUSR1 to this process when the mitm is ready - mitm_ip: The desired ip address of the mitm server + mitm_ip: The desired ip address of the mitm server \ + Windows: This parameter is ignored mitm_port: The desired port of the mitm server. \ - This is the port used to connect to the server, but the client is free to connect on any port - remote_ip: The desired remote ip for the TUNTAP interface. This is the only ip address \ - reachable by the interface and is thus the src ip for clients, and the ip that the remote server \ - must be bound to + Linux: This is the port used to connect to the server, but the client is free to connect on any port \ + Windows: This parameter is ignored + remote_ip: The remote ip for the that the remote server is bound to server_port: The port that the remote server is bound to mitm: The core logic for a MITM test case. This callable may maintain its own state and is responsible \ - for sending packets over the TUNTAP interface (if it doesn't, nothing will happen) + for sending packets over the interface (if it doesn't, nothing will happen) """ + interface = get_interface(mitm_ip, mitm_port, remote_ip, server_port) - tuntap = create_tuntap_interface("tun0", mitm_ip, remote_ip) - - # signal the caller that the tuntap interface has been created - if pid > 0: - os.kill(pid, signal.SIGUSR1) + signal_ready(pid) # these track information about our connections # we already know what to expect for the server connection, we are the connector @@ -86,8 +56,7 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in # the port that the mitm uses to connect to the server # we increment the port for each new connection to avoid collisions - mitm_server_port = mitm_port - server_conn = TCPConnection(mitm_ip, mitm_server_port, remote_ip, server_port) + server_conn = TCPConnection(mitm_ip, mitm_port, remote_ip, server_port) # tracks the state of each connection client_sent_fin_ack = False @@ -96,7 +65,7 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in server_closed = False while True: - pkt = tuntap.recv() + pkt = interface.recv() if not pkt.haslayer(IP) or not pkt.haslayer(TCP): continue ip = pkt[IP] @@ -108,7 +77,7 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in pretty = f"[{tcp.flags}]{(': ' + str(bytes(tcp.payload))) if tcp.payload else ''}" - if not mitm.proxy(tuntap, pkt, sender, client_conn, server_conn): + if not mitm.proxy(interface, pkt, sender, client_conn, server_conn): if sender == client_conn: print(f"[DROPPED]: -> {pretty}") elif sender == server_conn: @@ -124,15 +93,12 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in print(f"<- {pretty}") if tcp.flags == "S": # SYN from client - print("-> [S]") if sender != client_conn or client_conn is None: + print("-> [S]") print(f"[*] New connection from {ip.src}:{tcp.sport} to {ip.dst}:{tcp.dport}") client_conn = sender - server_conn = TCPConnection(mitm_ip, mitm_server_port, remote_ip, server_port) - - # increment the port so that the next client connection (if there is one) uses a different port - mitm_server_port += 1 + server_conn = TCPConnection(mitm_ip, mitm_port, remote_ip, server_port) if tcp.flags == "SA": # SYN-ACK from server if sender == server_conn: @@ -155,6 +121,38 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in return +def get_interface(mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int) -> AbstractMITMInterface: + """get the platform-specific mitm interface""" + + system = platform.system() + if system == "Windows": + from tests.cpp.ymq.py_mitm.windivert import WindivertMITMInterface + + return WindivertMITMInterface(mitm_ip, mitm_port, remote_ip, server_port) + elif system in ("Linux", "Darwin"): + from tests.cpp.ymq.py_mitm.tuntap import create_tuntap_interface + + return create_tuntap_interface("tun0", mitm_ip, remote_ip) + + raise RuntimeError("unsupported platform") + + +def signal_ready(pid: int) -> None: + """signal to the caller that the mitm is ready""" + + system = platform.system() + if system == "Windows": + import win32api + import win32event + + handle = win32event.OpenEvent(win32event.EVENT_MODIFY_STATE, False, "Global\\PythonSignal") + win32event.SetEvent(handle) + win32api.CloseHandle(handle) + elif system in ("Linux", "Darwin"): + if pid > 0: + os.kill(pid, signal.SIGUSR1) # type: ignore[attr-defined] + + TESTCASES = { "passthrough": passthrough, "randomly_drop_packets": randomly_drop_packets, diff --git a/tests/cpp/ymq/py_mitm/types.py b/tests/cpp/ymq/py_mitm/mitm_types.py similarity index 72% rename from tests/cpp/ymq/py_mitm/types.py rename to tests/cpp/ymq/py_mitm/mitm_types.py index 26f4b1cda..f3650f857 100644 --- a/tests/cpp/ymq/py_mitm/types.py +++ b/tests/cpp/ymq/py_mitm/mitm_types.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from typing import Optional -from scapy.all import IP, TCP, TunTapInterface # type: ignore +from scapy.all import IP, TCP, Packet # type: ignore @dataclasses.dataclass @@ -36,20 +36,35 @@ def rewrite(self, pkt: IP, ack: Optional[int] = None, data=None): Returns: The rewritten packet, suitable for sending over TUNTAP """ + ip = pkt[IP] tcp = pkt[TCP] return ( - IP(src=self.local_ip, dst=self.remote_ip) - / TCP(sport=self.local_port, dport=self.remote_port, flags=tcp.flags, seq=tcp.seq, ack=ack or tcp.ack) + IP(src=self.local_ip or ip.src, dst=self.remote_ip) + / TCP( + sport=self.local_port or tcp.sport, + dport=self.remote_port, + flags=tcp.flags, + seq=tcp.seq, + ack=ack or tcp.ack, + ) / bytes(data or tcp.payload) ) +class AbstractMITMInterface(ABC): + @abstractmethod + def recv(self) -> Packet: ... + + @abstractmethod + def send(self, pkt: Packet) -> None: ... + + class AbstractMITM(ABC): @abstractmethod def proxy( self, - tuntap: TunTapInterface, + interface: AbstractMITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/passthrough.py b/tests/cpp/ymq/py_mitm/passthrough.py index 265574d85..b0b99262a 100644 --- a/tests/cpp/ymq/py_mitm/passthrough.py +++ b/tests/cpp/ymq/py_mitm/passthrough.py @@ -7,13 +7,13 @@ from typing import Optional -from tests.cpp.ymq.py_mitm.types import IP, AbstractMITM, TCPConnection, TunTapInterface +from tests.cpp.ymq.py_mitm.mitm_types import IP, AbstractMITM, AbstractMITMInterface, TCPConnection class MITM(AbstractMITM): def proxy( self, - tuntap: TunTapInterface, + tuntap: AbstractMITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py index 36a7b39ca..b9c1fff1e 100644 --- a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py +++ b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py @@ -5,7 +5,7 @@ import random from typing import Optional -from tests.cpp.ymq.py_mitm.types import IP, AbstractMITM, TCPConnection, TunTapInterface +from tests.cpp.ymq.py_mitm.mitm_types import IP, AbstractMITM, AbstractMITMInterface, TCPConnection class MITM(AbstractMITM): @@ -25,7 +25,7 @@ def can_drop_server(self) -> bool: def proxy( self, - tuntap: TunTapInterface, + tuntap: AbstractMITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/send_rst_to_client.py b/tests/cpp/ymq/py_mitm/send_rst_to_client.py index fa6206db4..69a60db21 100644 --- a/tests/cpp/ymq/py_mitm/send_rst_to_client.py +++ b/tests/cpp/ymq/py_mitm/send_rst_to_client.py @@ -4,17 +4,18 @@ from typing import Optional -from tests.cpp.ymq.py_mitm.types import IP, TCP, AbstractMITM, TCPConnection, TunTapInterface +from tests.cpp.ymq.py_mitm.mitm_types import IP, TCP, AbstractMITM, AbstractMITMInterface, TCPConnection class MITM(AbstractMITM): def __init__(self): # count the number of psh-acks sent by the client self._client_pshack_counter = 0 + self._client_sent_identity = 0 def proxy( self, - tuntap: TunTapInterface, + tuntap: AbstractMITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], @@ -24,8 +25,13 @@ def proxy( if pkt[TCP].flags == "PA": self._client_pshack_counter += 1 + if bytes(pkt[TCP].payload).endswith(b"client") and self._client_sent_identity == 0: + self._client_sent_identity = 1 + # on the second psh-ack, send a rst instead - if self._client_pshack_counter == 2: + # if self._client_pshack_counter == 2: + elif self._client_sent_identity == 1: + self._client_sent_identity = 2 rst_pkt = IP(src=client_conn.local_ip, dst=client_conn.remote_ip) / TCP( sport=client_conn.local_port, dport=client_conn.remote_port, flags="R", seq=pkt[TCP].ack ) diff --git a/tests/cpp/ymq/py_mitm/tuntap.py b/tests/cpp/ymq/py_mitm/tuntap.py new file mode 100644 index 000000000..43957d777 --- /dev/null +++ b/tests/cpp/ymq/py_mitm/tuntap.py @@ -0,0 +1,37 @@ +import subprocess +from typing import List + +from scapy.all import TunTapInterface # type: ignore [attr-defined] + + +def echo_call(cmd: List[str]): + print(f"+ {' '.join(cmd)}") + subprocess.check_call(cmd) + + +def create_tuntap_interface(iface_name: str, mitm_ip: str, remote_ip: str) -> TunTapInterface: + """ + Creates a TUNTAP interface and sets brings it up and adds ips using the `ip` program + + Args: + iface_name: The name of the TUNTAP interface, usually like `tun0`, `tun1`, etc. + mitm_ip: The desired ip address of the mitm. This is the ip that clients can use to connect to the mitm + remote_ip: The ip that routes to/from the tuntap interface. + packets sent to `mitm_ip` will appear to come from `remote_ip`,\ + and conversely the tuntap interface can connect/send packets + to `remote_ip`, making it a suitable ip for binding a server + + Returns: + The TUNTAP interface + """ + iface = TunTapInterface(iface_name, mode="tun") + + try: + echo_call(["sudo", "ip", "link", "set", iface_name, "up"]) + echo_call(["sudo", "ip", "addr", "add", remote_ip, "peer", mitm_ip, "dev", iface_name]) + print(f"[+] Interface {iface_name} up with IP {mitm_ip}") + except subprocess.CalledProcessError: + print("[!] Could not bring up interface. Run as root or set manually.") + raise + + return iface diff --git a/tests/cpp/ymq/py_mitm/windivert.py b/tests/cpp/ymq/py_mitm/windivert.py new file mode 100644 index 000000000..6904fd84c --- /dev/null +++ b/tests/cpp/ymq/py_mitm/windivert.py @@ -0,0 +1,34 @@ +import socket +from typing import Any + +import pydivert +from scapy.all import IP, Packet # type: ignore[attr-defined] + +from tests.cpp.ymq.py_mitm.mitm_types import AbstractMITMInterface + + +class WindivertMITMInterface(AbstractMITMInterface): + _windivert: pydivert.WinDivert + _binder: socket.socket + + __interface: Any + __direction: pydivert.Direction + + def __init__(self, local_ip: str, local_port: int, remote_ip: str, server_port: int): + self._binder = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._binder.bind((local_ip, local_port)) + self._windivert = pydivert.WinDivert(f"tcp.DstPort == {local_port} or tcp.SrcPort == {server_port}") + self._windivert.open() + + def recv(self) -> Packet: + windivert_packet = self._windivert.recv() + + # save these for later when we need to re-inject + self.__interface = windivert_packet.interface + self.__direction = windivert_packet.direction + + scapy_packet = IP(bytes(windivert_packet.raw)) + return scapy_packet + + def send(self, pkt: Packet) -> None: + self._windivert.send(pydivert.Packet(bytes(pkt), self.__interface, self.__direction)) diff --git a/tests/cpp/ymq/test_ymq.cpp b/tests/cpp/ymq/test_ymq.cpp index 71e1f7e17..9489820ff 100644 --- a/tests/cpp/ymq/test_ymq.cpp +++ b/tests/cpp/ymq/test_ymq.cpp @@ -9,12 +9,20 @@ // the test cases are at the bottom of this file, after the clients and servers // the documentation for each case is found on the TEST() definition -#include #include + +#ifdef __linux__ +#include #include #include #include +#endif // __linux__ +#ifdef _WIN32 +#define NOMINMAX +#include +#endif // _WIN32 + #include #include #include @@ -22,20 +30,20 @@ #include #include -#include "common.h" #include "scaler/error/error.h" #include "scaler/ymq/bytes.h" #include "scaler/ymq/io_context.h" -#include "scaler/ymq/io_socket.h" #include "scaler/ymq/simple_interface.h" -#include "tests/cpp/ymq/common.h" +#include "tests/cpp/ymq/common/testing.h" +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/socket.h" using namespace scaler::ymq; using namespace std::chrono_literals; -// ━━━━━━━━━━━━━━━━━━━ +// -------------------- // clients and servers -// ━━━━━━━━━━━━━━━━━━━ +// -------------------- TestResult basic_server_ymq(std::string host, uint16_t port) { @@ -66,13 +74,13 @@ TestResult basic_client_ymq(std::string host, uint16_t port) return TestResult::Success; } -TestResult basic_server_raw(std::string host, uint16_t port) +TestResult basic_server_raw(uint16_t port) { - TcpSocket socket; + Socket socket; - socket.bind(host.c_str(), port); + socket.bind(port); socket.listen(); - auto [client, _] = socket.accept(); + auto client = socket.accept(); client.write_message("server"); auto client_identity = client.read_message(); RETURN_FAILURE_IF_FALSE(client_identity == "client"); @@ -84,9 +92,9 @@ TestResult basic_server_raw(std::string host, uint16_t port) TestResult basic_client_raw(std::string host, uint16_t port) { - TcpSocket socket; + Socket socket; - socket.connect(host.c_str(), port); + socket.try_connect(host.c_str(), port); socket.write_message("client"); auto server_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); @@ -113,9 +121,9 @@ TestResult server_receives_big_message(std::string host, uint16_t port) TestResult client_sends_big_message(std::string host, uint16_t port) { - TcpSocket socket; + Socket socket; - socket.connect(host.c_str(), port); + socket.try_connect(host.c_str(), port); socket.write_message("client"); auto remote_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(remote_identity == "server"); @@ -185,9 +193,9 @@ TestResult reconnect_client_main(std::string host, uint16_t port) TestResult client_simulated_slow_network(const char* host, uint16_t port) { - TcpSocket socket; + Socket socket; - socket.connect(host, port); + socket.try_connect(host, port); socket.write_message("client"); auto remote_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(remote_identity == "server"); @@ -210,9 +218,9 @@ TestResult client_sends_incomplete_identity(const char* host, uint16_t port) { // open a socket, write an incomplete identity and exit { - TcpSocket socket; + Socket socket; - socket.connect(host, port); + socket.try_connect(host, port); auto server_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); @@ -226,8 +234,8 @@ TestResult client_sends_incomplete_identity(const char* host, uint16_t port) // connect again and try to send a message { - TcpSocket socket; - socket.connect(host, port); + Socket socket; + socket.try_connect(host, port); auto server_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); socket.write_message("client"); @@ -255,13 +263,15 @@ TestResult server_receives_huge_header(const char* host, uint16_t port) TestResult client_sends_huge_header(const char* host, uint16_t port) { +#ifdef __linux__ // ignore SIGPIPE so that write() returns EPIPE instead of crashing the program signal(SIGPIPE, SIG_IGN); +#endif { - TcpSocket socket; + Socket socket; - socket.connect(host, port); + socket.try_connect(host, port); socket.write_message("client"); auto server_identity = socket.read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); @@ -277,572 +287,693 @@ TestResult client_sends_huge_header(const char* host, uint16_t port) try { socket.write_all("yi er san si wu liu"); } catch (const std::system_error& e) { +#ifdef __linux__ if (e.code().value() == EPIPE) { - std::cout << "writing failed with EPIPE as expected after sending huge header, continuing...\n"; - break; // this is expected +#endif // __linux__ +#ifdef _WIN32 + if (e.code().value() == WSAECONNABORTED) { +#endif // _WIN32 + std::cout << "writing failed, as expected after sending huge header, continuing...\n"; + break; // this is expected + } + + throw; // rethrow other errors } + } - throw; // rethrow other errors + if (i == 10) { + std::cout << "expected EPIPE after sending huge header\n"; + return TestResult::Failure; } } - if (i == 10) { - std::cout << "expected EPIPE after sending huge header\n"; - return TestResult::Failure; + { + Socket socket; + socket.try_connect(host, port); + socket.write_message("client"); + auto server_identity = socket.read_message(); + RETURN_FAILURE_IF_FALSE(server_identity == "server"); + socket.write_message("yi er san si wu liu"); } - } - { - TcpSocket socket; - socket.connect(host, port); - socket.write_message("client"); - auto server_identity = socket.read_message(); - RETURN_FAILURE_IF_FALSE(server_identity == "server"); - socket.write_message("yi er san si wu liu"); + return TestResult::Success; } - return TestResult::Success; -} - -TestResult server_receives_empty_messages(const char* host, uint16_t port) -{ - IOContext context(1); - - auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); - - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); - - auto result2 = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result2.has_value()); - RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); - - context.removeIOSocket(socket); - - return TestResult::Success; -} - -TestResult client_sends_empty_messages(std::string host, uint16_t port) -{ - IOContext context(1); - - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); - - auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); - RETURN_FAILURE_IF_FALSE(!error); - - auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); - RETURN_FAILURE_IF_FALSE(!error2); - - context.removeIOSocket(socket); - - return TestResult::Success; -} + TestResult server_receives_empty_messages(const char* host, uint16_t port) + { + IOContext context(1); -TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator, sem_t* sem) -{ - IOContext context(1); + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); - auto socket = - syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); - std::this_thread::sleep_for(500ms); + auto result2 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); - syncConnectSocket(socket, format_address(host, port)); + context.removeIOSocket(socket); - std::this_thread::sleep_for(500ms); + return TestResult::Success; + } - if (sem_post(sem) < 0) - throw std::system_error(errno, std::generic_category(), "failed to signal semaphore"); - sem_close(sem); + TestResult client_sends_empty_messages(std::string host, uint16_t port) + { + IOContext context(1); - auto msg = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(msg.has_value()); - RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello"); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); - context.removeIOSocket(socket); - return TestResult::Success; -} + auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); + RETURN_FAILURE_IF_FALSE(!error); -// topic: the identifier of the topic, must match what's passed to the subscribers -// sem: a semaphore to synchronize the publisher and subscriber processes -// n: the number of subscribers -TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic, sem_t* sem, int n) -{ - IOContext context(1); + auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); + RETURN_FAILURE_IF_FALSE(!error2); - auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); - syncBindSocket(socket, format_address(host, port)); + context.removeIOSocket(socket); - // wait for the subscribers to be ready - for (int i = 0; i < n; i++) - if (sem_wait(sem) < 0) - throw std::system_error(errno, std::generic_category(), "failed to wait on semaphore"); - sem_close(sem); + return TestResult::Success; + } - // the topic is wrong, so no one should receive this - auto error = syncSendMessage( - socket, Message {.address = Bytes(std::format("x{}", topic)), .payload = Bytes("no one should get this")}); - RETURN_FAILURE_IF_FALSE(!error); + TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator, void* sem) + { + IOContext context(1); - // no one should receive this either - error = syncSendMessage( - socket, - Message {.address = Bytes(std::format("{}x", topic)), .payload = Bytes("no one should get this either")}); - RETURN_FAILURE_IF_FALSE(!error); + auto socket = + syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); - error = syncSendMessage(socket, Message {.address = Bytes(topic), .payload = Bytes("hello")}); - RETURN_FAILURE_IF_FALSE(!error); + std::this_thread::sleep_for(500ms); - context.removeIOSocket(socket); - return TestResult::Success; -} + syncConnectSocket(socket, format_address(host, port)); -TestResult client_close_established_connection_client(std::string host, uint16_t port) -{ - IOContext context(1); + std::this_thread::sleep_for(500ms); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); +#ifdef __linux__ + if (sem_post((sem_t*)sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to signal semaphore"); + sem_close((sem_t*)sem); +#endif // __linux__ +#ifdef _WIN32 + if (!ReleaseSemaphore(sem, 1, nullptr)) + throw std::system_error(GetLastError(), std::generic_category(), "failed to signal semaphore"); +#endif // _WIN32 - auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); + auto msg = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(msg.has_value()); + RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello"); - socket->closeConnection("server"); - socket->requestStop(); + context.removeIOSocket(socket); + return TestResult::Success; + } - context.removeIOSocket(socket); - return TestResult::Success; -} + // topic: the identifier of the topic, must match what's passed to the subscribers + // sem: a semaphore to synchronize the publisher and subscriber processes + // n: the number of subscribers + TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic, void* sem, int n) + { + IOContext context(1); + + auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); + syncBindSocket(socket, format_address(host, port)); + +// wait for the subscribers to be ready +#ifdef __linux__ + for (int i = 0; i < n; i++) + if (sem_wait((sem_t*)sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to wait on semaphore"); + sem_close((sem_t*)sem); +#endif // __linux__ +#ifdef _WIN32 + for (int i = 0; i < n; i++) + if (WaitForSingleObject(sem, 3000) != WAIT_OBJECT_0) + throw std::system_error(GetLastError(), std::generic_category(), "failed to wait on semaphore"); +#endif // _WIN32 + + // the topic doesn't match, so no one should receive this + auto error = syncSendMessage( + socket, Message {.address = Bytes(std::format("x{}", topic)), .payload = Bytes("no one should get this")}); + RETURN_FAILURE_IF_FALSE(!error); -TestResult client_close_established_connection_server(std::string host, uint16_t port) -{ - IOContext context(1); + // no one should receive this either + error = syncSendMessage( + socket, + Message {.address = Bytes(std::format("{}x", topic)), .payload = Bytes("no one should get this either")}); + RETURN_FAILURE_IF_FALSE(!error); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "server"); - syncBindSocket(socket, format_address(host, port)); + error = syncSendMessage(socket, Message {.address = Bytes(topic), .payload = Bytes("hello")}); + RETURN_FAILURE_IF_FALSE(!error); - auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); + context.removeIOSocket(socket); + return TestResult::Success; + } - result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); - RETURN_FAILURE_IF_FALSE( - result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) + TestResult client_close_established_connection_client(std::string host, uint16_t port) + { + IOContext context(1); - context.removeIOSocket(socket); - return TestResult::Success; -} + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); -TestResult close_nonexistent_connection() -{ - IOContext context(1); + auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + socket->closeConnection("server"); + context.requestIOSocketStop(socket); - // note: we're not connected to anything; this connection does not exist - // this should be a no-op.. - socket->closeConnection("server"); + context.removeIOSocket(socket); + return TestResult::Success; + } - context.removeIOSocket(socket); - return TestResult::Success; -} + TestResult client_close_established_connection_server(std::string host, uint16_t port) + { + IOContext context(1); -TestResult test_request_stop() -{ - IOContext context(1); + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, format_address(host, port)); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); - auto future = futureRecvMessage(socket); - socket->requestStop(); + result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); + RETURN_FAILURE_IF_FALSE( + result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) - auto result = future.wait_for(100ms); - RETURN_FAILURE_IF_FALSE(result == std::future_status::ready, "future should have completed"); + context.removeIOSocket(socket); + return TestResult::Success; + } - // the future created beore requestion stop should have been cancelled with an error - auto result2 = future.get(); - RETURN_FAILURE_IF_FALSE(!result2.has_value()); - RETURN_FAILURE_IF_FALSE(result2.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); + TestResult close_nonexistent_connection() + { + IOContext context(1); - // and the same for any attempts to use the socket after it's been closed - auto result3 = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result3.has_value()); - RETURN_FAILURE_IF_FALSE(result3.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - context.removeIOSocket(socket); - return TestResult::Success; -} + // note: we're not connected to anything; this connection does not exist + // this should be a no-op.. + socket->closeConnection("server"); -TestResult client_socket_stop_before_close_connection(std::string host, uint16_t port) -{ - IOContext context(1); + context.removeIOSocket(socket); + return TestResult::Success; + } - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + TestResult test_request_stop() + { + IOContext context(1); - auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - socket->requestStop(); - socket->closeConnection("server"); + auto future = futureRecvMessage(socket); + context.requestIOSocketStop(socket); - context.removeIOSocket(socket); - return TestResult::Success; -} + auto result = future.wait_for(100ms); + RETURN_FAILURE_IF_FALSE(result == std::future_status::ready, "future should have completed"); -TestResult server_socket_stop_before_close_connection(std::string host, uint16_t port) -{ - IOContext context(1); + // the future created beore requestion stop should have been cancelled with an error + auto result2 = future.get(); + RETURN_FAILURE_IF_FALSE(!result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "server"); - syncBindSocket(socket, format_address(host, port)); + // and the same for any attempts to use the socket after it's been closed + auto result3 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result3.has_value()); + RETURN_FAILURE_IF_FALSE(result3.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); - auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); + context.removeIOSocket(socket); + return TestResult::Success; + } - result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); - RETURN_FAILURE_IF_FALSE( - result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) + TestResult client_socket_stop_before_close_connection(std::string host, uint16_t port) + { + IOContext context(1); - context.removeIOSocket(socket); - return TestResult::Success; -} + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, format_address(host, port)); -// ━━━━━━━━━━━━━ -// test cases -// ━━━━━━━━━━━━━ + auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); -// this is a 'basic' test which sends a single message from a client to a server -// in this variant, both the client and server are implemented using YMQ -// -// this case includes a _delay_ -// this is a thread sleep that happens after the client sends the message, to delay the close() of the socket -// at the moment, if this delay is missing, YMQ will not shut down correctly -TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) -{ - const auto host = "localhost"; - const auto port = 2889; + context.requestIOSocketStop(socket); + socket->closeConnection("server"); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); + context.removeIOSocket(socket); + return TestResult::Success; + } - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); -} + TestResult server_socket_stop_before_close_connection(std::string host, uint16_t port) + { + IOContext context(1); -// same as above, except YMQs protocol is directly implemented on top of a TCP socket -TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) -{ - const auto host = "localhost"; - const auto port = 2890; + auto socket = syncCreateSocket(context, IOSocketType::Connector, "server"); + syncBindSocket(socket, format_address(host, port)); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); + auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); -} + result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); + RETURN_FAILURE_IF_FALSE( + result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) -TEST(CcYmqTestSuite, TestBasicRawClientRawServer) -{ - const auto host = "localhost"; - const auto port = 2891; + context.removeIOSocket(socket); + return TestResult::Success; + } - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_raw(host, port); }}); + // ------------- + // test cases + // ------------- + + // this is a 'basic' test which sends a single message from a client to a server + // in this variant, both the client and server are implemented using YMQ + // + // this case includes a _delay_ + // this is a thread sleep that happens after the client sends the message, to delay the close() of the socket + // at the moment, if this delay is missing, YMQ will not shut down correctly + TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) + { + const auto host = "localhost"; + const auto port = 2889; - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); -} + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); -// this is the same as above, except that it has no delay before calling close() on the socket -TEST(CcYmqTestSuite, TestBasicRawClientRawServerNoDelay) -{ - const auto host = "localhost"; - const auto port = 2892; + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); + } - auto result = - test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // same as above, except YMQs protocol is directly implemented on top of a TCP socket + TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) + { + const auto host = "localhost"; + const auto port = 2890; -TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) -{ - const auto host = "localhost"; - const auto port = 2893; + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = + test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(host, port); }}); + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); + } - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); -} + TEST(CcYmqTestSuite, TestBasicRawClientRawServer) + { + const auto host = "localhost"; + const auto port = 2891; -// in this test case, the client sends a large message to the server -// YMQ should be able to handle this without issue -TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) -{ - const auto host = "localhost"; - const auto port = 2894; - - auto result = test( - 10, - {[=] { return client_sends_big_message(host, port); }, - [=] { return server_receives_big_message(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_raw(port); }}); -// this is the no-op/passthrough man in the middle test -// for this test case we use YMQ on both the client side and the server side -// the client connects to the mitm, and the mitm connects to the server -// when the mitm receives packets from the client, it forwards it to the server without changing it -// and similarly when it receives packets from the server, it forwards them to the client -// -// the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, -// and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port -// this defines the address of the mitm, and the addresses that can connect to it -// for more, see the python mitm files -TEST(CcYmqTestSuite, TestMitmPassthrough) -{ - auto mitm_ip = "192.0.2.4"; - auto mitm_port = 2323; - auto remote_ip = "192.0.2.3"; - auto remote_port = 23571; - - // the Python program must be the first and only the first function passed to test() - // we must also pass `true` as the third argument to ensure that Python is fully started - // before beginning the test - auto result = test( - 20, - {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, - [=] { return basic_client_ymq(mitm_ip, mitm_port); }, - [=] { return basic_server_ymq(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); -} + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); + } -// this test uses the mitm to test the reconnect logic of YMQ by sending RST packets -TEST(CcYmqTestSuite, TestMitmReconnect) -{ - auto mitm_ip = "192.0.2.4"; - auto mitm_port = 2525; - auto remote_ip = "192.0.2.3"; - auto remote_port = 23575; - - auto result = test( - 10, - {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, - [=] { return reconnect_client_main(mitm_ip, mitm_port); }, - [=] { return reconnect_server_main(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); -} + // this is the same as above, except that it has no delay before calling close() on the socket + TEST(CcYmqTestSuite, TestBasicRawClientRawServerNoDelay) + { + const auto host = "localhost"; + const auto port = 2892; -// TODO: Make this more reliable, and re-enable it -// in this test, the mitm drops a random % of packets arriving from the client and server -TEST(CcYmqTestSuite, TestMitmRandomlyDropPackets) -{ - auto mitm_ip = "192.0.2.4"; - auto mitm_port = 2828; - auto remote_ip = "192.0.2.3"; - auto remote_port = 23591; - - auto result = test( - 60, - {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, - [=] { return basic_client_ymq(mitm_ip, mitm_port); }, - [=] { return basic_server_ymq(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); -} + auto result = + test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } -// in this test the client is sending a message to the server -// but we simulate a slow network connection by sending the message in segmented chunks -TEST(CcYmqTestSuite, TestSlowNetwork) -{ - const auto host = "localhost"; - const auto port = 2895; + TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) + { + const auto host = "localhost"; + const auto port = 2893; - auto result = test( - 20, {[=] { return client_simulated_slow_network(host, port); }, [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(port); }}); -// TODO: figure out why this test fails in ci sometimes, and re-enable -// -// in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects -// then a new client connection is established, and this one sends a complete identity and message -// YMQ should be able to recover from a poorly-behaved client like this -TEST(CcYmqTestSuite, TestClientSendIncompleteIdentity) -{ - const auto host = "localhost"; - const auto port = 2896; + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); + } - auto result = test( - 20, - {[=] { return client_sends_incomplete_identity(host, port); }, [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // in this test case, the client sends a large message to the server + // YMQ should be able to handle this without issue + TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) + { + const auto host = "localhost"; + const auto port = 2894; + + auto result = test( + 10, + {[=] { return client_sends_big_message(host, port); }, + [=] { return server_receives_big_message(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } -// TODO: this should pass -// currently YMQ rejects the second connection, saying that the message is too large even when it isn't -// -// in this test, the client sends an unrealistically-large header -// it is important that YMQ checks the header size before allocating memory -// both for resilence against attacks and to guard against errors -TEST(CcYmqTestSuite, TestClientSendHugeHeader) -{ - const auto host = "localhost"; - const auto port = 2897; - - auto result = test( - 20, - {[=] { return client_sends_huge_header(host, port); }, - [=] { return server_receives_huge_header(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // this is the no-op/passthrough man in the middle test + // for this test case we use YMQ on both the client side and the server side + // the client connects to the mitm, and the mitm connects to the server + // when the mitm receives packets from the client, it forwards it to the server without changing it + // and similarly when it receives packets from the server, it forwards them to the client + // + // the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, + // and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port + // this defines the address of the mitm, and the addresses that can connect to it + // for more, see the python mitm files + TEST(CcYmqTestSuite, TestMitmPassthrough) + { +#ifdef __linux__ + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; +#endif // __linux__ +#ifdef _WIN32 + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; +#endif // _WIN32 + auto mitm_port = random_port(); + auto remote_port = 23571; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); + } -// in this test, the client sends empty messages to the server -// there are in effect two kinds of empty messages: Bytes() and Bytes("") -// in the former case, the bytes contains a nullptr -// in the latter case, the bytes contains a zero-length allocation -// it's important that the behaviour of YMQ is known for both of these cases -TEST(CcYmqTestSuite, TestClientSendEmptyMessage) -{ - const auto host = "localhost"; - const auto port = 2898; - - auto result = test( - 20, - {[=] { return client_sends_empty_messages(host, port); }, - [=] { return server_receives_empty_messages(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + // this is the same as the above, but both the client and server use raw sockets + TEST(CcYmqTestSuite, TestMitmPassthroughRaw) + { +#ifdef __linux__ + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; +#endif // __linux__ +#ifdef _WIN32 + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; +#endif // _WIN32 + auto mitm_port = random_port(); + auto remote_port = 23574; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_raw(mitm_ip, mitm_port); }, + [=] { return basic_server_raw(remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); + } -// this case tests the publish-subscribe pattern of YMQ -// we create one publisher and two subscribers with a common topic -// the publisher will send two messages to the wrong topic -// none of the subscribers should receive these -// and then the publisher will send a message to the correct topic -// both subscribers should receive this message -TEST(CcYmqTestSuite, TestPubSub) -{ - const auto host = "localhost"; - const auto port = 2900; - auto topic = "mytopic"; + // this test uses the mitm to test the reconnect logic of YMQ by sending RST packets + TEST(CcYmqTestSuite, TestMitmReconnect) + { +#ifdef __linux__ + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; +#endif // __linux__ +#ifdef _WIN32 + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; +#endif // _WIN32 + auto mitm_port = random_port(); + auto remote_port = 23572; + + auto result = test( + 30, + {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return reconnect_client_main(mitm_ip, mitm_port); }, + [=] { return reconnect_server_main(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); + } - // allocate a semaphore to synchronize the publisher and subscriber processes - sem_t* sem = - static_cast(mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + // TODO: Make this more reliable, and re-enable it + // in this test, the mitm drops a random % of packets arriving from the client and server + TEST(CcYmqTestSuite, TestMitmRandomlyDropPackets) + { +#ifdef __linux__ + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; +#endif // __linux__ +#ifdef _WIN32 + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; +#endif // _WIN32 + auto mitm_port = random_port(); + auto remote_port = 23573; + + auto result = test( + 60, + {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, + [=] { return basic_client_ymq(mitm_ip, mitm_port); }, + [=] { return basic_server_ymq(remote_ip, remote_port); }}, + true); + EXPECT_EQ(result, TestResult::Success); + } - if (sem == MAP_FAILED) - throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + // in this test the client is sending a message to the server + // but we simulate a slow network connection by sending the message in segmented chunks + TEST(CcYmqTestSuite, TestSlowNetwork) + { + const auto host = "localhost"; + const auto port = 2895; - if (sem_init(sem, 1, 0) < 0) - throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); + auto result = test( + 20, + {[=] { return client_simulated_slow_network(host, port); }, [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } - auto result = test( - 20, - {[=] { return pubsub_publisher(host, port, topic, sem, 2); }, - [=] { return pubsub_subscriber(host, port, topic, 0, sem); }, - [=] { return pubsub_subscriber(host, port, topic, 1, sem); }}); + // TODO: figure out why this test fails in ci sometimes, and re-enable + // + // in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects + // then a new client connection is established, and this one sends a complete identity and message + // YMQ should be able to recover from a poorly-behaved client like this + TEST(CcYmqTestSuite, TestClientSendIncompleteIdentity) + { + const auto host = "localhost"; + const auto port = 2896; + + auto result = test( + 20, + {[=] { return client_sends_incomplete_identity(host, port); }, + [=] { return basic_server_ymq(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } - sem_destroy(sem); - munmap(sem, sizeof(sem_t)); + // TODO: this should pass + // currently YMQ rejects the second connection, saying that the message is too large even when it isn't + // + // in this test, the client sends an unrealistically-large header + // it is important that YMQ checks the header size before allocating memory + // both for resilence against attacks and to guard against errors + TEST(CcYmqTestSuite, TestClientSendHugeHeader) + { + const auto host = "localhost"; + const auto port = 2897; + + auto result = test( + 20, + {[=] { return client_sends_huge_header(host, port); }, + [=] { return server_receives_huge_header(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } - EXPECT_EQ(result, TestResult::Success); -} + // in this test, the client sends empty messages to the server + // there are in effect two kinds of empty messages: Bytes() and Bytes("") + // in the former case, the bytes contains a nullptr + // in the latter case, the bytes contains a zero-length allocation + // it's important that the behaviour of YMQ is known for both of these cases + TEST(CcYmqTestSuite, TestClientSendEmptyMessage) + { + const auto host = "localhost"; + const auto port = 2898; + + auto result = test( + 20, + {[=] { return client_sends_empty_messages(host, port); }, + [=] { return server_receives_empty_messages(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } -// this sets the publisher with an empty topic and the subscribers with two other topics -// both subscribers should get all messages -TEST(CcYmqTestSuite, TestPubSubEmptyTopic) -{ - const auto host = "localhost"; - const auto port = 2906; + // this case tests the publish-subscribe pattern of YMQ + // we create one publisher and two subscribers with a common topic + // the publisher will send two messages to the wrong topic + // none of the subscribers should receive these + // and then the publisher will send a message to the correct topic + // both subscribers should receive this message + TEST(CcYmqTestSuite, TestPubSub) + { + const auto host = "localhost"; + const auto port = 2900; + auto topic = "mytopic"; + +// allocate a semaphore to synchronize the publisher and subscriber processes +#ifdef __linux__ + sem_t* sem = static_cast( + mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + + if (sem == MAP_FAILED) + throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + + if (sem_init(sem, 1, 0) < 0) + throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); +#endif // __linux__ +#ifdef _WIN32 + HANDLE sem = CreateSemaphore( + nullptr, // default security attributes + 0, // initial count + 2, // maximum count + nullptr); // unnamed semaphore + if (sem == nullptr) + throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); +#endif // _WIN32 + + auto result = test( + 20, + {[=] { return pubsub_publisher(host, port, topic, sem, 2); }, + [=] { return pubsub_subscriber(host, port, topic, 0, sem); }, + [=] { return pubsub_subscriber(host, port, topic, 1, sem); }}); + +#ifdef __linux__ + sem_destroy(sem); + munmap(sem, sizeof(sem_t)); +#endif // __linux__ +#ifdef _WIN32 + CloseHandle(sem); +#endif // _WIN32 + + EXPECT_EQ(result, TestResult::Success); + } - // allocate a semaphore to synchronize the publisher and subscriber processes - sem_t* sem = - static_cast(mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + // this sets the publisher with an empty topic and the subscribers with two other topics + // both subscribers should get all messages + TEST(CcYmqTestSuite, TestPubSubEmptyTopic) + { + const auto host = "localhost"; + const auto port = 2906; + +// allocate a semaphore to synchronize the publisher and subscriber processes +#ifdef __linux__ + sem_t* sem = static_cast( + mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + + if (sem == MAP_FAILED) + throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + + if (sem_init(sem, 1, 0) < 0) + throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); +#endif // __linux__ +#ifdef _WIN32 + HANDLE sem = CreateSemaphore( + nullptr, // default security attributes + 0, // initial count + 2, // maximum count + nullptr); // unnamed semaphore + if (sem == nullptr) + throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); +#endif // _WIN32 + + auto result = test( + 20, + {[=] { return pubsub_publisher(host, port, "", sem, 2); }, + [=] { return pubsub_subscriber(host, port, "abc", 0, sem); }, + [=] { return pubsub_subscriber(host, port, "def", 1, sem); }}); + +#ifdef __linux__ + sem_destroy(sem); + munmap(sem, sizeof(sem_t)); +#endif // __linux__ +#ifdef _WIN32 + CloseHandle(sem); +#endif // _WIN32 + + EXPECT_EQ(result, TestResult::Success); + } - if (sem == MAP_FAILED) - throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + // in this test case, the client establishes a connection with the server and then explicitly closes it + TEST(CcYmqTestSuite, DISABLED_TestClientCloseEstablishedConnection) + { + const auto host = "localhost"; + const auto port = 2902; + + auto result = test( + 20, + {[=] { return client_close_established_connection_client(host, port); }, + [=] { return client_close_established_connection_server(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } - if (sem_init(sem, 1, 0) < 0) - throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); + // this test case is similar to the one above, except that it requests the socket stop before closing the connection + TEST(CcYmqTestSuite, TestClientSocketStopBeforeCloseConnection) + { + const auto host = "localhost"; + const auto port = 2904; + + auto result = test( + 20, + {[=] { return client_socket_stop_before_close_connection(host, port); }, + [=] { return server_socket_stop_before_close_connection(host, port); }}); + EXPECT_EQ(result, TestResult::Success); + } - auto result = test( - 20, - {[=] { return pubsub_publisher(host, port, "", sem, 2); }, - [=] { return pubsub_subscriber(host, port, "abc", 0, sem); }, - [=] { return pubsub_subscriber(host, port, "def", 1, sem); }}); + // in this test case, the we try to close a connection that does not exist + TEST(CcYmqTestSuite, TestClientCloseNonexistentConnection) + { + auto result = close_nonexistent_connection(); + EXPECT_EQ(result, TestResult::Success); + } - sem_destroy(sem); - munmap(sem, sizeof(sem_t)); + // this test case verifies that requesting a socket stop causes pending and subsequent operations to be cancelled + TEST(CcYmqTestSuite, TestRequestSocketStop) + { + auto result = test_request_stop(); + EXPECT_EQ(result, TestResult::Success); + } - EXPECT_EQ(result, TestResult::Success); -} + // main -// in this test case, the client establishes a connection with the server and then explicitly closes it -TEST(CcYmqTestSuite, TestClientCloseEstablishedConnection) -{ - const auto host = "localhost"; - const auto port = 2902; - - auto result = test( - 20, - {[=] { return client_close_established_connection_client(host, port); }, - [=] { return client_close_established_connection_server(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + int main(int argc, char** argv) + { + ensure_python_initialized(); + +#ifdef _WIN32 + // initialize winsock + WSADATA wsaData = {}; + int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (iResult != 0) { + std::cerr << "WSAStartup failed: " << iResult << "\n"; + return 1; + } +#endif // _WIN32 -// this test case is similar to the one above, except that it requests the socket stop before closing the connection -TEST(CcYmqTestSuite, TestClientSocketStopBeforeCloseConnection) -{ - const auto host = "localhost"; - const auto port = 2904; - - auto result = test( - 20, - {[=] { return client_socket_stop_before_close_connection(host, port); }, - [=] { return server_socket_stop_before_close_connection(host, port); }}); - EXPECT_EQ(result, TestResult::Success); -} + testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); -// in this test case, the we try to close a connection that does not exist -TEST(CcYmqTestSuite, TestClientCloseNonexistentConnection) -{ - auto result = close_nonexistent_connection(); - EXPECT_EQ(result, TestResult::Success); -} +#ifdef _WIN32 + WSACleanup(); +#endif // _WIN32 -// this test case verifies that requesting a socket stop causes pending and subsequent operations to be cancelled -TEST(CcYmqTestSuite, TestRequestSocketStop) -{ - auto result = test_request_stop(); - EXPECT_EQ(result, TestResult::Success); -} + maybe_finalize_python(); + return result; + } diff --git a/tests/scheduler/test_scaling.py b/tests/scheduler/test_scaling.py index 401c6abed..8526380f6 100644 --- a/tests/scheduler/test_scaling.py +++ b/tests/scheduler/test_scaling.py @@ -109,7 +109,7 @@ def test_scaling_basic(self): os.kill(scheduler.pid, signal.SIGINT) scheduler.join() - os.kill(object_storage.pid, signal.SIGKILL) + object_storage.kill() object_storage.join() os.kill(webhook_server.pid, signal.SIGINT)