diff --git a/.github/actions/run-test/action.yml b/.github/actions/run-test/action.yml index 52bdca95c..49770009d 100644 --- a/.github/actions/run-test/action.yml +++ b/.github/actions/run-test/action.yml @@ -54,10 +54,22 @@ runs: shell: bash run: | uv pip install --system -r examples/applications/requirements_applications.txt + uv pip install --system -r examples/ray_compat/requirements.txt for example in "./examples"/*.py; do echo "Running $example" python $example done + readarray -t skip_tests < examples/ray_compat/skip_tests.txt + for example in "./examples/ray_compat"/*.py; do + filename=$(basename "$example") + if [[ " ${skip_tests[*]} " =~ [[:space:]]${filename}[[:space:]] ]]; then + echo "Skipping $example" + continue + fi + + echo "Running $example" + python $example + done for example in "./examples/applications"/*.py; do if python -c 'import sys; sys.exit(not sys.version_info <= (3, 10))'; then if [ "$example" = "./examples/applications/yfinance_historical_price.py" ]; then diff --git a/docs/source/index.rst b/docs/source/index.rst index d35ea7b64..79880cc9d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -26,6 +26,7 @@ Content tutorials/quickstart tutorials/features + tutorials/compatibility/ray tutorials/configuration tutorials/examples tutorials/development/devcontainer diff --git a/docs/source/tutorials/compatibility/ray.rst b/docs/source/tutorials/compatibility/ray.rst new file mode 100644 index 000000000..8a2ac3d5d --- /dev/null +++ b/docs/source/tutorials/compatibility/ray.rst @@ -0,0 +1,129 @@ +.. _ray_compatibility: + +Ray Compatibility Layer +======================= + +Scaler is a lightweight distributed computation engine similar to Ray. Scaler supports many of the same concepts as Ray including +remote functions (known as tasks in Scaler), futures, cluster object storage, labels (known as capabilities in Scaler), and it comes with comparable monitoring tools. + +Unlike Ray, Scaler supports both local clusters and also easily integrates with multiple cloud providers out of the box, including AWS EC2 and IBM Symphony, +with more providers planned for the future. You can view our `roadmap on GitHub `_ +for details on upcoming cloud integrations. + +Scaler provides a compatibility layer that allows developers familiar with the `Ray `_ API to adopt Scaler with minimal code changes. + +Quickstart +---------- + +To start using Scaler's Ray compatibility layer, ensure you have `opengris-scaler `_ installed in your Python environment. + +Then import ``scaler.compat.ray`` in your application after importing ``ray`` and before using any Ray APIs. + +This import patches the ``ray`` module, allowing you to use Ray's API as you normally would. + +.. code-block:: python + + import ray + import scaler.compat.ray + + # existing Ray app + +This will start a new local scheduler and cluster combo. To use an existing cluster, pass the address of the scheduler to ``scaler_init()``: + +.. code-block:: python + + import ray + from scaler.compat.ray import scaler_init + + # connects to an existing cluster + # when an address is provided, a local cluster is not started + scaler_init(address="tcp://:") + + # existing Ray app + +You can also provide scheduler and cluster configuration options to ``scaler_init()`` to configure the locally created cluster: + +.. code-block:: python + + import ray + from scaler.compat.ray import scaler_init + + # overrides the number of workers in the implicitly-created local cluster (defaults to number of CPU cores) + scaler_init(n_workers=5) + + # existing Ray app + +Remote Function Limitations +--------------------------- + +``ray.remote()`` accepts many parameters, but Scaler's compatibility layer only supports ``num_returns``. Other parameters will be ignored. + +Shutting Down +------------- + +The implicitly-created local cluster is a subprocess with global scope, and won't be shut down automatically. +This can cause your program to keep executing after your program has completed, it is therefore important to call ``ray.shutdown()`` +when your program is done when using the implicit local cluster. + +A Note about the Actor Model +---------------------------- + +Ray supports a powerful actor model that allows for stateful computation. This is currently not supported by the Scaler, but is planned for a future release. + +This documentation will be updated when actor support is added. For now please view our `roadmap on GitHub `_ for more details. + +Decorating a class with ``@ray.remote`` will raise a ``NotImplementedError``. + +Full Examples +------------- + +See `the examples directory `_ for complete Scaler Ray compatibility layer examples including: + +* `basic_local_cluster.py`: Demonstrates using the Scaler Ray compatibility layer with the implicitly-created local cluster. +* `basic_remote_cluster.py`: Demonstrates using the Scaler Ray compatibility layer with an existing remote cluster. +* `batch_prediction.py`: Demonstrates using Scaler's Ray compatibility layer for batch prediction, copied from Ray Core's documentation. +* `highly_parallel.py`: Demonstrates highly parallel computations, copied from Ray Core's documentation. +* `map_reduce.py`: Demonstrates a MapReduce pattern using Scaler's Ray compatibility layer, copied from Ray Core's documentation. +* `plot_hyperparameter.py`: Demonstrates hyperparameter tuning and plotting, copied from Ray Core's documentation. +* `web_crawler.py`: Demonstrates a web crawling example, copied from Ray Core's documentation. + +Supported APIs +-------------- + +The compatibility layer supports a subset of Ray Core's API. + +Below is a comprehensive list of the supported APIs. Functions and classes not in this list + +Core API +~~~~~~~~ + +* ``@ray.remote``: Only supports remote functions. Decorating a class with ``@ray.remote`` will raise a ``NotImplementedError``. +* ``ray.shutdown()`` +* ``ray.is_initialized()`` +* ``ray.get()`` +* ``ray.put()`` +* ``ray.wait()`` +* ``ray.cancel()`` + +Ray Utilities (`ray.util`) +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +* ``ray.util.as_completed()`` +* ``ray.util.map_unordered()`` + +Unsupported APIs +~~~~~~~~~~~~~~~~ + +The following APIs are not supported by the Scaler Ray compatibility layer. + +Some functions will be no-ops or return a mock object while others will raise a ``NotImplementedError`` exception. + +* ``ray.init()``: No-op. Use ``scaler_init()`` from ``scaler.compat.ray`` instead. +* ``ray.get_actor()``: Returns a mock object. +* ``ray.method()``: Raises ``NotImplementedError``. +* ``ray.actor``: Raises ``NotImplementedError``. +* ``ray.runtime_context``: Raises ``NotImplementedError``. +* ``ray.cross_language``: Raises ``NotImplementedError``. +* ``ray.get_gpu_ids()``: Raises ``NotImplementedError``. +* ``ray.get_runtime_context()``: Raises ``NotImplementedError``. +* ``ray.kill()``: Raises ``NotImplementedError``. diff --git a/examples/ray_compat/basic_local_cluster.py b/examples/ray_compat/basic_local_cluster.py new file mode 100644 index 000000000..c10d38dcb --- /dev/null +++ b/examples/ray_compat/basic_local_cluster.py @@ -0,0 +1,26 @@ +"""This is a basic example showing the minimal changes needed to start using Scaler for a Ray application.""" + +import ray + +# this patches the ray module +import scaler.compat.ray # noqa: F401 + + +def main(): + # the scaler is implicitly initialized here + # see basic_remote_cluster.py for more advanced usage + @ray.remote + def my_function(): + return 1 + + # this is executed by the local scaler cluster + future = my_function.remote() + assert ray.get(future) == 1 + + # the implicitly-created cluster is globally-scoped + # so we need to shut it down explicitly + ray.shutdown() + + +if __name__ == "__main__": + main() diff --git a/examples/ray_compat/basic_remote_cluster.py b/examples/ray_compat/basic_remote_cluster.py new file mode 100644 index 000000000..a8ff4c5fb --- /dev/null +++ b/examples/ray_compat/basic_remote_cluster.py @@ -0,0 +1,35 @@ +"""This is a basic example showing the minimal changes needed to start using Scaler for a Ray application.""" + +import ray + +from scaler.cluster.combo import SchedulerClusterCombo + +# this patches the ray module +from scaler.compat.ray import scaler_init + + +# this is an example and we don't have a real remote cluster here +# so for demonstration purposes we just start a local cluster +def start_remote_cluster() -> SchedulerClusterCombo: + return SchedulerClusterCombo(n_workers=1) + + +def main(address: str): + # explicitly init the scaler + # we explicitly provide the address of the remote scheduler + scaler_init(address=address) + + @ray.remote + def my_function(): + return 1 + + # this is executed by the remote scaler cluster + future = my_function.remote() + assert ray.get(future) == 1 + + +if __name__ == "__main__": + combo = start_remote_cluster() + main(combo.get_address()) + + combo.shutdown() diff --git a/examples/ray_compat/batch_prediction.py b/examples/ray_compat/batch_prediction.py new file mode 100644 index 000000000..65cf1d65b --- /dev/null +++ b/examples/ray_compat/batch_prediction.py @@ -0,0 +1,78 @@ +""" +This example was copied from https://docs.ray.io/en/latest/ray-core/examples/batch_prediction.html + +Like in `highly_parallel.py`, only minimal changes are needed for the example to work on Scaler. +""" + + +import numpy as np +import pandas as pd +import pyarrow.parquet as pq +import ray + +# changed line 1/2 +import scaler.compat.ray # noqa: F401 + + +def load_model(): + # A dummy model. + def model(batch: pd.DataFrame) -> pd.DataFrame: + # Dummy payload so copying the model will actually copy some data + # across nodes. + model.payload = np.zeros(100_000_000) # type: ignore[attr-defined] + return pd.DataFrame({"score": batch["passenger_count"] % 2 == 0}) + + return model + + +def main(): + @ray.remote + def make_prediction(model, shard_path): + df = pq.read_table(shard_path).to_pandas() + result = model(df) + + # Write out the prediction result. + # NOTE: unless the driver will have to further process the + # result (other than simply writing out to storage system), + # writing out at remote task is recommended, as it can avoid + # congesting or overloading the driver. + # ... + + # Here we just return the size about the result in this example. + return len(result) + + # 12 files, one for each remote task. + input_files = [ + f"s3://anonymous@air-example-data/ursa-labs-taxi-data/downsampled_2009_full_year_data.parquet" + f"/fe41422b01c04169af2a65a83b753e0f_{i:06d}.parquet" + for i in range(12) + ] + + # ray.put() the model just once to local object store, and then pass the + # reference to the remote tasks. + model = load_model() + model_ref = ray.put(model) + + result_refs = [] + + # Launch all prediction tasks. + for file in input_files: + # Launch a prediction task by passing model reference and shard file to it. + # NOTE: it would be highly inefficient if you are passing the model itself + # like make_prediction.remote(model, file), which in order to pass the model + # to remote node will ray.put(model) for each task, potentially overwhelming + # the local object store and causing out-of-disk error. + result_refs.append(make_prediction.remote(model_ref, file)) + + results = ray.get(result_refs) + + # Let's check prediction output size. + for r in results: + print("Prediction output size:", r) + + +if __name__ == "__main__": + main() + + # changed line 2/2 + ray.shutdown() diff --git a/examples/ray_compat/highly_parallel.py b/examples/ray_compat/highly_parallel.py new file mode 100644 index 000000000..779dc0467 --- /dev/null +++ b/examples/ray_compat/highly_parallel.py @@ -0,0 +1,50 @@ +""" +This example was copied from https://docs.ray.io/en/latest/ray-core/examples/highly_parallel.html + +Only one or two changes are needed to make this example work on Scaler. +First is to import the compatibility layer, this patches Ray Core's API. +The second is to call `ray.shutdown()`, necessary only if using a local cluster. +""" + +import random +import time +from fractions import Fraction + +import ray + +# this is one of only two changed lines +import scaler.compat.ray # noqa: F401 + +# Let's start Ray +ray.init(address="auto") + + +def main(): + @ray.remote + def pi4_sample(sample_count): + """pi4_sample runs sample_count experiments, and returns the + fraction of time it was inside the circle. + """ + in_count = 0 + for i in range(sample_count): + x = random.random() + y = random.random() + if x * x + y * y <= 1: + in_count += 1 + return Fraction(in_count, sample_count) + + SAMPLE_COUNT = 1000 * 1000 + start = time.time() + future = pi4_sample.remote(sample_count=SAMPLE_COUNT) # type: ignore[call-arg] + pi4 = ray.get(future) # noqa: F841 + end = time.time() + dur = end - start + print(f"Running {SAMPLE_COUNT} tests took {dur} seconds") + + +if __name__ == "__main__": + main() + + # this is the second changed line + # we need to explicitly shut down the implicit cluster + ray.shutdown() diff --git a/examples/ray_compat/map_reduce.py b/examples/ray_compat/map_reduce.py new file mode 100644 index 000000000..5972272ce --- /dev/null +++ b/examples/ray_compat/map_reduce.py @@ -0,0 +1,71 @@ +""" +This example was copied from https://docs.ray.io/en/latest/ray-core/examples/map_reduce.html + +Like in `highly_parallel.py`, only minimal changes are needed for the example to work on Scaler. +""" + +import subprocess + +import ray + +# changed line 1/2 +import scaler.compat.ray # noqa: F401 + +zen_of_python = subprocess.check_output(["python", "-c", "import this"]) +corpus = zen_of_python.split() + +num_partitions = 3 +chunk = len(corpus) // num_partitions +partitions = [corpus[i * chunk : (i + 1) * chunk] for i in range(num_partitions)] + + +def map_function(document): + for word in document.lower().split(): + yield word, 1 + + +def main(): + @ray.remote + def apply_map(corpus, num_partitions=3): + map_results = [list() for _ in range(num_partitions)] # type: ignore[var-annotated] + for document in corpus: + for result in map_function(document): + first_letter = result[0].decode("utf-8")[0] + word_index = ord(first_letter) % num_partitions + map_results[word_index].append(result) + return map_results + + map_results = [apply_map.options(num_returns=num_partitions).remote(data, num_partitions) for data in partitions] + + for i in range(num_partitions): + mapper_results = ray.get(map_results[i]) + for j, result in enumerate(mapper_results): + print(f"Mapper {i}, return value {j}: {result[:2]}") + + @ray.remote + def apply_reduce(*results): + reduce_results = dict() + for res in results: + for key, value in res: + if key not in reduce_results: + reduce_results[key] = 0 + reduce_results[key] += value + + return reduce_results + + outputs = [] + for i in range(num_partitions): + outputs.append(apply_reduce.remote(*[partition[i] for partition in map_results])) + + counts = {k: v for output in ray.get(outputs) for k, v in output.items()} + + sorted_counts = sorted(counts.items(), key=lambda item: item[1], reverse=True) + for count in sorted_counts: + print(f"{count[0].decode('utf-8')}: {count[1]}") + + +if __name__ == "__main__": + main() + + # changed line 2/2 + ray.shutdown() diff --git a/examples/ray_compat/plot_hyperparameter.py b/examples/ray_compat/plot_hyperparameter.py new file mode 100644 index 000000000..4693394fc --- /dev/null +++ b/examples/ray_compat/plot_hyperparameter.py @@ -0,0 +1,178 @@ +""" +This example was copied from https://docs.ray.io/en/latest/ray-core/examples/plot_hyperparameter.html + +Like in `highly_parallel.py`, only minimal changes are needed for the example to work on Scaler. +""" + +import os + +import numpy as np +import ray +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from filelock import FileLock +from torchvision import datasets, transforms + +# changed line 1/2 +import scaler.compat.ray # noqa: F401 + +ray.init() + +# The number of sets of random hyperparameters to try. +num_evaluations = 10 + + +# A function for generating random hyperparameters. +def generate_hyperparameters(): + return { + "learning_rate": 10 ** np.random.uniform(-5, 1), + "batch_size": np.random.randint(1, 100), + "momentum": np.random.uniform(0, 1), + } + + +def get_data_loaders(batch_size): + mnist_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + + # We add FileLock here because multiple workers will want to + # download data, and this may cause overwrites since + # DataLoader is not threadsafe. + with FileLock(os.path.expanduser("~/data.lock")): + train_loader = torch.utils.data.DataLoader( + datasets.MNIST("~/data", train=True, download=True, transform=mnist_transforms), + batch_size=batch_size, + shuffle=True, + ) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST("~/data", train=False, transform=mnist_transforms), batch_size=batch_size, shuffle=True + ) + return train_loader, test_loader + + +class ConvNet(nn.Module): + """Simple two layer Convolutional Neural Network.""" + + def __init__(self): + super(ConvNet, self).__init__() + self.conv1 = nn.Conv2d(1, 3, kernel_size=3) + self.fc = nn.Linear(192, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 3)) + x = x.view(-1, 192) + x = self.fc(x) + return F.log_softmax(x, dim=1) + + +def train(model, optimizer, train_loader, device=torch.device("cpu")): + """Optimize the model with one pass over the data. + + Cuts off at 1024 samples to simplify training. + """ + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if batch_idx * len(data) > 1024: + return + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + + +def test(model, test_loader, device=torch.device("cpu")): + """Checks the validation accuracy of the model. + + Cuts off at 512 samples for simplicity. + """ + model.eval() + correct = 0 + total = 0 + with torch.no_grad(): + for batch_idx, (data, target) in enumerate(test_loader): + if batch_idx * len(data) > 512: + break + data, target = data.to(device), target.to(device) + outputs = model(data) + _, predicted = torch.max(outputs.data, 1) + total += target.size(0) + correct += (predicted == target).sum().item() + + return correct / total + + +def main(): + @ray.remote + def evaluate_hyperparameters(config): + model = ConvNet() + train_loader, test_loader = get_data_loaders(config["batch_size"]) + optimizer = optim.SGD(model.parameters(), lr=config["learning_rate"], momentum=config["momentum"]) + train(model, optimizer, train_loader) + return test(model, test_loader) + + # Keep track of the best hyperparameters and the best accuracy. + best_hyperparameters = None + best_accuracy = 0 + # A list holding the object refs for all of the experiments that we have + # launched but have not yet been processed. + remaining_ids = [] + # A dictionary mapping an experiment's object ref to its hyperparameters. + # hyerparameters used for that experiment. + hyperparameters_mapping = {} + + # Randomly generate sets of hyperparameters and launch a task to evaluate it. + for i in range(num_evaluations): + hyperparameters = generate_hyperparameters() + accuracy_id = evaluate_hyperparameters.remote(hyperparameters) + remaining_ids.append(accuracy_id) + hyperparameters_mapping[accuracy_id] = hyperparameters + + # Fetch and print the results of the tasks in the order that they complete. + while remaining_ids: + # Use ray.wait to get the object ref of the first task that completes. + done_ids, remaining_ids = ray.wait(remaining_ids) + # There is only one return result by default. + result_id = done_ids[0] + + hyperparameters = hyperparameters_mapping[result_id] + accuracy = ray.get(result_id) + print( + """We achieve accuracy {:.3}% with + learning_rate: {:.2} + batch_size: {} + momentum: {:.2} + """.format( + 100 * accuracy, + hyperparameters["learning_rate"], + hyperparameters["batch_size"], + hyperparameters["momentum"], + ) + ) + if accuracy > best_accuracy: + best_hyperparameters = hyperparameters + best_accuracy = accuracy + + # Record the best performing set of hyperparameters. + print( + """Best accuracy over {} trials was {:.3} with + learning_rate: {:.2} + batch_size: {} + momentum: {:.2} + """.format( + num_evaluations, + 100 * best_accuracy, + best_hyperparameters["learning_rate"], + best_hyperparameters["batch_size"], + best_hyperparameters["momentum"], + ) + ) + + +if __name__ == "__main__": + main() + + # changed line 2/2 + ray.shutdown() diff --git a/examples/ray_compat/readme.md b/examples/ray_compat/readme.md new file mode 100644 index 000000000..d6da86a7d --- /dev/null +++ b/examples/ray_compat/readme.md @@ -0,0 +1,18 @@ +## Ray Compatibility Layer Examples + +Examples in this directory demonstrate how to use Scaler's Ray compatibility layer. + +- `basic_local_cluster.py` + Shows how to use Scaler's Ray compatibility layer with the implicitly-created local cluster +- `basic_remote_cluster.py` + Shows how to use Scaler's Ray compatibility layer with a remote cluster +- `highly_parallel.py` + This example is copied from the ["Highly Parallel" page of Ray Core's Documentation](https://docs.ray.io/en/latest/ray-core/examples/highly_parallel.html) +- `batch_prediction.py` + This example is copied from the ["Batch Prediction" page of Ray Core's Documentation](https://docs.ray.io/en/latest/ray-core/examples/batch_prediction.html) +- `map_reduce.py` + This example is copied from the ["MapReduce" page of Ray Core's Documentation](https://docs.ray.io/en/latest/ray-core/examples/map_reduce.html) +- `plot_hyperparameter.py` + This example is copied from the ["Hyperparameter Tuning" page of Ray Core's Documentation](https://docs.ray.io/en/latest/ray-core/examples/plot_hyperparameter.html) +- `web_crawler.py` + This example is copied from the ["Web Crawler" page of Ray Core's Documentation](https://docs.ray.io/en/latest/ray-core/examples/web_crawler.html) diff --git a/examples/ray_compat/requirements.txt b/examples/ray_compat/requirements.txt new file mode 100644 index 000000000..d46f9f885 --- /dev/null +++ b/examples/ray_compat/requirements.txt @@ -0,0 +1,9 @@ +ray[default] +beautifulsoup4==4.11.1 +pyarrow==17.0.0 +pandas==2.0.3 +filelock==3.16.1 + +# the test that uses these is disabled +# torch==2.5.1 +# torchvision==0.19.1 diff --git a/examples/ray_compat/skip_tests.txt b/examples/ray_compat/skip_tests.txt new file mode 100644 index 000000000..30659a34c --- /dev/null +++ b/examples/ray_compat/skip_tests.txt @@ -0,0 +1 @@ +plot_hyperparameter.py diff --git a/examples/ray_compat/web_crawler.py b/examples/ray_compat/web_crawler.py new file mode 100644 index 000000000..ae1fd2489 --- /dev/null +++ b/examples/ray_compat/web_crawler.py @@ -0,0 +1,67 @@ +""" +This example was copied from https://docs.ray.io/en/latest/ray-core/examples/web_crawler.html + +Like in `highly_parallel.py`, only minimal changes are needed for the example to work on Scaler. +""" + +import time + +import ray +import requests +from bs4 import BeautifulSoup + +# changed line 1/2 +import scaler.compat.ray # noqa: F401 + + +def extract_links(elements, base_url, max_results=100): + links = [] + for e in elements: + url = e["href"] + if "https://" not in url: + url = base_url + url + if base_url in url: + links.append(url) + return set(links[:max_results]) + + +def find_links(start_url, base_url, depth=2): + if depth == 0: + return set() + + page = requests.get(start_url) + soup = BeautifulSoup(page.content, "html.parser") + elements = soup.find_all("a", href=True) + links = extract_links(elements, base_url) + + for url in links: + new_links = find_links(url, base_url, depth - 1) + links = links.union(new_links) + return links + + +base = "https://docs.ray.io/en/latest/" +docs = base + "index.html" + + +def main(): + @ray.remote + def find_links_task(start_url, base_url, depth=2): + return find_links(start_url, base_url, depth) + + start = time.time() + find_links(docs, base) + serial_elapsed = time.time() - start + + start = time.time() + [find_links_task.remote(f"{base}{lib}/index.html", base) for lib in ["", "", "", "rllib", "tune", "serve"]] + parallel_elapsed = time.time() - start + + print(f"serial time: {serial_elapsed:.2}s; parallel time: {parallel_elapsed:.2}s") + + +if __name__ == "__main__": + main() + + # changed line 2/2 + ray.shutdown() diff --git a/examples/readme.md b/examples/readme.md index c682f2436..be996f185 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -19,3 +19,5 @@ Ensure that the scheduler and cluster are set up before running clients. Shows how to send a basic task to scheduler - `task_capabilities.py` Shows how to use capabilities to route task to various workers +- `ray_compat/` + Shows how to use Scaler's Ray compatibility layer diff --git a/pyproject.toml b/pyproject.toml index 08b87b949..e61385073 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ uvloop = [ ] gui = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", - "nicegui[plotly]==3.4.0; python_version >= '3.9'", + "nicegui[plotly]==3.4.1; python_version >= '3.9'", ] graphblas = [ "python-graphblas", @@ -52,7 +52,7 @@ aws = [ ] all = [ "nicegui[plotly]==2.24.2; python_version == '3.8'", - "nicegui[plotly]==3.4.0; python_version >= '3.9'", + "nicegui[plotly]==3.4.1; python_version >= '3.9'", "python-graphblas", "numpy==1.24.4; python_version == '3.8'", "numpy==2.0.2; python_version == '3.9'", diff --git a/src/cpp/scaler/CMakeLists.txt b/src/cpp/scaler/CMakeLists.txt index 18fafc02c..e79526c94 100644 --- a/src/cpp/scaler/CMakeLists.txt +++ b/src/cpp/scaler/CMakeLists.txt @@ -1,7 +1,7 @@ if(LINUX OR APPLE) add_subdirectory(object_storage) else() - message(WARNING "Not building OSS, as it's not supported on this system currently!") + message(WARNING "Not building OSS and its Python Interface, as it's not supported on this system currently!") endif() if(LINUX OR WIN32) diff --git a/src/cpp/scaler/ymq/CMakeLists.txt b/src/cpp/scaler/ymq/CMakeLists.txt index 415f2afa8..24037c4e4 100644 --- a/src/cpp/scaler/ymq/CMakeLists.txt +++ b/src/cpp/scaler/ymq/CMakeLists.txt @@ -115,7 +115,7 @@ endif() if(WIN32) # ymq python ======================================================================================================= - + message(WARNING "Not building Python Interface for YMQ, as it's not supported on this system currently!") target_link_libraries(ymq_objs PRIVATE "ws2_32") target_compile_definitions(ymq_objs PRIVATE _WINSOCKAPI_=) # Yes, trailing equal to guarantee empty def endif() diff --git a/src/cpp/scaler/ymq/pymod_ymq/bytes.h b/src/cpp/scaler/ymq/pymod_ymq/bytes.h index ba30efe82..096b73b67 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/bytes.h +++ b/src/cpp/scaler/ymq/pymod_ymq/bytes.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // First-party #include "scaler/ymq/bytes.h" diff --git a/src/cpp/scaler/ymq/pymod_ymq/python.h b/src/cpp/scaler/ymq/pymod_ymq/compatibility.h similarity index 96% rename from src/cpp/scaler/ymq/pymod_ymq/python.h rename to src/cpp/scaler/ymq/pymod_ymq/compatibility.h index 1c3f1b48f..8e1c862b6 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/python.h +++ b/src/cpp/scaler/ymq/pymod_ymq/compatibility.h @@ -1,4 +1,6 @@ #pragma once +// NOTE: This file is needed because of we support backward compatibility to +// Python 3.8. This file will be removed once we drop the support to Python 3.8. #define PY_SSIZE_T_CLEAN diff --git a/src/cpp/scaler/ymq/pymod_ymq/exception.h b/src/cpp/scaler/ymq/pymod_ymq/exception.h index 226867b42..d942f90e4 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/exception.h +++ b/src/cpp/scaler/ymq/pymod_ymq/exception.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // First-party #include "scaler/error/error.h" diff --git a/src/cpp/scaler/ymq/pymod_ymq/io_context.h b/src/cpp/scaler/ymq/pymod_ymq/io_context.h index ca9a8c012..0589fdedf 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/io_context.h +++ b/src/cpp/scaler/ymq/pymod_ymq/io_context.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // C++ #include diff --git a/src/cpp/scaler/ymq/pymod_ymq/io_socket.h b/src/cpp/scaler/ymq/pymod_ymq/io_socket.h index 0a1f202c5..983a97caf 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/io_socket.h +++ b/src/cpp/scaler/ymq/pymod_ymq/io_socket.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // C++ #include diff --git a/src/cpp/scaler/ymq/pymod_ymq/message.h b/src/cpp/scaler/ymq/pymod_ymq/message.h index a17f80d7c..3ba070445 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/message.h +++ b/src/cpp/scaler/ymq/pymod_ymq/message.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // First-party #include "scaler/ymq/pymod_ymq/bytes.h" diff --git a/src/cpp/scaler/ymq/pymod_ymq/ymq.h b/src/cpp/scaler/ymq/pymod_ymq/ymq.h index e9c996430..aee702ba5 100644 --- a/src/cpp/scaler/ymq/pymod_ymq/ymq.h +++ b/src/cpp/scaler/ymq/pymod_ymq/ymq.h @@ -1,7 +1,7 @@ #pragma once // Python -#include "scaler/ymq/pymod_ymq/python.h" +#include "scaler/ymq/pymod_ymq/compatibility.h" // C #include diff --git a/src/scaler/compat/__init__.py b/src/scaler/compat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/scaler/compat/ray.py b/src/scaler/compat/ray.py new file mode 100644 index 000000000..98d670eb8 --- /dev/null +++ b/src/scaler/compat/ray.py @@ -0,0 +1,524 @@ +""" +This module provides a compatibility layer for Scaler that mimics the Ray interface. +It allows users familiar with Ray's API to interact with Scaler in a similar fashion, +including remote function execution, object referencing, and waiting for task completion. +""" + +import concurrent.futures +import inspect +from typing import Any, Callable, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union, cast +from unittest.mock import Mock, patch + +import psutil +from typing_extensions import ParamSpec + +from scaler.client.client import Client +from scaler.client.future import ScalerFuture +from scaler.client.object_reference import ObjectReference +from scaler.client.serializer.default import DefaultSerializer +from scaler.client.serializer.mixins import Serializer +from scaler.cluster.combo import SchedulerClusterCombo +from scaler.config.defaults import ( + DEFAULT_CLIENT_TIMEOUT_SECONDS, + DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, + DEFAULT_HEARTBEAT_INTERVAL_SECONDS, + DEFAULT_IO_THREADS, + DEFAULT_LOAD_BALANCE_SECONDS, + DEFAULT_LOAD_BALANCE_TRIGGER_TIMES, + DEFAULT_LOGGING_LEVEL, + DEFAULT_LOGGING_PATHS, + DEFAULT_MAX_NUMBER_OF_TASKS_WAITING, + DEFAULT_OBJECT_RETENTION_SECONDS, + DEFAULT_PER_WORKER_QUEUE_SIZE, + DEFAULT_TASK_TIMEOUT_SECONDS, + DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, + DEFAULT_WORKER_DEATH_TIMEOUT, + DEFAULT_WORKER_TIMEOUT_SECONDS, +) +from scaler.scheduler.allocate_policy.allocate_policy import AllocatePolicy + + +def _not_implemented(*args, **kwargs) -> None: + raise NotImplementedError + + +def _no_op(*_args, **_kwargs) -> None: + pass + + +class NotImplementedMock(Mock): + def __getattr__(self, _name): + raise NotImplementedError("this module is not supported in Scaler compatibility layer.") + + +# We patch the underlying init and shutdown functions to prevent the real +# ray from being initialized or shut down, which can cause issues with our +# compatibility layer, especially during atexit. +patch("ray._private.worker.init", new=_no_op).start() +patch("ray._private.worker.shutdown", new=_no_op).start() + +# Prevent the import of a module that causes issues during ray shutdown +# by replacing it with a mock object. This module contains a class decorated +# with @ray.remote, which is not yet supported by the Scaler compatibility layer. + +try: + patch("ray.experimental.channel.cpu_communicator", new=Mock()).start() +except AttributeError: + pass # this doesn't exist on old versions of Ray + +patch("ray.dag.compiled_dag_node", new=Mock()).start() + + +# Add no-op, mock, or not-implemented patches for various ray functions and modules +patch("ray.init", new=_no_op).start() +patch("ray.method", new=_not_implemented).start() +patch("ray.actor", new=Mock()).start() +patch("ray.runtime_context", new=NotImplementedMock()).start() +patch("ray.cross_language", new=NotImplementedMock()).start() +patch("ray.get_actor", new=_no_op).start() +patch("ray.get_gpu_ids", new=_not_implemented).start() +patch("ray.get_runtime_context", new=_not_implemented).start() +patch("ray.kill", new=_not_implemented).start() + + +combo: Optional[SchedulerClusterCombo] = None +client: Optional[Client] = None + + +def scaler_init( + address: Optional[str] = None, + *, + n_workers: Optional[int] = psutil.cpu_count(), + object_storage_address: Optional[str] = None, + monitor_address: Optional[str] = None, + per_worker_capabilities: Optional[Dict[str, int]] = None, + worker_io_threads: int = DEFAULT_IO_THREADS, + scheduler_io_threads: int = DEFAULT_IO_THREADS, + max_number_of_tasks_waiting: int = DEFAULT_MAX_NUMBER_OF_TASKS_WAITING, + heartbeat_interval_seconds: int = DEFAULT_HEARTBEAT_INTERVAL_SECONDS, + client_timeout_seconds: int = DEFAULT_CLIENT_TIMEOUT_SECONDS, + worker_timeout_seconds: int = DEFAULT_WORKER_TIMEOUT_SECONDS, + object_retention_seconds: int = DEFAULT_OBJECT_RETENTION_SECONDS, + task_timeout_seconds: int = DEFAULT_TASK_TIMEOUT_SECONDS, + death_timeout_seconds: int = DEFAULT_WORKER_DEATH_TIMEOUT, + load_balance_seconds: int = DEFAULT_LOAD_BALANCE_SECONDS, + load_balance_trigger_times: int = DEFAULT_LOAD_BALANCE_TRIGGER_TIMES, + garbage_collect_interval_seconds: int = DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, + trim_memory_threshold_bytes: int = DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, + per_worker_task_queue_size: int = DEFAULT_PER_WORKER_QUEUE_SIZE, + hard_processor_suspend: bool = DEFAULT_HARD_PROCESSOR_SUSPEND, + protected: bool = True, + allocate_policy: AllocatePolicy = AllocatePolicy.even, + event_loop: str = "builtin", + logging_paths: Tuple[str, ...] = DEFAULT_LOGGING_PATHS, + logging_level: str = DEFAULT_LOGGING_LEVEL, + logging_config_file: Optional[str] = None, + # client-specific options + profiling: bool = False, + timeout_seconds: int = DEFAULT_CLIENT_TIMEOUT_SECONDS, + serializer: Serializer = DefaultSerializer(), + stream_output: bool = False, +) -> None: + """ + Initializes Scaler's Ray compatibility layer. + + If `address` is provided, we connect to an existing Scaler cluster. + Otherwise, it starts a new local cluster with the specified configuration. + Several client-specific options can also be set, and shared options are passed to both client and cluster. + + Args: + address: The address of the Scaler scheduler to connect to. + n_workers: The number of workers to start in the local cluster. + Defaults to the number of CPU cores. + **kwargs: Other Scaler cluster configuration options. + """ + global client, combo + + if client is not None: + raise RuntimeError("Cannot initialize scaler twice") + + if address is None: + combo = SchedulerClusterCombo( + n_workers=n_workers, + object_storage_address=object_storage_address, + monitor_address=monitor_address, + per_worker_capabilities=per_worker_capabilities, + worker_io_threads=worker_io_threads, + scheduler_io_threads=scheduler_io_threads, + max_number_of_tasks_waiting=max_number_of_tasks_waiting, + heartbeat_interval_seconds=heartbeat_interval_seconds, + client_timeout_seconds=client_timeout_seconds, + worker_timeout_seconds=worker_timeout_seconds, + object_retention_seconds=object_retention_seconds, + task_timeout_seconds=task_timeout_seconds, + death_timeout_seconds=death_timeout_seconds, + load_balance_seconds=load_balance_seconds, + load_balance_trigger_times=load_balance_trigger_times, + garbage_collect_interval_seconds=garbage_collect_interval_seconds, + trim_memory_threshold_bytes=trim_memory_threshold_bytes, + per_worker_task_queue_size=per_worker_task_queue_size, + hard_processor_suspend=hard_processor_suspend, + protected=protected, + allocate_policy=allocate_policy, + event_loop=event_loop, + logging_paths=logging_paths, + logging_level=logging_level, + logging_config_file=logging_config_file, + ) + + address = combo.get_address() + + client = Client( + address=address, + profiling=profiling, + timeout_seconds=timeout_seconds, + heartbeat_interval_seconds=heartbeat_interval_seconds, + serializer=serializer, + stream_output=stream_output, + object_storage_address=object_storage_address, + ) + + +def shutdown() -> None: + """ + Disconnects the client and shuts down the local cluster if one was created. + + Mimics the behavior of `ray.shutdown()`. + """ + global client, combo + + if client: + client.disconnect() + if combo: + combo.shutdown() + + client = None + combo = None + + +patch("ray.shutdown", new=shutdown).start() + + +def is_initialized() -> bool: + """Checks if the Scaler client has been initialized.""" + return client is not None + + +patch("ray.is_initialized", new=is_initialized).start() + + +def ensure_init(): + """ + This is an internal function that ensures the Scaler client is initialized, calling `init()` with + default parameters if it is not. + """ + if not is_initialized(): + scaler_init() + + +T = TypeVar("T") +P = ParamSpec("P") +V = TypeVar("V") + + +class RayObjectReference(Generic[T]): + """ + A wrapper around a ScalerFuture to provide an API similar to a Ray ObjectRef. + + This class allows treating results of asynchronous Scaler tasks in a way + that is compatible with the Ray API. + """ + + _future: ScalerFuture + + # the index into the return value for num_results > 1 + _index: Optional[int] + + def __init__(self, future: ScalerFuture, index: Optional[int] = None) -> None: + """ + Initializes the RayObjectReference with a ScalerFuture. + + Args: + future: The ScalerFuture instance to wrap. + """ + self._future = future + self._index = index + + def get(self) -> T: + """ + Retrieves the result of the future, blocking until it's available. + + Returns: + The result of the completed future. + """ + obj = self._future.result() + + if self._index is None: + return obj + + try: + return obj[self._index] + except TypeError as e: + raise TypeError("num_returns can only be used on a function that returns an indexable object") from e + + def cancel(self) -> None: + """Attempts to cancel the future.""" + self._future.cancel() + + +def unwrap_ray_object_reference(maybe_ref: Union[T, RayObjectReference[T]]) -> T: + """ + Helper to get the result if the input is a RayObjectReference. + + If the input is a `RayObjectReference`, its result is returned. + Otherwise, the input is returned as is. This is used to transparently + handle passing of object references as arguments to remote functions. + + Args: + maybe_ref: The object to unwrap. + + Returns: + The result of the reference or the original object. + """ + if isinstance(maybe_ref, RayObjectReference): + return maybe_ref.get() + return maybe_ref + + +def _wrap_remote_fn(fn: Callable[P, T], client: Client) -> Callable[P, T]: + # this function forwards the implicit client to the worker and enables nesting + def forward_client(*args: P.args, **kwargs: P.kwargs) -> T: + import scaler.compat.ray + + scaler.compat.ray.client = client + return fn(*args, **kwargs) + + forward_client.__signature__ = inspect.signature(fn) # type: ignore[attr-defined] + + return forward_client + + +class RayRemote(Generic[P, T]): + """ + A wrapper for a function to make it "remote," similar to a Ray remote function. + + This class is typically instantiated by the `@ray.remote` decorator. + """ + + _fn: Callable[P, T] + + _num_returns: int + + def __init__(self, fn: Callable[P, T], num_returns: int = 1, **kwargs) -> None: + """ + Initializes the remote function wrapper. + + Args: + fn: The Python function to be executed remotely. + num_returns: The number of object refs returned by a call to this remote function. + **kwargs: This is provided for callsite compatibility. All additional keyword arguments are ignored. + """ + + self._set_options(num_returns=num_returns) + self._fn = fn + + def _set_options(self, **kwargs) -> None: + if "num_returns" in kwargs: + self._set_num_returns(kwargs["num_returns"]) + + def _set_num_returns(self, num_returns: Any) -> None: + if not isinstance(num_returns, int): + raise ValueError("num_returns must be an integer") + + if num_returns <= 0: + raise ValueError("num_returns must be > 0") + + self._num_returns = num_returns + + def remote(self, *args: P.args, **kwargs: P.kwargs) -> Union[RayObjectReference, List[RayObjectReference]]: + """ + Executes the wrapped function remotely. + + Args: + *args: Positional arguments for the remote function. + **kwargs: Keyword arguments for the remote function. + + Returns: + A RayObjectReference that can be used to retrieve the result, + or a list of RayObjectReferences if num_returns > 1. + """ + if not is_initialized(): + raise RuntimeError("Scaler is not initialized") + + # Ray supports passing object references into other remote functions + # so we must take special care to get their values + processed_args = [unwrap_ray_object_reference(arg) for arg in args] + processed_kwargs = {k: unwrap_ray_object_reference(v) for k, v in kwargs.items()} + + future = client.submit(_wrap_remote_fn(self._fn, client), *processed_args, **processed_kwargs) + + if self._num_returns == 1: + return RayObjectReference(future) + + return [RayObjectReference(future, index=i) for i in range(self._num_returns)] + + def options(self, *args, **kwargs) -> "RayRemote[P, T]": + return RayRemote(self._fn, *args, **kwargs) + + +def get(ref: Union[RayObjectReference[T], List[RayObjectReference[Any]]]) -> Union[T, List[Any]]: + """ + Retrieves the result from one or more RayObjectReferences. + + This function blocks until the results are available. Mimics `ray.get()`. + + Args: + ref: A single RayObjectReference or a list of them. + + Returns: + The result of the reference or a list of results. + """ + if isinstance(ref, List): + return [get(x) for x in ref] + if isinstance(ref, RayObjectReference): + return ref.get() + + raise RuntimeError(f"Unknown type [{type(ref)}] passed to ray.get()") + + +patch("ray.get", new=get).start() + + +def put(obj: Any) -> ObjectReference: + """ + Stores an object in the Scaler object store. Mimics `ray.put()`. + + Args: + obj: The Python object to be stored. + + Returns: + An ObjectReference that can be used to retrieve the object. + """ + return client.send_object(obj) + + +patch("ray.put", new=put).start() + + +def remote(*args, **kwargs) -> Union[RayRemote, Callable]: + """ + A decorator that creates a `RayRemote` instance from a regular function. + + Mimics the behavior of `@ray.remote`. This decorator can be used with or without arguments, + e.g., `@ray.remote` or `@ray.remote(num_cpus=1)`. + + All arguments passed to the decorator are ignored. + + Returns: + A RayRemote instance that can be called with `.remote()`, or a decorator + that produces a RayRemote instance. + """ + ensure_init() + + def _decorator(fn: Callable) -> RayRemote: + if isinstance(fn, type): # Check if 'fn' is a class + raise NotImplementedError( + "Decorating classes with @ray.remote is not yet supported in Scaler compatibility layer." + ) + return RayRemote(fn, **kwargs) + + if len(args) == 1 and callable(args[0]) and not kwargs: + # This is the case: @ray.remote + return _decorator(args[0]) + else: + # This is the case: @ray.remote(...) + return _decorator + + +patch("ray.remote", new=remote).start() + + +def cancel(ref: RayObjectReference) -> None: + """ + Attempts to cancel the execution of a task. Mimics `ray.cancel()`. + + Args: + ref: The RayObjectReference corresponding to the task to be canceled. + """ + ref.cancel() + + +patch("ray.cancel", new=cancel).start() + + +class _RayUtil: + def as_completed(self, refs: List[RayObjectReference[T]]) -> Iterator[RayObjectReference[T]]: + """ + Returns an iterator that yields object references as they are completed. + Mimics `ray.util.as_completed()`. + """ + future_to_ref = {ref._future: ref for ref in refs} + for future in concurrent.futures.as_completed(future_to_ref.keys()): + yield future_to_ref[cast(ScalerFuture, future)] + + # python3.8 cannot handle giving real type hints to `fn` + def map_unordered(self, fn: RayRemote, values: List[V]) -> Iterator[T]: + """ + Applies a remote function to each value in a list and yields the results + as they become available. Mimics `ray.util.map_unordered()`. + + The function `fn` must be a @ray.remote decorated function, with `num_returns=1`. + """ + if not hasattr(fn, "remote") or not callable(fn.remote): + raise TypeError("The function passed to map_unordered must be a @ray.remote function.") + + if fn._num_returns > 1: + raise TypeError("map_unordered only supports remote functions with num_returns=1") + + refs = [cast(RayObjectReference, fn.remote(v)) for v in values] + for ref in self.as_completed(refs): + yield ref.get() + + +patch("ray.util", new=_RayUtil()).start() + + +def wait( + refs: List[RayObjectReference[T]], *, num_returns: Optional[int] = 1, timeout: Optional[float] = None +) -> Tuple[List[RayObjectReference[T]], List[RayObjectReference[T]]]: + """ + Waits for a number of object references to be ready. Mimics `ray.wait()`. + + Args: + refs: A list of RayObjectReferences to wait on. + num_returns: The number of references to wait for. If None, waits for all. + timeout: The maximum time in seconds to wait. + + Returns: + A tuple containing two lists: the list of ready references and the + list of remaining, not-ready references. + """ + + if num_returns is not None and num_returns > len(refs): + raise ValueError("num_returns cannot be greater than the number of provided object references") + + if num_returns is not None and num_returns <= 0: + return [], list(refs) + + future_to_ref = {ref._future: ref for ref in refs} + done = set() + + try: + for future in concurrent.futures.as_completed((ref._future for ref in refs), timeout=timeout): + done.add(future_to_ref[cast(ScalerFuture, future)]) + + if num_returns is not None and len(done) == num_returns: + break + except concurrent.futures.TimeoutError: + pass + + return list(done), [ref for ref in refs if ref not in done] + + +patch("ray.wait", new=wait).start() diff --git a/src/scaler/utility/queues/async_priority_queue.py b/src/scaler/utility/queues/async_priority_queue.py index 7b846e65b..f0d444989 100644 --- a/src/scaler/utility/queues/async_priority_queue.py +++ b/src/scaler/utility/queues/async_priority_queue.py @@ -1,7 +1,8 @@ -import heapq -import sys from asyncio import Queue -from typing import Any, Dict, List, Tuple, Union +from dataclasses import dataclass +from typing import Any, Dict, Tuple, Union + +from sortedcontainers import SortedDict PriorityType = Union[int, Tuple["PriorityType", ...]] @@ -9,62 +10,74 @@ class AsyncPriorityQueue(Queue): """A subclass of Queue; retrieves entries in priority order (lowest first). - Entries are typically list of the form: [priority, data]. + Input entries are typically list of the form: [priority, data]. """ + @dataclass(frozen=True) + class MapKey: + priority: int + count: int + + def __lt__(self, other): + return (self.priority, self.count) < (other.priority, other.count) + + def __hash__(self): + return hash((self.priority, self.count)) + + @dataclass + class LocatorValue: + map_key: "AsyncPriorityQueue.MapKey" + data: bytes + def __len__(self): return len(self._queue) def _init(self, maxsize): - self._queue: List[List] = [] - self._locator: Dict[bytes, List] = {} + self._locator: Dict[bytes, AsyncPriorityQueue.LocatorValue] = {} + self._queue: Dict[AsyncPriorityQueue.MapKey, bytes] = SortedDict() + self._item_counter: int = 0 def _put(self, item): if not isinstance(item, list): item = list(item) - heapq.heappush(self._queue, item) - self._locator[item[1]] = item + priority, data = item + map_key = AsyncPriorityQueue.MapKey(priority=priority, count=self._item_counter) + self._locator[data] = AsyncPriorityQueue.LocatorValue(map_key=map_key, data=data) + self._queue[map_key] = data + self._item_counter += 1 def _get(self): - priority, data = heapq.heappop(self._queue) + map_key, data = self._queue.popitem(0) # type: ignore[call-arg] self._locator.pop(data) - return priority, data + return map_key.priority, data def remove(self, data): - # this operation is O(n), first change priority to -1 and pop from top of the heap, mark it as invalid - # entry in the heap is not good idea as those invalid, entry will never get removed, so we used heapq internal - # function _siftdown to maintain min heap invariant - item = self._locator.pop(data) - i = self._queue.index(item) # O(n) - item[0] = self.__to_lowest_priority(item[0]) - heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined] - assert heapq.heappop(self._queue) == item + loc_value = self._locator.pop(data) + self._queue.pop(loc_value.map_key) def decrease_priority(self, data): - # this operation should be O(n), mark it as invalid entry in the heap is not good idea as those invalid - # entry will never get removed, so we used heapq internal function _siftdown to maintain min heap invariant - item = self._locator[data] - i = self._queue.index(item) # O(n) - item[0] = self.__to_lower_priority(item[0]) - heapq._siftdown(self._queue, 0, i) # type: ignore[attr-defined] + # Decrease the priority *value* of an item in the queue, effectively move data closer to the front. + # Notes: + # - *priority* in the signature means the priority *value* of the item. + # - Time complexity is O(log n) due to the underlying SortedDict structure. + + loc_value = self._locator[data] + map_key = AsyncPriorityQueue.MapKey(priority=loc_value.map_key.priority - 1, count=self._item_counter) + new_loc_value = AsyncPriorityQueue.LocatorValue(map_key=map_key, data=data) + self._locator[data] = new_loc_value + self._queue.pop(loc_value.map_key) + self._queue[map_key] = data + self._item_counter += 1 def max_priority_item(self) -> Tuple[PriorityType, Any]: - """output the Tuple of top priority number and top priority item""" - item = heapq.heappop(self._queue) - heapq.heappush(self._queue, item) - return item[0], item[1] - - @classmethod - def __to_lowest_priority(cls, original_priority: PriorityType) -> PriorityType: - if isinstance(original_priority, tuple): - return tuple(cls.__to_lowest_priority(value) for value in original_priority) - else: - return -sys.maxsize - 1 - - @classmethod - def __to_lower_priority(cls, original_priority: PriorityType) -> PriorityType: - if isinstance(original_priority, tuple): - return tuple(cls.__to_lower_priority(value) for value in original_priority) - else: - return original_priority - 1 + """Return the current item at the front of the queue without removing it from the queue. + + Notes: + - This is a "peek" operation; it does not modify the queue. + - For items with the same priority, insertion order determines which item is returned first. + - *priority* means the priority in the queue + - Time complexity is O(1) as we are peeking in the head + """ + loc_value = self._queue.peekitem(0) # type: ignore[attr-defined] + return (loc_value[0].priority, loc_value[1]) diff --git a/src/scaler/utility/queues/async_sorted_priority_queue.py b/src/scaler/utility/queues/async_sorted_priority_queue.py deleted file mode 100644 index 39bd99079..000000000 --- a/src/scaler/utility/queues/async_sorted_priority_queue.py +++ /dev/null @@ -1,45 +0,0 @@ -from asyncio import Queue -from typing import Any, Dict - -from scaler.utility.queues.async_priority_queue import AsyncPriorityQueue - - -class AsyncSortedPriorityQueue(Queue): - """A subclass of Queue; retrieves entries in priority order (lowest first), and then by adding order. - - Entries are typically list of the form: [priority number, data]. - """ - - def __len__(self): - return len(self._queue) - - def _init(self, maxsize: int): - self._queue = AsyncPriorityQueue() - - # Keeps an item count to assign monotonic integer to queued items, so to also keep the priority queue sorted by - # adding order. - # See https://docs.python.org/3/library/heapq.html#priority-queue-implementation-notes. - self._item_counter: int = 0 - self._data_to_item_id: Dict[Any, int] = dict() - - def _put(self, item) -> None: - priority, data = item - - if data in self._data_to_item_id: - raise ValueError(f"item `{data}` already in the queue") - - item_id = self._item_counter - self._item_counter += 1 - - self._queue._put([priority, (item_id, data)]) - self._data_to_item_id[data] = item_id - - def _get(self): - priority, (_, data) = self._queue._get() - self._data_to_item_id.pop(data) - - return [priority, data] - - def remove(self, data: Any) -> None: - item_id = self._data_to_item_id.pop(data) - self._queue.remove((item_id, data)) diff --git a/src/scaler/version.txt b/src/scaler/version.txt index 8c64af98f..74ac4fd74 100644 --- a/src/scaler/version.txt +++ b/src/scaler/version.txt @@ -1 +1 @@ -1.12.41 +1.12.44 diff --git a/src/scaler/worker/agent/task_manager.py b/src/scaler/worker/agent/task_manager.py index cc6115310..a8f62890f 100644 --- a/src/scaler/worker/agent/task_manager.py +++ b/src/scaler/worker/agent/task_manager.py @@ -6,7 +6,7 @@ from scaler.utility.identifiers import TaskID from scaler.utility.metadata.task_flags import retrieve_task_flags_from_task from scaler.utility.mixins import Looper -from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue +from scaler.utility.queues.async_priority_queue import AsyncPriorityQueue from scaler.worker.agent.mixins import ProcessorManager, TaskManager _SUSPENDED_TASKS_PRIORITY = 1 @@ -29,7 +29,7 @@ def __init__(self, task_timeout_seconds: int): # 4. Task(priority=0) # # We want to execute the tasks in this order: 2-3-1-4. - self._queued_task_ids = AsyncSortedPriorityQueue() + self._queued_task_ids = AsyncPriorityQueue() self._processing_task_ids: Set[TaskID] = set() # Tasks associated with a processor, including suspended tasks diff --git a/src/scaler/worker_adapter/symphony/task_manager.py b/src/scaler/worker_adapter/symphony/task_manager.py index 73f5cf910..271a9ed35 100644 --- a/src/scaler/worker_adapter/symphony/task_manager.py +++ b/src/scaler/worker_adapter/symphony/task_manager.py @@ -13,7 +13,7 @@ from scaler.utility.identifiers import ObjectID, TaskID from scaler.utility.metadata.task_flags import retrieve_task_flags_from_task from scaler.utility.mixins import Looper -from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue +from scaler.utility.queues.async_priority_queue import AsyncPriorityQueue from scaler.utility.serialization import serialize_failure from scaler.worker.agent.mixins import HeartbeatManager, TaskManager from scaler.worker_adapter.symphony.callback import SessionCallback @@ -40,7 +40,7 @@ def __init__(self, base_concurrency: int, service_name: str): self._serializers: Dict[bytes, Serializer] = dict() - self._queued_task_id_queue = AsyncSortedPriorityQueue() + self._queued_task_id_queue = AsyncPriorityQueue() self._queued_task_ids: Set[bytes] = set() self._acquiring_task_ids: Set[TaskID] = set() # tasks contesting the semaphore diff --git a/tests/compat/test_ray_compat.py b/tests/compat/test_ray_compat.py new file mode 100644 index 000000000..e97181bd1 --- /dev/null +++ b/tests/compat/test_ray_compat.py @@ -0,0 +1,229 @@ +import time +import unittest + +import numpy as np +import ray +from numpy import random + +from scaler.cluster.combo import SchedulerClusterCombo + +# this patches ray +from scaler.compat.ray import scaler_init + + +class TestRayCompat(unittest.TestCase): + def tearDown(self): + ray.shutdown() + + def test_basic(self) -> None: + ray.init() + + @ray.remote + def remote_fn() -> int: + return 7 + + ref = remote_fn.remote() + + self.assertEqual(ray.get(ref), 7) + + # https://docs.ray.io/en/latest/ray-core/walkthrough.html#running-a-task + def test_ray_example_square(self) -> None: + # Define the square task. + @ray.remote + def square(x): + return x * x + + # Launch four parallel square tasks. + futures = [square.remote(i) for i in range(4)] + + # Retrieve results. + self.assertEqual(ray.get(futures), [0, 1, 4, 9]) + + # https://docs.ray.io/en/latest/ray-core/walkthrough.html#passing-objects + def test_ray_example_numpy(self) -> None: + # Define a task that sums the values in a matrix. + @ray.remote + def sum_matrix(matrix): + return np.sum(matrix) + + # Call the task with a literal argument value. + print(ray.get(sum_matrix.remote(np.ones((100, 100))))) + # -> 10000.0 + + # Put a large array into the object store. + matrix_ref = ray.put(np.ones((1000, 1000))) + + # Call the task with the object reference as an argument. + self.assertEqual(ray.get(sum_matrix.remote(matrix_ref)), 1000000.0) + + ray.shutdown() + + # https://docs.ray.io/en/latest/ray-core/tasks/nested-tasks.html#nested-remote-functions + def test_ray_example_nested_simple(self) -> None: + @ray.remote + def f(): + return 1 + + @ray.remote + def g(): + # Call f 4 times and return the resulting object refs. + return [f.remote() for _ in range(4)] + + @ray.remote + def h(): + # Call f 4 times, block until those 4 tasks finish, + # retrieve the results, and return the values. + return ray.get([f.remote() for _ in range(4)]) + + self.assertEqual(ray.get(h.remote()), [1, 1, 1, 1]) + + # https://docs.ray.io/en/latest/ray-core/patterns/nested-tasks.html#code-example + def test_ray_example_nested_quicksort(self) -> None: + def partition(collection): + # Use the last element as the pivot + pivot = collection.pop() + greater, lesser = [], [] + for element in collection: + if element > pivot: + greater.append(element) + else: + lesser.append(element) + return lesser, pivot, greater + + def quick_sort(collection): + if len(collection) <= 200000: # magic number + return sorted(collection) + else: + lesser, pivot, greater = partition(collection) + lesser = quick_sort(lesser) + greater = quick_sort(greater) + return lesser + [pivot] + greater + + @ray.remote + def quick_sort_distributed(collection): + # Tiny tasks are an antipattern. + # Thus, in our example we have a "magic number" to + # toggle when distributed recursion should be used vs + # when the sorting should be done in place. The rule + # of thumb is that the duration of an individual task + # should be at least 1 second. + if len(collection) <= 200000: # magic number + return sorted(collection) + else: + lesser, pivot, greater = partition(collection) + lesser = quick_sort_distributed.remote(lesser) + greater = quick_sort_distributed.remote(greater) + return ray.get(lesser) + [pivot] + ray.get(greater) + + for size in [200000, 4000000, 8000000]: + unsorted = random.randint(1000000, size=(size)).tolist() + s = time.time() + sequential_sorted = quick_sort(unsorted[:]) + print(f"Sequential execution: {(time.time() - s):.3f}") + s = time.time() + distributed_sorted = ray.get(quick_sort_distributed.remote(unsorted)) + print(f"Distributed execution: {(time.time() - s):.3f}") + print("--" * 10) + + print(len(sequential_sorted), len(distributed_sorted)) + + self.assertEqual( + sequential_sorted, + distributed_sorted, + msg=f"Expected sequential and distributed sorts to match for {size} element case", + ) + + def test_ray_passing_refs(self) -> None: + @ray.remote + def my_function() -> int: + return 1 + + @ray.remote + def function_with_an_argument(value: int) -> int: + return value + 1 + + obj_ref1 = my_function.remote() + self.assertEqual(ray.get(obj_ref1), 1) + + # You can pass an object ref as an argument to another Ray task. + obj_ref2 = function_with_an_argument.remote(obj_ref1) + self.assertEqual(ray.get(obj_ref2), 2) + + def test_ray_wait_timeout(self) -> None: + @ray.remote + def sleep(secs: int) -> None: + time.sleep(secs) + + refs = [sleep.remote(x) for x in (2, 10)] + ready, remaining = ray.wait(refs, timeout=5) + + self.assertEqual(ready, refs[:1]) + self.assertEqual(remaining, refs[1:]) + + def test_ray_wait_no_timeout(self) -> None: + @ray.remote + def sleep(secs: int) -> None: + time.sleep(secs) + + refs = [sleep.remote(x) for x in (2, 10)] + ready, remaining = ray.wait(refs, num_returns=2, timeout=None) + + self.assertCountEqual(ready, refs) + self.assertEqual(remaining, []) + + def test_ray_wait_num_returns(self) -> None: + @ray.remote + def sleep(secs: int) -> None: + time.sleep(secs) + + refs = [sleep.remote(x) for x in (2, 10)] + ready, remaining = ray.wait(refs, num_returns=1, timeout=None) + + self.assertEqual(ready, refs[:1]) + self.assertEqual(remaining, refs[1:]) + + def test_ray_util_as_completed(self) -> None: + @ray.remote + def sleep(secs: int) -> int: + time.sleep(secs) + return secs + + refs = [sleep.remote(x) for x in (2, 1, 3)] + completed_refs = [] + for ref in ray.util.as_completed(refs): + completed_refs.append(ref) + + # The order of completion should be 1, 2, 3 + self.assertEqual(ray.get(completed_refs[0]), 1) + self.assertEqual(ray.get(completed_refs[1]), 2) + self.assertEqual(ray.get(completed_refs[2]), 3) + + def test_ray_util_map_unordered(self) -> None: + @ray.remote + def square(x: int) -> int: + time.sleep(random.uniform(0, 0.1)) + return x * x + + values = list(range(10)) + results = [] + for result in ray.util.map_unordered(square, values): + results.append(result) + + self.assertEqual(sorted(results), [x * x for x in values]) + + def test_ray_external_cluster(self) -> None: + combo = SchedulerClusterCombo(n_workers=1) + + # explicitly init scaler's ray interface, passing the address of an existing cluster + scaler_init(address=combo.get_address()) + + @ray.remote + def random() -> int: + return 7 + + self.assertEqual(ray.get(random.remote()), 7) + + def test_ray_actor_not_implemented(self) -> None: + with self.assertRaises(NotImplementedError): + # Any access to ray.actor should raise NotImplementedError + _ = ray.actor.ActorClass diff --git a/tests/core/test_async_priority_queue.py b/tests/core/test_async_priority_queue.py index 55bd232e5..5af90392c 100644 --- a/tests/core/test_async_priority_queue.py +++ b/tests/core/test_async_priority_queue.py @@ -11,7 +11,7 @@ def setUp(self) -> None: setup_logger() logging_test_name(self) - def test_async_priority_queue(self): + def test_async_priority_queue_basic(self): async def async_test(): queue = AsyncPriorityQueue() await queue.put((5, 5)) @@ -38,3 +38,113 @@ async def async_test(): self.assertTrue(queue.empty()) asyncio.run(async_test()) + + def test_stable_insertion(self): + async def async_test(): + queue = AsyncPriorityQueue() + + await queue.put((1, 4)) + await queue.put((1, 3)) + await queue.put((1, 2)) + + # Stability: insertion order preserved + self.assertEqual(await queue.get(), (1, 4)) + self.assertEqual(await queue.get(), (1, 3)) + self.assertEqual(await queue.get(), (1, 2)) + + asyncio.run(async_test()) + + def test_decrease_priority_reorders_correctly(self): + async def async_test(): + queue = AsyncPriorityQueue() + + await queue.put((5, "x")) + await queue.put((1, "y")) + await queue.put((3, "z")) + + queue.decrease_priority("x") + queue.decrease_priority("x") + queue.decrease_priority("x") + # "x" has priority 2 after decrease + + self.assertEqual(await queue.get(), (1, "y")) + self.assertEqual(await queue.get(), (2, "x")) + self.assertEqual(await queue.get(), (3, "z")) + + asyncio.run(async_test()) + + def test_remove(self): + async def async_test(): + queue = AsyncPriorityQueue() + + await queue.put((1, "a")) + await queue.put((2, "b")) + await queue.put((3, "c")) + await queue.put((4, "d")) + + queue.remove("b") + queue.remove("d") + + self.assertEqual(queue.qsize(), 2) + self.assertEqual(await queue.get(), (1, "a")) + self.assertEqual(await queue.get(), (3, "c")) + self.assertTrue(queue.empty()) + + asyncio.run(async_test()) + + def test_max_priority_item(self): + async def async_test(): + queue = AsyncPriorityQueue() + + await queue.put((10, "low")) + await queue.put((1, "high")) + await queue.put((5, "mid")) + + priority, data = queue.max_priority_item() + self.assertEqual(priority, 1) + self.assertEqual(data, "high") + + # Ensure peek does not remove + self.assertEqual(queue.qsize(), 3) + + asyncio.run(async_test()) + + def test_interleaved_put_get(self): + async def async_test(): + queue = AsyncPriorityQueue() + + await queue.put((2, "b")) + self.assertEqual(await queue.get(), (2, "b")) + + await queue.put((3, "c")) + await queue.put((1, "a")) + + self.assertEqual(await queue.get(), (1, "a")) + + await queue.put((0, "z")) + self.assertEqual(await queue.get(), (0, "z")) + self.assertEqual(await queue.get(), (3, "c")) + + self.assertTrue(queue.empty()) + + asyncio.run(async_test()) + + def test_len(self): + async def async_test(): + queue = AsyncPriorityQueue() + + self.assertEqual(len(queue), 0) + self.assertEqual(queue.qsize(), 0) + + await queue.put((1, 1)) + await queue.put((2, 2)) + + self.assertEqual(len(queue), 2) + self.assertEqual(queue.qsize(), 2) + + await queue.get() + + self.assertEqual(len(queue), 1) + self.assertEqual(queue.qsize(), 1) + + asyncio.run(async_test()) diff --git a/tests/core/test_async_sorted_priority_queue.py b/tests/core/test_async_sorted_priority_queue.py deleted file mode 100644 index d5987a732..000000000 --- a/tests/core/test_async_sorted_priority_queue.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio -import unittest - -from scaler.utility.logging.utility import setup_logger -from scaler.utility.queues.async_sorted_priority_queue import AsyncSortedPriorityQueue -from tests.utility.utility import logging_test_name - - -class TestSortedPriorityQueue(unittest.TestCase): - def setUp(self) -> None: - setup_logger() - logging_test_name(self) - - def test_sorted_priority_queue(self): - async def async_test(): - queue = AsyncSortedPriorityQueue() - - await queue.put([2, 3]) - await queue.put([3, 5]) - await queue.put([1, 1]) - await queue.put([3, 6]) - await queue.put([2, 4]) - await queue.put([-3, 0]) # supports negative priorities - await queue.put([1, 2]) - - queue.remove(4) - self.assertEqual(queue.qsize(), 6) - - self.assertEqual(await queue.get(), [-3, 0]) - self.assertEqual(await queue.get(), [1, 1]) - self.assertEqual(await queue.get(), [1, 2]) - self.assertEqual(await queue.get(), [2, 3]) - self.assertEqual(await queue.get(), [3, 5]) - self.assertEqual(await queue.get(), [3, 6]) - self.assertEqual(queue.qsize(), 0) - self.assertTrue(not queue) - self.assertTrue(queue.empty()) - - asyncio.run(async_test()) diff --git a/tests/cpp/ymq/CMakeLists.txt b/tests/cpp/ymq/CMakeLists.txt index 318ff3f33..6ad01f46e 100644 --- a/tests/cpp/ymq/CMakeLists.txt +++ b/tests/cpp/ymq/CMakeLists.txt @@ -2,10 +2,16 @@ add_test_executable(test_ymq test_ymq.cpp) target_sources(test_ymq PRIVATE common/testing.h + common/testing.cpp common/utils.h common/utils.cpp - net/socket.h + net/address.h + net/address.cpp + net/i_socket.h + net/socket_utils.h + net/tcp_socket.h + net/uds_socket.h pipe/pipe.h pipe/pipe_utils.h @@ -18,11 +24,18 @@ if(LINUX OR APPLE) pipe/pipe_utils_linux.cpp pipe/pipe_reader_linux.cpp pipe/pipe_writer_linux.cpp - net/socket_linux.cpp) + common/utils_linux.cpp + common/testing_linux.cpp + net/socket_utils_linux.cpp + net/tcp_socket_linux.cpp + net/uds_socket.cpp) elseif(WIN32) target_sources(test_ymq PRIVATE pipe/pipe_utils_windows.cpp pipe/pipe_reader_windows.cpp pipe/pipe_writer_windows.cpp - net/socket_windows.cpp) + common/testing_windows.cpp + common/utils_windows.cpp + net/socket_utils_windows.cpp + net/tcp_socket_windows.cpp) endif() diff --git a/tests/cpp/ymq/common/testing.cpp b/tests/cpp/ymq/common/testing.cpp new file mode 100644 index 000000000..640f60557 --- /dev/null +++ b/tests/cpp/ymq/common/testing.cpp @@ -0,0 +1,211 @@ +#define PY_SSIZE_T_CLEAN + +// on Windows and in debug mode, undefine _DEBUG before including Python.h +// this prevents issues including the debug version of the Python library +#if defined(_WIN32) && defined(_DEBUG) +#undef _DEBUG +#include +#define _DEBUG +#else +#include +#endif + +#include + +#include +#include +#include +#include +#include +#include + +#include "tests/cpp/ymq/common/testing.h" +#include "tests/cpp/ymq/common/utils.h" + +#ifdef _WIN32 +#define popen _popen +#define pclose _pclose +#endif + +TestResult return_failure_if_false(bool cond, const char* msg, const char* condStr, const char* file, int line) +{ + // Failure: ... (assertion failed) at file:line + if (!cond) { + std::cerr << "Failure"; + if (condStr) + std::cerr << ": " << condStr; + if (msg) + std::cerr << " (" << msg << ")"; + else + std::cerr << " (assertion failed)"; + if (file) + std::cerr << " at " << file << ":" << line; + std::cerr << '\n'; + return TestResult::Failure; + } + return TestResult::Success; +} + +// in the case that there's no msg, delegate +TestResult return_failure_if_false(bool cond, const char* condStr, const char* file, int line) +{ + return return_failure_if_false(cond, nullptr, condStr, file, line); +} + +std::wstring discover_python_home(std::string command) +{ + // leverage the system's command line to get the current python prefix + FILE* pipe = popen(std::format("{} -c \"import sys; print(sys.prefix)\"", command).c_str(), "r"); + if (!pipe) + throw std::runtime_error("failed to start python process to discover prefix"); + + std::array buffer {}; + std::string output {}; + + size_t n = 0; + while ((n = fread(buffer.data(), 1, buffer.size(), pipe)) > 0) + output.append(buffer.data(), n); + + // remove trailing whitespace + output.erase(output.find_last_not_of("\r\n") + 1); + + auto status = pclose(pipe); + if (status < 0) + throw std::runtime_error("failed to close close process"); + else if (status > 0) + throw std::runtime_error("process returned non-zero exit code: " + std::to_string(status)); + + // assume it's ascii, so we can just cast it as a wstring + return std::wstring(output.begin(), output.end()); +} + +void ensure_python_initialized() +{ + if (Py_IsInitialized()) + return; + + ensure_python_home(); + Py_Initialize(); + + // add the cwd to the path + { + PyObject* sysPath = PySys_GetObject("path"); + if (!sysPath) + throw std::runtime_error("failed to get sys.path"); + + PyObject* newPath = PyUnicode_FromString("."); + if (!newPath) + throw std::runtime_error("failed to create Python string"); + + if (PyList_Append(sysPath, newPath) < 0) { + Py_DECREF(newPath); + throw std::runtime_error("failed to append to sys.path"); + } + + Py_DECREF(newPath); + } + + // release the GIL, the caller will have to acquire it again + PyEval_SaveThread(); +} + +void maybe_finalize_python() +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + if (!Py_IsInitialized()) + return; + + Py_Finalize(); + + // stop compiler from complaining that it's unused + (void)gstate; +} + +TestResult run_python(const char* path, std::vector> argv) +{ + PyGILState_STATE gstate = PyGILState_Ensure(); + + auto pidStr = std::to_string(get_listener_pid()); + argv.insert(argv.begin(), pidStr.c_str()); + argv.insert(argv.begin(), "mitm"); + + // set argv + { + PyObject* pyArgv = PyList_New(argv.size()); + if (!pyArgv) + goto exception; + + for (size_t i = 0; i < argv.size(); i++) + if (argv[i]) + PyList_SET_ITEM(pyArgv, i, PyUnicode_FromString(argv[i].value().c_str())); + else + PyList_SET_ITEM(pyArgv, i, Py_None); + + if (PySys_SetObject("argv", pyArgv) < 0) + goto exception; + + Py_DECREF(pyArgv); + } + + { + std::ifstream file(path); + std::stringstream buffer; + buffer << file.rdbuf(); + + int rc = PyRun_SimpleString(buffer.str().c_str()); + file.close(); + + if (rc < 0) + throw std::runtime_error("failed to run python script"); + } + + PyGILState_Release(gstate); + return TestResult::Success; + +exception: + PyGILState_Release(gstate); + return TestResult::Failure; +} + +TestResult run_mitm( + std::string testCase, + std::string mitmIp, + uint16_t mitmPort, + std::string remoteIp, + uint16_t remotePort, + std::vector extraArgs) +{ + auto cwd = std::filesystem::current_path(); + chdir_to_project_root(); + + // we build the args for the user to make calling the function more convenient + std::vector> args { + testCase, mitmIp, std::to_string(mitmPort), remoteIp, std::to_string(remotePort)}; + + for (auto arg: extraArgs) + args.push_back(arg); + + auto result = run_python("tests/cpp/ymq/py_mitm/main.py", args); + + // change back to the original working directory + std::filesystem::current_path(cwd); + return result; +} + +void test_wrapper(std::function fn, int timeoutSecs, PipeWriter pipeWr, void* hEvent) +{ + TestResult result = TestResult::Failure; + try { + result = fn(); + } catch (const std::exception& e) { + std::cerr << "Exception: " << e.what() << std::endl; + result = TestResult::Failure; + } catch (...) { + std::cerr << "Unknown exception" << std::endl; + result = TestResult::Failure; + } + + pipeWr.write_all((char*)&result, sizeof(TestResult)); + + signal_event(hEvent); +} diff --git a/tests/cpp/ymq/common/testing.h b/tests/cpp/ymq/common/testing.h index f85ef15c2..714fc9a1d 100644 --- a/tests/cpp/ymq/common/testing.h +++ b/tests/cpp/ymq/common/testing.h @@ -1,612 +1,59 @@ #pragma once -#define PY_SSIZE_T_CLEAN - -// if on Windows and in debug mode, undefine _DEBUG before including Python.h -// this prevents issues including the debug version of the Python library -#if defined(_WIN32) && defined(_DEBUG) -#undef _DEBUG -#include -#define _DEBUG -#else -#include -#endif - -#include -#include -#include - -#ifdef __linux__ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#endif // __linux__ -#ifdef _WIN32 -#include -#include -#include -#include - -// the windows timer apis work in 100-nanosecond units -const LONGLONG ns_per_second = 1'000'000'000LL; -const LONGLONG ns_per_unit = 100LL; // 1 unit = 100 nanoseconds - -#define popen _popen -#define pclose _pclose -#endif // _WIN32 - -#include -#include -#include #include -#include -#include -#include -#include -#include -#include -#include #include -#include #include #include -#include -#include -#include -#include - -#include "tests/cpp/ymq/common/utils.h" -#include "tests/cpp/ymq/net/socket.h" -#include "tests/cpp/ymq/pipe/pipe.h" -using namespace std::chrono_literals; +#include "tests/cpp/ymq/pipe/pipe_writer.h" enum class TestResult : char { Success = 1, Failure = 2 }; -inline TestResult return_failure_if_false( - bool cond, const char* msg = nullptr, const char* cond_str = nullptr, const char* file = nullptr, int line = 0) -{ - // Failure: ... (assertion failed) at file:line - if (!cond) { - std::cerr << "Failure"; - if (cond_str) - std::cerr << ": " << cond_str; - if (msg) - std::cerr << " (" << msg << ")"; - else - std::cerr << " (assertion failed)"; - if (file) - std::cerr << " at " << file << ":" << line; - std::cerr << '\n'; - return TestResult::Failure; - } - return TestResult::Success; -} +TestResult return_failure_if_false( + bool cond, const char* msg = nullptr, const char* condStr = nullptr, const char* file = nullptr, int line = 0); // in the case that there's no msg, delegate -inline TestResult return_failure_if_false(bool cond, const char* cond_str, const char* file, int line) -{ - return return_failure_if_false(cond, nullptr, cond_str, file, line); -} +TestResult return_failure_if_false(bool cond, const char* condStr, const char* file, int line); #define RETURN_FAILURE_IF_FALSE(cond, ...) \ if (return_failure_if_false((cond), ##__VA_ARGS__, #cond, __FILE__, __LINE__) == TestResult::Failure) \ return TestResult::Failure; -// hEvent: unused on linux, event handle on windows -inline void fork_wrapper(std::function fn, int timeout_secs, PipeWriter pipe_wr, void* hEvent) -{ - TestResult result = TestResult::Failure; - try { - result = fn(); - } catch (const std::exception& e) { - std::cerr << "Exception: " << e.what() << std::endl; - result = TestResult::Failure; - } catch (...) { - std::cerr << "Unknown exception" << std::endl; - result = TestResult::Failure; - } - - pipe_wr.write_all((char*)&result, sizeof(TestResult)); +void signal_event(void* hEvent); -#ifdef _WIN32 - SetEvent((HANDLE)hEvent); -#endif // _WIN32 -} +// hEvent: unused on linux, event handle on windows +void test_wrapper(std::function fn, int timeoutSecs, PipeWriter pipeWr, void* hEvent); // this function along with `wait_for_python_ready_sigwait()` // work together to wait on a signal from the python process // indicating that the tuntap interface has been created, and that the mitm is ready // // hEvent is an output parameter for windows but unused on linux -inline void wait_for_python_ready_sigblock(void** hEvent) -{ -#ifdef __linux__ - sigset_t set {}; - - if (sigemptyset(&set) < 0) - raise_system_error("failed to create empty signal set"); - - if (sigaddset(&set, SIGUSR1) < 0) - raise_system_error("failed to add sigusr1 to the signal set"); - - if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) - raise_system_error("failed to mask sigusr1"); - -#endif // __linux__ -#ifdef _WIN32 - // TODO: implement signaling of this event in the python mitm - *hEvent = CreateEvent( - NULL, // default security attributes - FALSE, // auto-reset event - FALSE, // initial state is nonsignaled - "Global\\PythonSignal"); // name of the event - if (*hEvent == NULL) - raise_system_error("failed to create event"); -#endif // _WIN32 - - std::cout << "blocked signal..." << std::endl; -} +void wait_for_python_ready_sigblock(void** hEvent); // as in the above function, hEvent is unused on linux -inline void wait_for_python_ready_sigwait(void* hEvent, int timeout_secs) -{ - std::cout << "waiting for python to be ready..." << std::endl; - -#ifdef __linux__ - timespec ts {.tv_sec = timeout_secs, .tv_nsec = 0}; - sigset_t set {}; - siginfo_t sig {}; - - if (sigemptyset(&set) < 0) - raise_system_error("failed to create empty signal set"); - - if (sigaddset(&set, SIGUSR1) < 0) - raise_system_error("failed to add sigusr1 to the signal set"); - - if (sigtimedwait(&set, &sig, &ts) < 0) - raise_system_error("failed to wait on sigusr1"); - - sigprocmask(SIG_UNBLOCK, &set, nullptr); - -#endif // __linux__ -#ifdef _WIN32 - DWORD waitResult = WaitForSingleObject(hEvent, timeout_secs * 1000); - if (waitResult != WAIT_OBJECT_0) { - raise_system_error("failed to wait on event"); - } - CloseHandle(hEvent); -#endif // _WIN32 - - std::cout << "signal received; python is ready" << std::endl; -} +void wait_for_python_ready_sigwait(void* hEvent, int timeoutSecs); // run a test // forks and runs each of the provided closures // if `wait_for_python` is true, wait for SIGUSR1 after forking and executing the first closure -inline TestResult test( - int timeout_secs, std::vector> closures, bool wait_for_python = false) -{ - std::vector pipes {}; - - for (size_t i = 0; i < closures.size(); i++) - pipes.emplace_back(); - -#ifdef __linux__ - std::vector pids {}; - void* hEvent = nullptr; - for (size_t i = 0; i < closures.size(); i++) { - if (wait_for_python && i == 0) - wait_for_python_ready_sigblock(&hEvent); - - auto pid = fork(); - if (pid < 0) { - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - raise_system_error("failed to fork"); - } - - if (pid == 0) { - fork_wrapper(closures[i], timeout_secs, std::move(pipes[i].writer), nullptr); - std::exit(EXIT_SUCCESS); - } - - pids.push_back(pid); - - if (wait_for_python && i == 0) - wait_for_python_ready_sigwait(&hEvent, 3); - } - - std::vector pfds {}; - - int timerfd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); - if (timerfd < 0) { - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - - raise_system_error("failed to create timerfd"); - } - - pfds.push_back({.fd = timerfd, .events = POLL_IN, .revents = 0}); - for (const auto& pipe: pipes) - pfds.push_back({ - .fd = (int)pipe.reader.fd(), - .events = POLL_IN, - .revents = 0, - }); - - itimerspec spec { - .it_interval = - { - .tv_sec = 0, - .tv_nsec = 0, - }, - .it_value = { - .tv_sec = timeout_secs, - .tv_nsec = 0, - }}; - - if (timerfd_settime(timerfd, 0, &spec, nullptr) < 0) { - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - close(timerfd); - - raise_system_error("failed to set timerfd"); - } - - std::vector> results(pids.size(), std::nullopt); - - for (;;) { - auto n = poll(pfds.data(), pfds.size(), -1); - if (n < 0) { - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - close(timerfd); - - raise_system_error("failed to poll"); - } - - for (auto& pfd: std::vector(pfds)) { - if (pfd.revents == 0) - continue; - - // timed out - if (pfd.fd == timerfd) { - std::cout << "Timed out!\n"; - - std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); - close(timerfd); - - return TestResult::Failure; - } - - auto elem = std::find_if( - pipes.begin(), pipes.end(), [fd = pfd.fd](const auto& pipe) { return pipe.reader.fd() == fd; }); - auto idx = elem - pipes.begin(); - - TestResult result = TestResult::Failure; - char buffer = 0; - auto n = read(pfd.fd, &buffer, sizeof(TestResult)); - if (n == 0) { - std::cout << "failed to read from pipe: pipe closed unexpectedly\n"; - result = TestResult::Failure; - } else if (n < 0) { - std::cout << "failed to read from pipe: " << std::strerror(errno) << std::endl; - result = TestResult::Failure; - } else - result = (TestResult)buffer; - - // the subprocess should have exited - // check its exit status - int status; - if (waitpid(pids[idx], &status, 0) < 0) - std::cout << "failed to wait on subprocess[" << idx << "]: " << std::strerror(errno) << std::endl; - - auto exit_status = WEXITSTATUS(status); - if (WIFEXITED(status) && exit_status != EXIT_SUCCESS) { - std::cout << "subprocess[" << idx << "] exited with status " << exit_status << std::endl; - } else if (WIFSIGNALED(status)) { - std::cout << "subprocess[" << idx << "] killed by signal " << WTERMSIG(status) << std::endl; - } else { - std::cout << "subprocess[" << idx << "] completed with " - << (result == TestResult::Success ? "Success" : "Failure") << std::endl; - } - - // store the result - results[idx] = result; - - // this subprocess is done, remove its pipe from the poll fds - pfds.erase( - std::remove_if(pfds.begin(), pfds.end(), [&](const auto& p) { return p.fd == pfd.fd; }), pfds.end()); - - auto done = - std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); - if (done) - goto end; // justification for goto: breaks out of two levels of loop - } - } - -end: - close(timerfd); - - if (std::ranges::any_of(results, [](const auto& x) { return x == TestResult::Failure; })) - return TestResult::Failure; - - return TestResult::Success; -#endif // __linux__ -#ifdef _WIN32 - std::vector events {}; - std::vector threads {}; - - for (size_t i = 0; i < closures.size(); i++) { - HANDLE hEvent = CreateEvent( - nullptr, // default security attributes - true, // auto-reset event - false, // initial state is nonsignaled - nullptr); // unnamed event - if (!hEvent) - raise_system_error("failed to create event"); - events.push_back(hEvent); - } - - for (size_t i = 0; i < closures.size(); i++) { - HANDLE hEvent = nullptr; - if (wait_for_python && i == 0) - wait_for_python_ready_sigblock(&hEvent); - - threads.emplace_back(fork_wrapper, closures[i], timeout_secs, std::move(pipes[i].writer), events[i]); - - if (wait_for_python && i == 0) - wait_for_python_ready_sigwait(hEvent, 3); - } - - HANDLE timer = CreateWaitableTimer(nullptr, true, nullptr); - if (!timer) { - std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); - raise_system_error("failed to create waitable timer"); - } - - LARGE_INTEGER expires_in = {0}; - - // negative value indicates relative time - expires_in.QuadPart = -static_cast(timeout_secs) * ns_per_second / ns_per_unit; - if (!SetWaitableTimer(timer, &expires_in, 0, nullptr, nullptr, false)) { - std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); - CloseHandle(timer); - raise_system_error("failed to set waitable timer"); - } - - // these are the handles we're going to poll - std::vector wait_handles {timer}; - - // poll all read halves of the pipes - for (const auto& ev: events) - wait_handles.push_back(ev); - - std::vector> results(threads.size(), std::nullopt); - - for (;;) { - DWORD waitResult = WaitForMultipleObjects(wait_handles.size(), wait_handles.data(), false, INFINITE); - if (waitResult == WAIT_FAILED) { - std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); - CloseHandle(timer); - raise_system_error("failed to wait on handles"); - } - - // the idx of the handle in the handles array - // note that index 0 is the timer - // and we adjust the handles array as tasks complete - // so we need an extra step to calculate the index in `closure`-space - size_t wait_idx = (size_t)waitResult - WAIT_OBJECT_0; - - // timed out - if (wait_idx == 0) { - std::cout << "Timed out!\n"; - std::for_each(threads.begin(), threads.end(), [](auto& t) { - t.request_stop(); - t.detach(); - }); - std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); - CloseHandle(timer); - return TestResult::Failure; - } - - // find the idx - const auto& hEvent = wait_handles[wait_idx]; - auto event_it = std::find_if(events.begin(), events.end(), [hEvent](const auto& ev) { return ev == hEvent; }); - const auto idx = event_it - events.begin(); - auto& pipe = pipes[idx]; - TestResult result = TestResult::Failure; - char buffer = 0; - try { - pipe.reader.read_exact(&buffer, sizeof(TestResult)); - result = (TestResult)buffer; - } catch (const std::system_error& e) { - std::cout << "failed to read from pipe: " << e.what() << std::endl; - result = TestResult::Failure; - } - - std::cout << "subprocess[" << idx << "] completed with " - << (result == TestResult::Success ? "Success" : "Failure") << std::endl; - - // store the result - results[idx] = result; - - // this subprocess is done, remove its pipe from the handles - wait_handles.erase( - std::remove_if(wait_handles.begin(), wait_handles.end(), [&](const auto& h) { return h == hEvent; }), - wait_handles.end()); - auto done = std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); - if (done) - goto end; // justification for goto: breaks out of two levels of loop - } - -end: - std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); - CloseHandle(timer); - - if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) - return TestResult::Failure; - - return TestResult::Success; -#endif // _WIN32 -} - -inline std::wstring discover_python_home(std::string command) -{ - // leverage the system's command line to get the current python prefix - FILE* pipe = popen(std::format("{} -c \"import sys; print(sys.prefix)\"", command).c_str(), "r"); - if (!pipe) - throw std::runtime_error("failed to start python process to discover prefix"); - - std::array buffer {}; - std::string output {}; - - size_t n; - while ((n = fread(buffer.data(), 1, buffer.size(), pipe)) > 0) - output.append(buffer.data(), n); - - // remove trailing whitespace - output.erase(output.find_last_not_of("\r\n") + 1); - - auto status = pclose(pipe); - if (status < 0) - throw std::runtime_error("failed to close close process"); - else if (status > 0) - throw std::runtime_error("process returned non-zero exit code: " + std::to_string(status)); - - // assume it's ascii, so we can just cast it as a wstring - return std::wstring(output.begin(), output.end()); -} - -inline void ensure_python_initialized() -{ - if (Py_IsInitialized()) - return; - -#ifdef _WIN32 - auto python_home = discover_python_home("python"); - Py_SetPythonHome(python_home.c_str()); -#endif // _WIN32 - - Py_Initialize(); - - // add the cwd to the path - { - PyObject* sysPath = PySys_GetObject("path"); - if (!sysPath) - throw std::runtime_error("failed to get sys.path"); - - PyObject* newPath = PyUnicode_FromString("."); - if (!newPath) - throw std::runtime_error("failed to create Python string"); - - if (PyList_Append(sysPath, newPath) < 0) { - Py_DECREF(newPath); - throw std::runtime_error("failed to append to sys.path"); - } - - Py_DECREF(newPath); - } - - // release the GIL, the caller will have to acquire it again - PyEval_SaveThread(); -} - -inline void maybe_finalize_python() -{ - PyGILState_STATE gstate = PyGILState_Ensure(); - if (!Py_IsInitialized()) - return; - - Py_Finalize(); - - // stop compiler from complaining that it's unused - (void)gstate; -} - -inline TestResult run_python(const char* path, std::vector> argv = {}) -{ - // ensure_python_initialized(); - PyGILState_STATE gstate = PyGILState_Ensure(); - -// insert the pid at the start of the argv, this is important for signalling readiness -#ifdef __linux__ - pid_t pid = getppid(); -#endif // __linux__ -#ifdef _WIN32 - DWORD pid = GetCurrentProcessId(); -#endif // _WIN32 - - auto pid_s = std::to_string(pid); - argv.insert(argv.begin(), pid_s.c_str()); - argv.insert(argv.begin(), "mitm"); - - // set argv - { - PyObject* py_argv = PyList_New(argv.size()); - if (!py_argv) - goto exception; - - for (size_t i = 0; i < argv.size(); i++) - if (argv[i]) - PyList_SET_ITEM(py_argv, i, PyUnicode_FromString(argv[i].value().c_str())); - else - PyList_SET_ITEM(py_argv, i, Py_None); - - if (PySys_SetObject("argv", py_argv) < 0) - goto exception; - - Py_DECREF(py_argv); - } - - { - std::ifstream file(path); - std::stringstream buffer; - buffer << file.rdbuf(); - - int rc = PyRun_SimpleString(buffer.str().c_str()); - file.close(); - - if (rc < 0) - throw std::runtime_error("failed to run python script"); - } - - PyGILState_Release(gstate); - return TestResult::Success; - -exception: - PyGILState_Release(gstate); - return TestResult::Failure; -} +TestResult test(int timeoutSecs, std::vector> closures, bool waitForPython = false); -inline TestResult run_mitm( - std::string testcase, - std::string mitm_ip, - uint16_t mitm_port, - std::string remote_ip, - uint16_t remote_port, - std::vector extra_args = {}) -{ - auto cwd = std::filesystem::current_path(); - chdir_to_project_root(); +std::wstring discover_python_home(std::string command); - // we build the args for the user to make calling the function more convenient - std::vector> args { - testcase, mitm_ip, std::to_string(mitm_port), remote_ip, std::to_string(remote_port)}; +void ensure_python_home(); +void ensure_python_initialized(); +void maybe_finalize_python(); - for (auto arg: extra_args) - args.push_back(arg); +// get the pid of the process waiting to be signaled by Python +int get_listener_pid(); - auto result = run_python("tests/cpp/ymq/py_mitm/main.py", args); +TestResult run_python(const char* path, std::vector> argv = {}); - // change back to the original working directory - std::filesystem::current_path(cwd); - return result; -} +TestResult run_mitm( + std::string testCase, + std::string mitmIp, + uint16_t mitmPort, + std::string remoteIp, + uint16_t remotePort, + std::vector extraArgs = {}); diff --git a/tests/cpp/ymq/common/testing_linux.cpp b/tests/cpp/ymq/common/testing_linux.cpp new file mode 100644 index 000000000..532018cdb --- /dev/null +++ b/tests/cpp/ymq/common/testing_linux.cpp @@ -0,0 +1,222 @@ +#define PY_SSIZE_T_CLEAN + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "tests/cpp/ymq/common/testing.h" +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe.h" + +void ensure_python_home() +{ + // no-op +} + +void signal_event(void* hEvent) +{ + // no-op +} + +int get_listener_pid() +{ + return getppid(); +} + +void wait_for_python_ready_sigblock(void** hEvent) +{ + sigset_t set {}; + + if (sigemptyset(&set) < 0) + raise_system_error("failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + raise_system_error("failed to add sigusr1 to the signal set"); + + if (sigprocmask(SIG_BLOCK, &set, nullptr) < 0) + raise_system_error("failed to mask sigusr1"); + + std::cout << "blocked signal..." << std::endl; +} + +void wait_for_python_ready_sigwait(void* hEvent, int timeoutSecs) +{ + std::cout << "waiting for python to be ready..." << std::endl; + + timespec ts {.tv_sec = timeoutSecs, .tv_nsec = 0}; + sigset_t set {}; + siginfo_t sig {}; + + if (sigemptyset(&set) < 0) + raise_system_error("failed to create empty signal set"); + + if (sigaddset(&set, SIGUSR1) < 0) + raise_system_error("failed to add sigusr1 to the signal set"); + + if (sigtimedwait(&set, &sig, &ts) < 0) + raise_system_error("failed to wait on sigusr1"); + + sigprocmask(SIG_UNBLOCK, &set, nullptr); + + std::cout << "signal received; python is ready" << std::endl; +} + +TestResult test(int timeoutSecs, std::vector> closures, bool waitForPython) +{ + std::vector pipes {}; + + for (size_t i = 0; i < closures.size(); i++) + pipes.emplace_back(); + + std::vector pids {}; + void* hEvent = nullptr; + for (size_t i = 0; i < closures.size(); i++) { + if (waitForPython && i == 0) + wait_for_python_ready_sigblock(&hEvent); + + auto pid = fork(); + if (pid < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + raise_system_error("failed to fork"); + } + + if (pid == 0) { + test_wrapper(closures[i], timeoutSecs, std::move(pipes[i].writer), nullptr); + std::exit(EXIT_SUCCESS); + } + + pids.push_back(pid); + + if (waitForPython && i == 0) + wait_for_python_ready_sigwait(&hEvent, 3); + } + + std::vector pollFds {}; + + int timerFd = timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK); + if (timerFd < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + + raise_system_error("failed to create timerfd"); + } + + pollFds.push_back({.fd = timerFd, .events = POLL_IN, .revents = 0}); + for (const auto& pipe: pipes) + pollFds.push_back({ + .fd = (int)pipe.reader.fd(), + .events = POLL_IN, + .revents = 0, + }); + + itimerspec spec { + .it_interval = + { + .tv_sec = 0, + .tv_nsec = 0, + }, + .it_value = { + .tv_sec = timeoutSecs, + .tv_nsec = 0, + }}; + + if (timerfd_settime(timerFd, 0, &spec, nullptr) < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerFd); + + raise_system_error("failed to set timerfd"); + } + + std::vector> results(pids.size(), std::nullopt); + + for (;;) { + auto n = poll(pollFds.data(), pollFds.size(), -1); + if (n < 0) { + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerFd); + + raise_system_error("failed to poll"); + } + + for (auto& pfd: std::vector(pollFds)) { + if (pfd.revents == 0) + continue; + + // timed out + if (pfd.fd == timerFd) { + std::cout << "Timed out!\n"; + + std::for_each(pids.begin(), pids.end(), [](const auto& pid) { kill(pid, SIGKILL); }); + close(timerFd); + + return TestResult::Failure; + } + + auto elem = std::find_if( + pipes.begin(), pipes.end(), [fd = pfd.fd](const auto& pipe) { return pipe.reader.fd() == fd; }); + auto idx = elem - pipes.begin(); + + TestResult result = TestResult::Failure; + char buffer = 0; + auto n = read(pfd.fd, &buffer, sizeof(TestResult)); + if (n == 0) { + std::cout << "failed to read from pipe: pipe closed unexpectedly\n"; + result = TestResult::Failure; + } else if (n < 0) { + std::cout << "failed to read from pipe: " << std::strerror(errno) << std::endl; + result = TestResult::Failure; + } else + result = (TestResult)buffer; + + // the subprocess should have exited + // check its exit status + int status; + if (waitpid(pids[idx], &status, 0) < 0) + std::cout << "failed to wait on subprocess[" << idx << "]: " << std::strerror(errno) << std::endl; + + auto exitStatus = WEXITSTATUS(status); + if (WIFEXITED(status) && exitStatus != EXIT_SUCCESS) { + std::cout << "subprocess[" << idx << "] exited with status " << exitStatus << std::endl; + } else if (WIFSIGNALED(status)) { + std::cout << "subprocess[" << idx << "] killed by signal " << WTERMSIG(status) << std::endl; + } else { + std::cout << "subprocess[" << idx << "] completed with " + << (result == TestResult::Success ? "Success" : "Failure") << std::endl; + } + + // store the result + results[idx] = result; + + // this subprocess is done, remove its pipe from the poll fds + pollFds.erase( + std::remove_if(pollFds.begin(), pollFds.end(), [&](const auto& p) { return p.fd == pfd.fd; }), pollFds.end()); + + auto done = + std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + } + +end: + close(timerFd); + + if (std::ranges::any_of(results, [](const auto& x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +} diff --git a/tests/cpp/ymq/common/testing_windows.cpp b/tests/cpp/ymq/common/testing_windows.cpp new file mode 100644 index 000000000..f39a8117c --- /dev/null +++ b/tests/cpp/ymq/common/testing_windows.cpp @@ -0,0 +1,195 @@ +#define PY_SSIZE_T_CLEAN + +// on Windows and in debug mode, undefine _DEBUG before including Python.h +// this prevents issues including the debug version of the Python library +#ifdef _DEBUG +#undef _DEBUG +#include +#define _DEBUG +#else +#include +#endif + +#include +#include +#include +#include + +#include +#include +#include + +#include "tests/cpp/ymq/common/testing.h" +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/pipe/pipe.h" + +// the windows timer apis work in 100-nanosecond units +const LONGLONG ns_per_second = 1'000'000'000LL; +const LONGLONG ns_per_unit = 100LL; // 1 unit = 100 nanoseconds + +void ensure_python_home() +{ + auto pythonHome = discover_python_home("python"); + Py_SetPythonHome(pythonHome.c_str()); +} + +int get_listener_pid() +{ + return GetCurrentProcessId(); +} + +void signal_event(void* hEvent) +{ + SetEvent((HANDLE)hEvent); +} + +void wait_for_python_ready_sigblock(void** hEvent) +{ + *hEvent = CreateEvent( + NULL, // default security attributes + false, // auto-reset event + false, // initial state is nonsignaled + "Global\\PythonSignal"); // name of the event + if (*hEvent == NULL) + raise_system_error("failed to create event"); + + std::cout << "blocked signal..." << std::endl; +} + +void wait_for_python_ready_sigwait(void* hEvent, int timeoutSecs) +{ + std::cout << "waiting for python to be ready..." << std::endl; + + DWORD waitResult = WaitForSingleObject(hEvent, timeoutSecs * 1000); + + if (waitResult != WAIT_OBJECT_0) { + raise_system_error("failed to wait on event"); + } + + CloseHandle(hEvent); + + std::cout << "signal received; python is ready" << std::endl; +} + +TestResult test(int timeoutSecs, std::vector> closures, bool waitForPython) +{ + std::vector pipes {}; + + for (size_t i = 0; i < closures.size(); i++) + pipes.emplace_back(); + + std::vector events {}; + std::vector threads {}; + + for (size_t i = 0; i < closures.size(); i++) { + HANDLE hEvent = CreateEvent( + nullptr, // default security attributes + true, // auto-reset event + false, // initial state is nonsignaled + nullptr); // unnamed event + if (!hEvent) + raise_system_error("failed to create event"); + events.push_back(hEvent); + } + + for (size_t i = 0; i < closures.size(); i++) { + HANDLE hEvent = nullptr; + if (waitForPython && i == 0) + wait_for_python_ready_sigblock(&hEvent); + + threads.emplace_back(test_wrapper, closures[i], timeoutSecs, std::move(pipes[i].writer), events[i]); + + if (waitForPython && i == 0) + wait_for_python_ready_sigwait(hEvent, 3); + } + + HANDLE timer = CreateWaitableTimer(nullptr, true, nullptr); + if (!timer) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + raise_system_error("failed to create waitable timer"); + } + + LARGE_INTEGER expiresIn = {0}; + + // negative value indicates relative time + expiresIn.QuadPart = -static_cast(timeoutSecs) * ns_per_second / ns_per_unit; + if (!SetWaitableTimer(timer, &expiresIn, 0, nullptr, nullptr, false)) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + raise_system_error("failed to set waitable timer"); + } + + // these are the handles we're going to poll + std::vector waitHandles {timer}; + + // poll all read halves of the pipes + for (const auto& ev: events) + waitHandles.push_back(ev); + + std::vector> results(threads.size(), std::nullopt); + + for (;;) { + DWORD waitResult = WaitForMultipleObjects(waitHandles.size(), waitHandles.data(), false, INFINITE); + if (waitResult == WAIT_FAILED) { + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + raise_system_error("failed to wait on handles"); + } + + // the idx of the handle in the handles array + // note that index 0 is the timer + // and we adjust the handles array as tasks complete + // so we need an extra step to calculate the index in `closure`-space + size_t waitIdx = (size_t)waitResult - WAIT_OBJECT_0; + + // timed out + if (waitIdx == 0) { + std::cout << "Timed out!\n"; + std::for_each(threads.begin(), threads.end(), [](auto& t) { + t.request_stop(); + t.detach(); + }); + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + return TestResult::Failure; + } + + // find the idx + const auto& hEvent = waitHandles[waitIdx]; + auto eventIt = std::find_if(events.begin(), events.end(), [hEvent](const auto& ev) { return ev == hEvent; }); + const auto idx = eventIt - events.begin(); + auto& pipe = pipes[idx]; + TestResult result = TestResult::Failure; + char buffer = 0; + try { + pipe.reader.read_exact(&buffer, sizeof(TestResult)); + result = (TestResult)buffer; + } catch (const std::system_error& e) { + std::cout << "failed to read from pipe: " << e.what() << std::endl; + result = TestResult::Failure; + } + + std::cout << "subprocess[" << idx << "] completed with " + << (result == TestResult::Success ? "Success" : "Failure") << std::endl; + + // store the result + results[idx] = result; + + // this subprocess is done, remove its pipe from the handles + waitHandles.erase( + std::remove_if(waitHandles.begin(), waitHandles.end(), [&](const auto& h) { return h == hEvent; }), + waitHandles.end()); + auto done = std::all_of(results.begin(), results.end(), [](const auto& result) { return result.has_value(); }); + if (done) + goto end; // justification for goto: breaks out of two levels of loop + } + +end: + std::for_each(events.begin(), events.end(), [](const auto& ev) { CloseHandle(ev); }); + CloseHandle(timer); + + if (std::ranges::any_of(results, [](auto x) { return x == TestResult::Failure; })) + return TestResult::Failure; + + return TestResult::Success; +} diff --git a/tests/cpp/ymq/common/utils.cpp b/tests/cpp/ymq/common/utils.cpp index ba5f15ba8..66a4a7781 100644 --- a/tests/cpp/ymq/common/utils.cpp +++ b/tests/cpp/ymq/common/utils.cpp @@ -1,49 +1,7 @@ #include "tests/cpp/ymq/common/utils.h" -#include #include #include -#include - -#ifdef __linux__ -#include -#endif // __linux__ -#ifdef _WIN32 -#include -#include -#endif // _WIN32 - -void raise_system_error(const char* msg) -{ -#ifdef __linux__ - throw std::system_error(errno, std::generic_category(), msg); -#endif // __linux__ -#ifdef _WIN32 - throw std::system_error(GetLastError(), std::generic_category(), msg); -#endif // _WIN32 -} - -void raise_socket_error(const char* msg) -{ -#ifdef __linux__ - throw std::system_error(errno, std::generic_category(), msg); -#endif // __linux__ -#ifdef _WIN32 - throw std::system_error(WSAGetLastError(), std::generic_category(), msg); -#endif // _WIN32 -} - -const char* check_localhost(const char* host) -{ - return std::strcmp(host, "localhost") == 0 ? "127.0.0.1" : host; -} - -std::string format_address(std::string host, uint16_t port) -{ - std::ostringstream oss; - oss << "tcp://" << check_localhost(host.c_str()) << ":" << port; - return oss.str(); -} // change the current working directory to the project root // this is important for finding the python mitm script @@ -61,9 +19,9 @@ void chdir_to_project_root() } } -unsigned short random_port(unsigned short min_port, unsigned short max_port) +unsigned short random_port(unsigned short minPort, unsigned short maxPort) { static thread_local std::mt19937_64 rng(std::random_device {}()); - std::uniform_int_distribution dist(min_port, max_port); + std::uniform_int_distribution dist(minPort, maxPort); return static_cast(dist(rng)); } diff --git a/tests/cpp/ymq/common/utils.h b/tests/cpp/ymq/common/utils.h index 2861ec5e6..f612792d6 100644 --- a/tests/cpp/ymq/common/utils.h +++ b/tests/cpp/ymq/common/utils.h @@ -9,14 +9,8 @@ void raise_system_error(const char* msg); // throw wan error with the last socket error code void raise_socket_error(const char* msg); -// returns its input unless it matches "localhost" in which case 127.0.0.1 is returned -const char* check_localhost(const char* host); - -// formats a tcp:// address -std::string format_address(std::string host, uint16_t port); - // change the current working directory to the project root // this is important for finding the python mitm script void chdir_to_project_root(); -unsigned short random_port(unsigned short min_port = 1024, unsigned short max_port = 65535); +unsigned short random_port(unsigned short minPort = 1024, unsigned short maxPort = 65535); diff --git a/tests/cpp/ymq/common/utils_linux.cpp b/tests/cpp/ymq/common/utils_linux.cpp new file mode 100644 index 000000000..74262d4b9 --- /dev/null +++ b/tests/cpp/ymq/common/utils_linux.cpp @@ -0,0 +1,13 @@ +#include + +#include + +void raise_system_error(const char* msg) +{ + throw std::system_error(errno, std::generic_category(), msg); +} + +void raise_socket_error(const char* msg) +{ + throw std::system_error(errno, std::generic_category(), msg); +} diff --git a/tests/cpp/ymq/common/utils_windows.cpp b/tests/cpp/ymq/common/utils_windows.cpp new file mode 100644 index 000000000..2677c32b0 --- /dev/null +++ b/tests/cpp/ymq/common/utils_windows.cpp @@ -0,0 +1,14 @@ +#include +#include + +#include + +void raise_system_error(const char* msg) +{ + throw std::system_error(GetLastError(), std::generic_category(), msg); +} + +void raise_socket_error(const char* msg) +{ + throw std::system_error(WSAGetLastError(), std::generic_category(), msg); +} \ No newline at end of file diff --git a/tests/cpp/ymq/net/address.cpp b/tests/cpp/ymq/net/address.cpp new file mode 100644 index 000000000..0705f27b9 --- /dev/null +++ b/tests/cpp/ymq/net/address.cpp @@ -0,0 +1,29 @@ +#include "tests/cpp/ymq/net/address.h" + +#include + +Address parseAddress(const std::string& address_str) +{ + if (address_str.rfind("tcp://", 0) == 0) { // Check if string starts with "tcp://" + std::string_view remaining = address_str; + remaining.remove_prefix(6); // Remove "tcp://" + + size_t colon_pos = remaining.find(':'); + if (colon_pos == std::string_view::npos) { + throw std::runtime_error("Invalid TCP address format: missing port"); + } + + std::string host(remaining.substr(0, colon_pos)); + uint16_t port = (uint16_t)std::stoi(std::string(remaining.substr(colon_pos + 1))); + return Address {"tcp", host, port, ""}; + } + + if (address_str.rfind("ipc://", 0) == 0) { // Check if string starts with "ipc://" + std::string_view remaining = address_str; + remaining.remove_prefix(6); // Remove "ipc://" + std::string path(remaining); + return Address {"ipc", "", 0, path}; + } + + throw std::runtime_error("Invalid address format: " + address_str); +} diff --git a/tests/cpp/ymq/net/address.h b/tests/cpp/ymq/net/address.h new file mode 100644 index 000000000..8eb528bea --- /dev/null +++ b/tests/cpp/ymq/net/address.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +struct Address { + std::string protocol; + std::string host; + uint16_t port; + std::string path; +}; + +Address parseAddress(const std::string& address_str); diff --git a/tests/cpp/ymq/net/i_socket.h b/tests/cpp/ymq/net/i_socket.h new file mode 100644 index 000000000..600a90727 --- /dev/null +++ b/tests/cpp/ymq/net/i_socket.h @@ -0,0 +1,24 @@ +#pragma once +#include +#include +#include +#include + +class ISocket { +public: + virtual ~ISocket() = default; + + virtual void try_connect(const std::string& address, int tries = 10) const = 0; + virtual void bind(const std::string& address) const = 0; + virtual void listen(int backlog = 5) const = 0; + virtual std::unique_ptr accept() const = 0; + + virtual void write_all(const void* data, size_t size) const = 0; + virtual void write_all(std::string msg) const = 0; + + virtual void read_exact(void* buffer, size_t size) const = 0; + + virtual void write_message(std::string msg) const = 0; + + virtual std::string read_message() const = 0; +}; diff --git a/tests/cpp/ymq/net/socket_linux.cpp b/tests/cpp/ymq/net/socket_linux.cpp index b36f76e35..1ae5803bd 100644 --- a/tests/cpp/ymq/net/socket_linux.cpp +++ b/tests/cpp/ymq/net/socket_linux.cpp @@ -55,7 +55,7 @@ void Socket::try_connect(const std::string& host, short port, int tries) const sockaddr_in addr {}; addr.sin_family = AF_INET; addr.sin_port = htons(port); - inet_pton(AF_INET, check_localhost(host.c_str()), &addr.sin_addr); + inet_pton(AF_INET, host.c_str(), &addr.sin_addr); for (int i = 0; i < tries; i++) { auto code = ::connect(this->_fd, (sockaddr*)&addr, sizeof(addr)); diff --git a/tests/cpp/ymq/net/socket_utils.h b/tests/cpp/ymq/net/socket_utils.h new file mode 100644 index 000000000..e049e9151 --- /dev/null +++ b/tests/cpp/ymq/net/socket_utils.h @@ -0,0 +1,6 @@ +#include + +#include "tests/cpp/ymq/net/i_socket.h" + +std::unique_ptr connect_socket(std::string& address_str); +std::unique_ptr bind_socket(std::string& address_str); diff --git a/tests/cpp/ymq/net/socket_utils_linux.cpp b/tests/cpp/ymq/net/socket_utils_linux.cpp new file mode 100644 index 000000000..365a6414c --- /dev/null +++ b/tests/cpp/ymq/net/socket_utils_linux.cpp @@ -0,0 +1,41 @@ +#include +#include + +#include "tests/cpp/ymq/net/address.h" +#include "tests/cpp/ymq/net/socket_utils.h" +#include "tests/cpp/ymq/net/tcp_socket.h" +#include "tests/cpp/ymq/net/uds_socket.h" + +std::unique_ptr connect_socket(std::string& address_str) +{ + auto address = parseAddress(address_str); + + if (address.protocol == "tcp") { + auto socket = std::make_unique(); + socket->try_connect(address_str); + return socket; + } else if (address.protocol == "ipc") { + auto socket = std::make_unique(); + socket->try_connect(address_str); + return socket; + } + + throw std::runtime_error(std::format("Unsupported protocol for raw client: '{}'", address.protocol)); +} + +std::unique_ptr bind_socket(std::string& address_str) +{ + auto address = parseAddress(address_str); + + if (address.protocol == "tcp") { + auto socket = std::make_unique(); + socket->bind(address_str); + return socket; + } else if (address.protocol == "ipc") { + auto socket = std::make_unique(); + socket->bind(address_str); + return socket; + } + + throw std::runtime_error(std::format("Unsupported protocol for raw server: '{}'", address.protocol)); +} diff --git a/tests/cpp/ymq/net/socket_utils_windows.cpp b/tests/cpp/ymq/net/socket_utils_windows.cpp new file mode 100644 index 000000000..9b27c401d --- /dev/null +++ b/tests/cpp/ymq/net/socket_utils_windows.cpp @@ -0,0 +1,31 @@ +#include +#include + +#include "tests/cpp/ymq/net/address.h" +#include "tests/cpp/ymq/net/socket_utils.h" +#include "tests/cpp/ymq/net/tcp_socket.h" + +std::unique_ptr connect_socket(std::string& address_str) +{ + auto address = parseAddress(address_str); + + if (address.protocol == "tcp") { + auto socket = std::make_unique(); + socket->try_connect(address_str); + return socket; + } + + throw std::runtime_error(std::format("Unsupported protocol for raw client: '{}'", address.protocol)); +} + +std::unique_ptr bind_socket(std::string& address_str) +{ + auto address = parseAddress(address_str); + + if (address.protocol == "tcp") { + auto socket = std::make_unique(); + socket->bind(address_str); + return socket; + } + throw std::runtime_error(std::format("Unsupported protocol for raw server: '{}'", address.protocol)); +} diff --git a/tests/cpp/ymq/net/socket_windows.cpp b/tests/cpp/ymq/net/socket_windows.cpp index 8f8323334..0f4eea7b8 100644 --- a/tests/cpp/ymq/net/socket_windows.cpp +++ b/tests/cpp/ymq/net/socket_windows.cpp @@ -56,7 +56,7 @@ void Socket::try_connect(const std::string& host, short port, int tries) const sockaddr_in addr {}; addr.sin_family = AF_INET; addr.sin_port = htons(port); - inet_pton(AF_INET, check_localhost(host.c_str()), &addr.sin_addr); + inet_pton(AF_INET, host.c_str(), &addr.sin_addr); for (int i = 0; i < tries; i++) { auto code = ::connect((SOCKET)this->_fd, (sockaddr*)&addr, sizeof(addr)); @@ -67,8 +67,6 @@ void Socket::try_connect(const std::string& host, short port, int tries) const continue; } - std::printf("fpppp %d\n", WSAGetLastError()); - raise_socket_error("failed to connect"); } diff --git a/tests/cpp/ymq/net/tcp_socket.h b/tests/cpp/ymq/net/tcp_socket.h new file mode 100644 index 000000000..d74de05d0 --- /dev/null +++ b/tests/cpp/ymq/net/tcp_socket.h @@ -0,0 +1,50 @@ +#pragma once +#include +#include +#include +#include + +#include "address.h" +#include "i_socket.h" + +class TCPSocket: public ISocket { +public: + TCPSocket(bool nodelay = false); + TCPSocket(bool nodelay, long long fd); + ~TCPSocket(); + + // move-only + TCPSocket(TCPSocket&&) noexcept; + TCPSocket& operator=(TCPSocket&&) noexcept; + TCPSocket(const TCPSocket&) = delete; + TCPSocket& operator=(const TCPSocket&) = delete; + + void try_connect(const std::string& address, int tries = 10) const override; + void bind(const std::string& address) const override; + void listen(int backlog = 5) const override; + std::unique_ptr accept() const override; + + void write_all(const void* data, size_t size) const override; + void write_all(std::string msg) const override; + + void read_exact(void* buffer, size_t size) const override; + + void write_message(std::string msg) const override; + + std::string read_message() const override; + +private: + // the native handle for this pipe reader + // on Linux, this is a file descriptor + // on Windows, this is a SOCKET + long long _fd; + + // indicates if nodelay was set + bool _nodelay; + + // write up to `size` bytes + int write(const void* buffer, size_t size) const; + + // read up to `size` bytes + int read(void* buffer, size_t size) const; +}; diff --git a/tests/cpp/ymq/net/tcp_socket_linux.cpp b/tests/cpp/ymq/net/tcp_socket_linux.cpp new file mode 100644 index 000000000..1c51f3ede --- /dev/null +++ b/tests/cpp/ymq/net/tcp_socket_linux.cpp @@ -0,0 +1,160 @@ +#include +#include +#include +#include + +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/tcp_socket.h" + +TCPSocket::TCPSocket(bool nodelay): _fd(-1), _nodelay(nodelay) +{ + this->_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (this->_fd < 0) + raise_socket_error("failed to create socket"); + + char on = 1; + if (this->_nodelay) + if (::setsockopt(this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) < 0) + raise_socket_error("failed to set nodelay"); +} + +TCPSocket::TCPSocket(bool nodelay, long long fd): _fd(fd), _nodelay(nodelay) +{ + char on = 1; + if (this->_nodelay) + if (::setsockopt(this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) < 0) + raise_socket_error("failed to set nodelay"); +} + +TCPSocket::~TCPSocket() +{ + close(this->_fd); +} + +TCPSocket::TCPSocket(TCPSocket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; +} + +TCPSocket& TCPSocket::operator=(TCPSocket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +void TCPSocket::try_connect(const std::string& address_str, int tries) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "tcp") { + throw std::runtime_error("Unsupported protocol for TCPSocket: " + address.protocol); + } + + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(address.port); + inet_pton(AF_INET, address.host.c_str(), &addr.sin_addr); + + for (int i = 0; i < tries; i++) { + auto code = ::connect(this->_fd, (sockaddr*)&addr, sizeof(addr)); + + if (code < 0) { + if (errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + raise_socket_error("failed to connect"); + } + + break; // success + } +} + +void TCPSocket::bind(const std::string& address_str) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "tcp") { + throw std::runtime_error("Unsupported protocol for TCPSocket: " + address.protocol); + } + + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(address.port); + addr.sin_addr.s_addr = INADDR_ANY; + if (::bind(this->_fd, (sockaddr*)&addr, sizeof(addr)) < 0) + raise_socket_error("failed to bind"); +} + +void TCPSocket::listen(int backlog) const +{ + if (::listen(this->_fd, backlog) < 0) + raise_socket_error("failed to listen"); +} + +std::unique_ptr TCPSocket::accept() const +{ + long long fd = ::accept(this->_fd, nullptr, nullptr); + if (fd < 0) + raise_socket_error("failed to accept"); + + return std::make_unique(this->_nodelay, fd); +} + +int TCPSocket::write(const void* buffer, size_t size) const +{ + int n = ::write(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to send"); + return n; +} + +void TCPSocket::write_all(const void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} + +void TCPSocket::write_all(std::string msg) const +{ + this->write_all(msg.data(), msg.size()); +} + +void TCPSocket::write_message(std::string msg) const +{ + uint64_t header = msg.length(); + this->write_all(&header, 8); + this->write_all(msg.data(), msg.length()); +} + +int TCPSocket::read(void* buffer, size_t size) const +{ + int n = ::read(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to recv"); + return n; +} + +void TCPSocket::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} + +std::string TCPSocket::read_message() const +{ + uint64_t header = 0; + this->read_exact(&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); +} diff --git a/tests/cpp/ymq/net/tcp_socket_windows.cpp b/tests/cpp/ymq/net/tcp_socket_windows.cpp new file mode 100644 index 000000000..c0be0ab08 --- /dev/null +++ b/tests/cpp/ymq/net/tcp_socket_windows.cpp @@ -0,0 +1,161 @@ +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" +#include "tests/cpp/ymq/net/tcp_socket.h" + +TCPSocket::TCPSocket(bool nodelay): _fd(-1), _nodelay(nodelay) +{ + this->_fd = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + if (this->_fd == SOCKET_ERROR) + raise_socket_error("failed to create socket"); + + char on = 1; + if (this->_nodelay) + if (::setsockopt((SOCKET)this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) == SOCKET_ERROR) + raise_socket_error("failed to set nodelay"); +} + +TCPSocket::TCPSocket(bool nodelay, long long fd): _fd(fd), _nodelay(nodelay) +{ + char on = 1; + if (this->_nodelay) + if (::setsockopt((SOCKET)this->_fd, IPPROTO_TCP, TCP_NODELAY, (const char*)&on, sizeof(on)) == SOCKET_ERROR) + raise_socket_error("failed to set nodelay"); +} + +TCPSocket::~TCPSocket() +{ + ::closesocket((SOCKET)this->_fd); +} + +TCPSocket::TCPSocket(TCPSocket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; +} + +TCPSocket& TCPSocket::operator=(TCPSocket&& other) noexcept +{ + this->_nodelay = other._nodelay; + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +void TCPSocket::try_connect(const std::string& address_str, int tries) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "tcp") { + throw std::runtime_error("Unsupported protocol for TCPSocket: " + address.protocol); + } + + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(address.port); + inet_pton(AF_INET, address.host.c_str(), &addr.sin_addr); + + for (int i = 0; i < tries; i++) { + auto code = ::connect((SOCKET)this->_fd, (sockaddr*)&addr, sizeof(addr)); + + if (code == SOCKET_ERROR) { + if (WSAGetLastError() == WSAECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + raise_socket_error("failed to connect"); + } + + break; // success + } +} + +void TCPSocket::bind(const std::string& address_str) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "tcp") { + throw std::runtime_error("Unsupported protocol for TCPSocket: " + address.protocol); + } + + sockaddr_in addr {}; + addr.sin_family = AF_INET; + addr.sin_port = htons(address.port); + addr.sin_addr.s_addr = INADDR_ANY; + if (::bind((SOCKET)this->_fd, (sockaddr*)&addr, sizeof(addr)) == SOCKET_ERROR) + raise_socket_error("failed to bind"); +} + +void TCPSocket::listen(int backlog) const +{ + if (::listen((SOCKET)this->_fd, backlog) == SOCKET_ERROR) + raise_socket_error("failed to listen"); +} + +std::unique_ptr TCPSocket::accept() const +{ + long long fd = ::accept((SOCKET)this->_fd, nullptr, nullptr); + if (fd == SOCKET_ERROR) + raise_socket_error("failed to accept"); + + return std::make_unique(this->_nodelay, fd); +} + +int TCPSocket::write(const void* buffer, size_t size) const +{ + auto n = ::send((SOCKET)this->_fd, static_cast(buffer), (int)size, 0); + if (n == SOCKET_ERROR) + raise_socket_error("failed to send data"); + return n; +} + +void TCPSocket::write_all(const void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} + +void TCPSocket::write_all(std::string msg) const +{ + this->write_all(msg.data(), msg.size()); +} + +void TCPSocket::write_message(std::string msg) const +{ + uint64_t header = msg.length(); + this->write_all(&header, 8); + this->write_all(msg.data(), msg.length()); +} + +int TCPSocket::read(void* buffer, size_t size) const +{ + auto n = ::recv((SOCKET)this->_fd, static_cast(buffer), (int)size, 0); + if (n == SOCKET_ERROR) + raise_socket_error("failed to receive data"); + return n; +} + +void TCPSocket::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} + +std::string TCPSocket::read_message() const +{ + uint64_t header = 0; + this->read_exact(&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); +} diff --git a/tests/cpp/ymq/net/uds_socket.cpp b/tests/cpp/ymq/net/uds_socket.cpp new file mode 100644 index 000000000..e3a4ccc4c --- /dev/null +++ b/tests/cpp/ymq/net/uds_socket.cpp @@ -0,0 +1,149 @@ +#include "tests/cpp/ymq/net/uds_socket.h" + +#include +#include +#include + +#include +#include +#include +#include + +#include "tests/cpp/ymq/common/utils.h" + +UDSSocket::UDSSocket(): _fd(-1) +{ + this->_fd = ::socket(AF_UNIX, SOCK_STREAM, 0); + if (this->_fd < 0) + raise_socket_error("failed to create socket"); +} + +UDSSocket::UDSSocket(long long fd): _fd(fd) +{ +} + +UDSSocket::~UDSSocket() +{ + close(this->_fd); +} + +UDSSocket::UDSSocket(UDSSocket&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; +} + +UDSSocket& UDSSocket::operator=(UDSSocket&& other) noexcept +{ + this->_fd = other._fd; + other._fd = -1; + return *this; +} + +void UDSSocket::try_connect(const std::string& address_str, int tries) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "ipc") { + throw std::runtime_error("Unsupported protocol for UDSSocket: " + address.protocol); + } + + sockaddr_un addr {}; + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, address.path.c_str(), sizeof(addr.sun_path) - 1); + + for (int i = 0; i < tries; i++) { + auto code = ::connect(this->_fd, (sockaddr*)&addr, sizeof(addr)); + + if (code < 0) { + if (errno == ENOENT || errno == ECONNREFUSED) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + continue; + } + + raise_socket_error("failed to connect"); + } + + break; // success + } +} + +void UDSSocket::bind(const std::string& address_str) const +{ + auto address = parseAddress(address_str); + if (address.protocol != "ipc") { + throw std::runtime_error("Unsupported protocol for UDSSocket: " + address.protocol); + } + + sockaddr_un addr {}; + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, address.path.c_str(), sizeof(addr.sun_path) - 1); + ::unlink(address.path.c_str()); + if (::bind(this->_fd, (sockaddr*)&addr, sizeof(addr)) < 0) + raise_socket_error("failed to bind"); +} + +void UDSSocket::listen(int backlog) const +{ + if (::listen(this->_fd, backlog) < 0) + raise_socket_error("failed to listen"); +} + +std::unique_ptr UDSSocket::accept() const +{ + long long fd = ::accept(this->_fd, nullptr, nullptr); + if (fd < 0) + raise_socket_error("failed to accept"); + + return std::make_unique(fd); +} + +int UDSSocket::write(const void* buffer, size_t size) const +{ + int n = ::write(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to send"); + return n; +} + +void UDSSocket::write_all(const void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->write((char*)buffer + cursor, size - cursor); +} + +void UDSSocket::write_all(std::string msg) const +{ + this->write_all(msg.data(), msg.size()); +} + +void UDSSocket::write_message(std::string msg) const +{ + uint64_t header = msg.length(); + this->write_all(&header, 8); + this->write_all(msg.data(), msg.length()); +} + +int UDSSocket::read(void* buffer, size_t size) const +{ + int n = ::read(this->_fd, buffer, size); + if (n < 0) + raise_socket_error("failed to recv"); + return n; +} + +void UDSSocket::read_exact(void* buffer, size_t size) const +{ + size_t cursor = 0; + while (cursor < size) + cursor += (size_t)this->read((char*)buffer + cursor, size - cursor); +} + +std::string UDSSocket::read_message() const +{ + uint64_t header = 0; + this->read_exact(&header, 8); + std::vector buffer(header); + this->read_exact(buffer.data(), header); + return std::string(buffer.data(), header); +} diff --git a/tests/cpp/ymq/net/uds_socket.h b/tests/cpp/ymq/net/uds_socket.h new file mode 100644 index 000000000..1a242480b --- /dev/null +++ b/tests/cpp/ymq/net/uds_socket.h @@ -0,0 +1,44 @@ +#pragma once +#include +#include +#include +#include + +#include "address.h" +#include "i_socket.h" + +class UDSSocket: public ISocket { +public: + UDSSocket(); + UDSSocket(long long fd); + ~UDSSocket(); + + // move-only + UDSSocket(UDSSocket&&) noexcept; + UDSSocket& operator=(UDSSocket&&) noexcept; + UDSSocket(const UDSSocket&) = delete; + UDSSocket& operator=(const UDSSocket&) = delete; + + void try_connect(const std::string& address, int tries = 10) const override; + void bind(const std::string& address) const override; + void listen(int backlog = 5) const override; + std::unique_ptr accept() const override; + + void write_all(const void* data, size_t size) const override; + void write_all(std::string msg) const override; + + void read_exact(void* buffer, size_t size) const override; + + void write_message(std::string msg) const override; + + std::string read_message() const override; + +private: + long long _fd; + + // write up to `size` bytes + int write(const void* buffer, size_t size) const; + + // read up to `size` bytes + int read(void* buffer, size_t size) const; +}; diff --git a/tests/cpp/ymq/py_mitm/main.py b/tests/cpp/ymq/py_mitm/main.py index 5f5d2115b..f349ce211 100644 --- a/tests/cpp/ymq/py_mitm/main.py +++ b/tests/cpp/ymq/py_mitm/main.py @@ -13,16 +13,14 @@ import os import platform import signal -import subprocess -from typing import List from scapy.all import IP, TCP # type: ignore from tests.cpp.ymq.py_mitm import passthrough, randomly_drop_packets, send_rst_to_client -from tests.cpp.ymq.py_mitm.mitm_types import AbstractMITM, AbstractMITMInterface, TCPConnection +from tests.cpp.ymq.py_mitm.mitm_types import MITM, MITMInterface, TCPConnection -def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: AbstractMITM) -> None: +def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int, mitm: MITM) -> None: """ This function serves as a framework for man in the middle implementations A client connects to the MITM, then the MITM connects to a remote server @@ -121,7 +119,7 @@ def main(pid: int, mitm_ip: str, mitm_port: int, remote_ip: str, server_port: in return -def get_interface(mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int) -> AbstractMITMInterface: +def get_interface(mitm_ip: str, mitm_port: int, remote_ip: str, server_port: int) -> MITMInterface: """get the platform-specific mitm interface""" system = platform.system() @@ -154,9 +152,9 @@ def signal_ready(pid: int) -> None: TESTCASES = { - "passthrough": passthrough, - "randomly_drop_packets": randomly_drop_packets, - "send_rst_to_client": send_rst_to_client, + "passthrough": passthrough.PassthroughMITM, + "randomly_drop_packets": randomly_drop_packets.RandomlyDropPacketsMITM, + "send_rst_to_client": send_rst_to_client.SendRSTToClientMITM, } if __name__ == "__main__": @@ -168,8 +166,8 @@ def signal_ready(pid: int) -> None: parser.add_argument("remote_ip", type=str, help="The desired remote ip for the TUNTAP interface") parser.add_argument("server_port", type=int, help="The port that the remote server is bound to") - args, unknown = parser.parse_known_args() + args, extra = parser.parse_known_args() - module = TESTCASES[args.testcase] + mitm = TESTCASES[args.testcase] - main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, module.MITM(*unknown)) + main(args.pid, args.mitm_ip, args.mitm_port, args.remote_ip, args.server_port, mitm(*extra)) diff --git a/tests/cpp/ymq/py_mitm/mitm_types.py b/tests/cpp/ymq/py_mitm/mitm_types.py index 16eb8d122..fdc90fa0f 100644 --- a/tests/cpp/ymq/py_mitm/mitm_types.py +++ b/tests/cpp/ymq/py_mitm/mitm_types.py @@ -52,7 +52,7 @@ def rewrite(self, pkt: IP, ack: Optional[int] = None, data=None): ) -class AbstractMITMInterface(ABC): +class MITMInterface(ABC): @abstractmethod def recv(self) -> Packet: ... @@ -62,11 +62,11 @@ def send(self, pkt: Packet) -> None: ... -class AbstractMITM(ABC): +class MITM(ABC): @abstractmethod def proxy( self, - interface: AbstractMITMInterface, + interface: MITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/passthrough.py b/tests/cpp/ymq/py_mitm/passthrough.py index b0b99262a..1886321f2 100644 --- a/tests/cpp/ymq/py_mitm/passthrough.py +++ b/tests/cpp/ymq/py_mitm/passthrough.py @@ -7,13 +7,13 @@ from typing import Optional -from tests.cpp.ymq.py_mitm.mitm_types import IP, AbstractMITM, AbstractMITMInterface, TCPConnection +from tests.cpp.ymq.py_mitm.mitm_types import IP, MITM, MITMInterface, TCPConnection -class MITM(AbstractMITM): +class PassthroughMITM(MITM): def proxy( self, - tuntap: AbstractMITMInterface, + tuntap: MITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py index b9c1fff1e..a9d3df367 100644 --- a/tests/cpp/ymq/py_mitm/randomly_drop_packets.py +++ b/tests/cpp/ymq/py_mitm/randomly_drop_packets.py @@ -5,10 +5,10 @@ import random from typing import Optional -from tests.cpp.ymq.py_mitm.mitm_types import IP, AbstractMITM, AbstractMITMInterface, TCPConnection +from tests.cpp.ymq.py_mitm.mitm_types import IP, MITM, MITMInterface, TCPConnection -class MITM(AbstractMITM): +class RandomlyDropPacketsMITM(MITM): def __init__(self, drop_percentage: str): self._drop_percentage = float(drop_percentage) self._consecutive_drop_limit = 3 @@ -25,7 +25,7 @@ def can_drop_server(self) -> bool: def proxy( self, - tuntap: AbstractMITMInterface, + tuntap: MITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/send_rst_to_client.py b/tests/cpp/ymq/py_mitm/send_rst_to_client.py index 69a60db21..3829ae03d 100644 --- a/tests/cpp/ymq/py_mitm/send_rst_to_client.py +++ b/tests/cpp/ymq/py_mitm/send_rst_to_client.py @@ -4,10 +4,10 @@ from typing import Optional -from tests.cpp.ymq.py_mitm.mitm_types import IP, TCP, AbstractMITM, AbstractMITMInterface, TCPConnection +from tests.cpp.ymq.py_mitm.mitm_types import IP, MITM, TCP, MITMInterface, TCPConnection -class MITM(AbstractMITM): +class SendRSTToClientMITM(MITM): def __init__(self): # count the number of psh-acks sent by the client self._client_pshack_counter = 0 @@ -15,7 +15,7 @@ def __init__(self): def proxy( self, - tuntap: AbstractMITMInterface, + tuntap: MITMInterface, pkt: IP, sender: TCPConnection, client_conn: Optional[TCPConnection], diff --git a/tests/cpp/ymq/py_mitm/windivert.py b/tests/cpp/ymq/py_mitm/windivert.py index 6904fd84c..dd494e7f5 100644 --- a/tests/cpp/ymq/py_mitm/windivert.py +++ b/tests/cpp/ymq/py_mitm/windivert.py @@ -4,10 +4,10 @@ import pydivert from scapy.all import IP, Packet # type: ignore[attr-defined] -from tests.cpp.ymq.py_mitm.mitm_types import AbstractMITMInterface +from tests.cpp.ymq.py_mitm.mitm_types import MITMInterface -class WindivertMITMInterface(AbstractMITMInterface): +class WindivertMITMInterface(MITMInterface): _windivert: pydivert.WinDivert _binder: socket.socket diff --git a/tests/cpp/ymq/test_ymq.cpp b/tests/cpp/ymq/test_ymq.cpp index 9489820ff..4a6718b62 100644 --- a/tests/cpp/ymq/test_ymq.cpp +++ b/tests/cpp/ymq/test_ymq.cpp @@ -8,7 +8,6 @@ // // the test cases are at the bottom of this file, after the clients and servers // the documentation for each case is found on the TEST() definition - #include #ifdef __linux__ @@ -22,9 +21,10 @@ #define NOMINMAX #include #endif // _WIN32 - #include #include +#include +#include #include #include #include @@ -36,21 +36,45 @@ #include "scaler/ymq/simple_interface.h" #include "tests/cpp/ymq/common/testing.h" #include "tests/cpp/ymq/common/utils.h" -#include "tests/cpp/ymq/net/socket.h" +#include "tests/cpp/ymq/net/socket_utils.h" using namespace scaler::ymq; using namespace std::chrono_literals; +// a test suite that's parameterized by transport protocol (e.g. "tcp", "ipc") +class CcYmqTestSuiteParametrized: public ::testing::TestWithParam { +protected: + std::string GetAddress(int port) + { + const std::string& transport = GetParam(); + if (transport == "tcp") { + return std::format("tcp://127.0.0.1:{}", port); + } +#ifdef __linux__ + if (transport == "ipc") { + // using a unique path for each test based on port + const char* runner_temp = std::getenv("RUNNER_TEMP"); + if (runner_temp) { + return std::format("ipc://{}/ymq-test-{}.ipc", runner_temp, port); + } + return std::format("ipc:///tmp/ymq-test-{}.ipc", port); + } +#endif + // Gtest should not select this for unsupported platforms, but as a fallback, + // return something that will cause tests to fail clearly. + return "invalid-transport"; + } +}; + // -------------------- // clients and servers // -------------------- - -TestResult basic_server_ymq(std::string host, uint16_t port) +TestResult basic_server_ymq(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + syncBindSocket(socket, address); auto result = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(result.has_value()); @@ -61,12 +85,12 @@ TestResult basic_server_ymq(std::string host, uint16_t port) return TestResult::Success; } -TestResult basic_client_ymq(std::string host, uint16_t port) +TestResult basic_client_ymq(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + syncConnectSocket(socket, address); auto result = syncSendMessage(socket, {.address = Bytes("server"), .payload = Bytes("yi er san si wu liu")}); context.removeIOSocket(socket); @@ -74,41 +98,39 @@ TestResult basic_client_ymq(std::string host, uint16_t port) return TestResult::Success; } -TestResult basic_server_raw(uint16_t port) +TestResult basic_server_raw(std::string address_str) { - Socket socket; + auto socket = bind_socket(address_str); - socket.bind(port); - socket.listen(); - auto client = socket.accept(); - client.write_message("server"); - auto client_identity = client.read_message(); + socket->listen(5); // Default backlog + auto client = socket->accept(); + client->write_message("server"); + auto client_identity = client->read_message(); RETURN_FAILURE_IF_FALSE(client_identity == "client"); - auto msg = client.read_message(); + auto msg = client->read_message(); RETURN_FAILURE_IF_FALSE(msg == "yi er san si wu liu"); return TestResult::Success; } -TestResult basic_client_raw(std::string host, uint16_t port) +TestResult basic_client_raw(std::string address_str) { - Socket socket; + auto socket = connect_socket(address_str); - socket.try_connect(host.c_str(), port); - socket.write_message("client"); - auto server_identity = socket.read_message(); + socket->write_message("client"); + auto server_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); - socket.write_message("yi er san si wu liu"); + socket->write_message("yi er san si wu liu"); return TestResult::Success; } -TestResult server_receives_big_message(std::string host, uint16_t port) +TestResult server_receives_big_message(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + syncBindSocket(socket, address); auto result = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(result.has_value()); @@ -119,26 +141,25 @@ TestResult server_receives_big_message(std::string host, uint16_t port) return TestResult::Success; } -TestResult client_sends_big_message(std::string host, uint16_t port) +TestResult client_sends_big_message(std::string address_str) { - Socket socket; + auto socket = connect_socket(address_str); - socket.try_connect(host.c_str(), port); - socket.write_message("client"); - auto remote_identity = socket.read_message(); + socket->write_message("client"); + auto remote_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(remote_identity == "server"); std::string msg(500'000'000, '.'); - socket.write_message(msg); + socket->write_message(msg); return TestResult::Success; } -TestResult reconnect_server_main(std::string host, uint16_t port) +TestResult reconnect_server_main(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + syncBindSocket(socket, address); auto result = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(result.has_value()); @@ -152,12 +173,12 @@ TestResult reconnect_server_main(std::string host, uint16_t port) return TestResult::Success; } -TestResult reconnect_client_main(std::string host, uint16_t port) +TestResult reconnect_client_main(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + syncConnectSocket(socket, address); // create the recv future in advance, this remains active between reconnects auto future = futureRecvMessage(socket); @@ -191,66 +212,63 @@ TestResult reconnect_client_main(std::string host, uint16_t port) return TestResult::Failure; } -TestResult client_simulated_slow_network(const char* host, uint16_t port) +TestResult client_simulated_slow_network(std::string address) { - Socket socket; + auto socket = connect_socket(address); - socket.try_connect(host, port); - socket.write_message("client"); - auto remote_identity = socket.read_message(); + socket->write_message("client"); + auto remote_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(remote_identity == "server"); std::string message = "yi er san si wu liu"; uint64_t header = message.length(); - socket.write_all((char*)&header, 4); + socket->write_all((char*)&header, 4); std::this_thread::sleep_for(2s); - socket.write_all((char*)&header + 4, 4); + socket->write_all((char*)&header + 4, 4); std::this_thread::sleep_for(3s); - socket.write_all(message.data(), header / 2); + socket->write_all(message.data(), header / 2); std::this_thread::sleep_for(2s); - socket.write_all(message.data() + header / 2, header - header / 2); + socket->write_all(message.data() + header / 2, header - header / 2); return TestResult::Success; } -TestResult client_sends_incomplete_identity(const char* host, uint16_t port) +TestResult client_sends_incomplete_identity(std::string address) { // open a socket, write an incomplete identity and exit { - Socket socket; + auto socket = connect_socket(address); - socket.try_connect(host, port); - - auto server_identity = socket.read_message(); + auto server_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); // write incomplete identity and exit std::string identity = "client"; uint64_t header = identity.length(); - socket.write_all((char*)&header, 8); - socket.write_all(identity.data(), identity.length() - 2); + socket->write_all((char*)&header, 8); + socket->write_all(identity.data(), identity.length() - 2); } // connect again and try to send a message { - Socket socket; - socket.try_connect(host, port); - auto server_identity = socket.read_message(); + auto socket = connect_socket(address); + + auto server_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); - socket.write_message("client"); - socket.write_message("yi er san si wu liu"); + socket->write_message("client"); + socket->write_message("yi er san si wu liu"); } return TestResult::Success; } -TestResult server_receives_huge_header(const char* host, uint16_t port) +TestResult server_receives_huge_header(std::string address) { IOContext context(1); auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + syncBindSocket(socket, address); auto result = syncRecvMessage(socket); RETURN_FAILURE_IF_FALSE(result.has_value()); @@ -261,44 +279,42 @@ TestResult server_receives_huge_header(const char* host, uint16_t port) return TestResult::Success; } -TestResult client_sends_huge_header(const char* host, uint16_t port) +TestResult client_sends_huge_header(std::string address) { #ifdef __linux__ // ignore SIGPIPE so that write() returns EPIPE instead of crashing the program signal(SIGPIPE, SIG_IGN); + + int expected_error = EPIPE; #endif +#ifdef _WIN32 + int expected_error = WSAECONNABORTED; +#endif // _WIN32 { - Socket socket; + auto socket = connect_socket(address); - socket.try_connect(host, port); - socket.write_message("client"); - auto server_identity = socket.read_message(); + socket->write_message("client"); + auto server_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); // write the huge header uint64_t header = std::numeric_limits::max(); - socket.write_all((char*)&header, 8); + socket->write_all((char*)&header, 8); size_t i = 0; for (; i < 10; i++) { std::this_thread::sleep_for(1s); try { - socket.write_all("yi er san si wu liu"); + socket->write_all("yi er san si wu liu"); } catch (const std::system_error& e) { -#ifdef __linux__ - if (e.code().value() == EPIPE) { -#endif // __linux__ -#ifdef _WIN32 - if (e.code().value() == WSAECONNABORTED) { -#endif // _WIN32 - std::cout << "writing failed, as expected after sending huge header, continuing...\n"; - break; // this is expected - } - - throw; // rethrow other errors + if (e.code().value() == expected_error) { + std::cout << "writing failed, as expected after sending huge header, continuing...\n"; + break; // this is expected } + + throw; // rethrow other errors } if (i == 10) { @@ -308,672 +324,674 @@ TestResult client_sends_huge_header(const char* host, uint16_t port) } { - Socket socket; - socket.try_connect(host, port); - socket.write_message("client"); - auto server_identity = socket.read_message(); + auto socket = connect_socket(address); + + socket->write_message("client"); + auto server_identity = socket->read_message(); RETURN_FAILURE_IF_FALSE(server_identity == "server"); - socket.write_message("yi er san si wu liu"); + socket->write_message("yi er san si wu liu"); } return TestResult::Success; } +} - TestResult server_receives_empty_messages(const char* host, uint16_t port) - { - IOContext context(1); +TestResult server_receives_empty_messages(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, address); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == ""); - auto result2 = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result2.has_value()); - RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); + auto result2 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2->payload.as_string() == ""); - context.removeIOSocket(socket); + context.removeIOSocket(socket); - return TestResult::Success; - } + return TestResult::Success; +} - TestResult client_sends_empty_messages(std::string host, uint16_t port) - { - IOContext context(1); +TestResult client_sends_empty_messages(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, address); - auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); - RETURN_FAILURE_IF_FALSE(!error); + auto error = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes()}); + RETURN_FAILURE_IF_FALSE(!error); - auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); - RETURN_FAILURE_IF_FALSE(!error2); + auto error2 = syncSendMessage(socket, Message {.address = Bytes(), .payload = Bytes("")}); + RETURN_FAILURE_IF_FALSE(!error2); - context.removeIOSocket(socket); + context.removeIOSocket(socket); - return TestResult::Success; - } + return TestResult::Success; +} - TestResult pubsub_subscriber(std::string host, uint16_t port, std::string topic, int differentiator, void* sem) - { - IOContext context(1); +TestResult pubsub_subscriber(std::string address, std::string topic, int differentiator, void* sem) +{ + IOContext context(1); - auto socket = - syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); + auto socket = + syncCreateSocket(context, IOSocketType::Unicast, std::format("{}_subscriber_{}", topic, differentiator)); - std::this_thread::sleep_for(500ms); + std::this_thread::sleep_for(500ms); - syncConnectSocket(socket, format_address(host, port)); + syncConnectSocket(socket, address); - std::this_thread::sleep_for(500ms); + std::this_thread::sleep_for(500ms); #ifdef __linux__ - if (sem_post((sem_t*)sem) < 0) - throw std::system_error(errno, std::generic_category(), "failed to signal semaphore"); - sem_close((sem_t*)sem); + if (sem_post((sem_t*)sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to signal semaphore"); + sem_close((sem_t*)sem); #endif // __linux__ #ifdef _WIN32 - if (!ReleaseSemaphore(sem, 1, nullptr)) - throw std::system_error(GetLastError(), std::generic_category(), "failed to signal semaphore"); + if (!ReleaseSemaphore(sem, 1, nullptr)) + throw std::system_error(GetLastError(), std::generic_category(), "failed to signal semaphore"); #endif // _WIN32 + auto msg = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(msg.has_value()); + RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello"); - auto msg = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(msg.has_value()); - RETURN_FAILURE_IF_FALSE(msg->payload.as_string() == "hello"); - - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - // topic: the identifier of the topic, must match what's passed to the subscribers - // sem: a semaphore to synchronize the publisher and subscriber processes - // n: the number of subscribers - TestResult pubsub_publisher(std::string host, uint16_t port, std::string topic, void* sem, int n) - { - IOContext context(1); +// topic: the identifier of the topic, must match what's passed to the subscribers +// sem: a semaphore to synchronize the publisher and subscriber processes +// n: the number of subscribers +TestResult pubsub_publisher(std::string address, std::string topic, void* sem, int n) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); - syncBindSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Multicast, "publisher"); + syncBindSocket(socket, address); // wait for the subscribers to be ready #ifdef __linux__ - for (int i = 0; i < n; i++) - if (sem_wait((sem_t*)sem) < 0) - throw std::system_error(errno, std::generic_category(), "failed to wait on semaphore"); - sem_close((sem_t*)sem); + for (int i = 0; i < n; i++) + if (sem_wait((sem_t*)sem) < 0) + throw std::system_error(errno, std::generic_category(), "failed to wait on semaphore"); + sem_close((sem_t*)sem); #endif // __linux__ #ifdef _WIN32 - for (int i = 0; i < n; i++) - if (WaitForSingleObject(sem, 3000) != WAIT_OBJECT_0) - throw std::system_error(GetLastError(), std::generic_category(), "failed to wait on semaphore"); + for (int i = 0; i < n; i++) + if (WaitForSingleObject(sem, 3000) != WAIT_OBJECT_0) + throw std::system_error(GetLastError(), std::generic_category(), "failed to wait on semaphore"); #endif // _WIN32 - // the topic doesn't match, so no one should receive this - auto error = syncSendMessage( - socket, Message {.address = Bytes(std::format("x{}", topic)), .payload = Bytes("no one should get this")}); - RETURN_FAILURE_IF_FALSE(!error); + auto error = syncSendMessage( + socket, Message {.address = Bytes(std::format("x{}", topic)), .payload = Bytes("no one should get this")}); + RETURN_FAILURE_IF_FALSE(!error); - // no one should receive this either - error = syncSendMessage( - socket, - Message {.address = Bytes(std::format("{}x", topic)), .payload = Bytes("no one should get this either")}); - RETURN_FAILURE_IF_FALSE(!error); + // no one should receive this either + error = syncSendMessage( + socket, + Message {.address = Bytes(std::format("{}x", topic)), .payload = Bytes("no one should get this either")}); + RETURN_FAILURE_IF_FALSE(!error); - error = syncSendMessage(socket, Message {.address = Bytes(topic), .payload = Bytes("hello")}); - RETURN_FAILURE_IF_FALSE(!error); + error = syncSendMessage(socket, Message {.address = Bytes(topic), .payload = Bytes("hello")}); + RETURN_FAILURE_IF_FALSE(!error); - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult client_close_established_connection_client(std::string host, uint16_t port) - { - IOContext context(1); +TestResult client_close_established_connection_client(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, address); - auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); + auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); - socket->closeConnection("server"); - context.requestIOSocketStop(socket); + socket->closeConnection("server"); + context.requestIOSocketStop(socket); - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult client_close_established_connection_server(std::string host, uint16_t port) - { - IOContext context(1); +TestResult client_close_established_connection_server(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); - syncBindSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Binder, "server"); + syncBindSocket(socket, address); - auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); + auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); - result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); - RETURN_FAILURE_IF_FALSE( - result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) + result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); + RETURN_FAILURE_IF_FALSE( + result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult close_nonexistent_connection() - { - IOContext context(1); +TestResult close_nonexistent_connection() +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - // note: we're not connected to anything; this connection does not exist - // this should be a no-op.. - socket->closeConnection("server"); + // note: we're not connected to anything; this connection does not exist + // this should be a no-op.. + socket->closeConnection("server"); - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult test_request_stop() - { - IOContext context(1); +TestResult test_request_stop() +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - auto future = futureRecvMessage(socket); - context.requestIOSocketStop(socket); + auto future = futureRecvMessage(socket); + context.requestIOSocketStop(socket); - auto result = future.wait_for(100ms); - RETURN_FAILURE_IF_FALSE(result == std::future_status::ready, "future should have completed"); + auto result = future.wait_for(100ms); + RETURN_FAILURE_IF_FALSE(result == std::future_status::ready, "future should have completed"); - // the future created beore requestion stop should have been cancelled with an error - auto result2 = future.get(); - RETURN_FAILURE_IF_FALSE(!result2.has_value()); - RETURN_FAILURE_IF_FALSE(result2.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); + // the future created beore requestion stop should have been cancelled with an error + auto result2 = future.get(); + RETURN_FAILURE_IF_FALSE(!result2.has_value()); + RETURN_FAILURE_IF_FALSE(result2.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); - // and the same for any attempts to use the socket after it's been closed - auto result3 = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result3.has_value()); - RETURN_FAILURE_IF_FALSE(result3.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); + // and the same for any attempts to use the socket after it's been closed + auto result3 = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result3.has_value()); + RETURN_FAILURE_IF_FALSE(result3.error()._errorCode == scaler::ymq::Error::ErrorCode::IOSocketStopRequested); - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult client_socket_stop_before_close_connection(std::string host, uint16_t port) - { - IOContext context(1); +TestResult client_socket_stop_before_close_connection(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); - syncConnectSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "client"); + syncConnectSocket(socket, address); - auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); + auto error = syncSendMessage(socket, Message {.address = Bytes("server"), .payload = Bytes("0")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "1"); - context.requestIOSocketStop(socket); - socket->closeConnection("server"); + context.requestIOSocketStop(socket); + socket->closeConnection("server"); - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - TestResult server_socket_stop_before_close_connection(std::string host, uint16_t port) - { - IOContext context(1); +TestResult server_socket_stop_before_close_connection(std::string address) +{ + IOContext context(1); - auto socket = syncCreateSocket(context, IOSocketType::Connector, "server"); - syncBindSocket(socket, format_address(host, port)); + auto socket = syncCreateSocket(context, IOSocketType::Connector, "server"); + syncBindSocket(socket, address); - auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); - RETURN_FAILURE_IF_FALSE(!error); - auto result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(result.has_value()); - RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); + auto error = syncSendMessage(socket, Message {.address = Bytes("client"), .payload = Bytes("1")}); + RETURN_FAILURE_IF_FALSE(!error); + auto result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(result.has_value()); + RETURN_FAILURE_IF_FALSE(result->payload.as_string() == "0"); - result = syncRecvMessage(socket); - RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); - RETURN_FAILURE_IF_FALSE( - result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) + result = syncRecvMessage(socket); + RETURN_FAILURE_IF_FALSE(!result.has_value(), "expected recv message to fail"); + RETURN_FAILURE_IF_FALSE( + result.error()._errorCode == scaler::ymq::Error::ErrorCode::ConnectorSocketClosedByRemoteEnd) - context.removeIOSocket(socket); - return TestResult::Success; - } + context.removeIOSocket(socket); + return TestResult::Success; +} - // ------------- - // test cases - // ------------- - - // this is a 'basic' test which sends a single message from a client to a server - // in this variant, both the client and server are implemented using YMQ - // - // this case includes a _delay_ - // this is a thread sleep that happens after the client sends the message, to delay the close() of the socket - // at the moment, if this delay is missing, YMQ will not shut down correctly - TEST(CcYmqTestSuite, TestBasicYMQClientYMQServer) - { - const auto host = "localhost"; - const auto port = 2889; +// ------------- +// test cases +// ------------- +// this is a 'basic' test which sends a single message from a client to a server +// in this variant, both the client and server are implemented using YMQ +// +// this case includes a _delay_ +// this is a thread sleep that happens after the client sends the message, to delay the close() of the socket +// at the moment, if this delay is missing, YMQ will not shut down correctly +TEST_P(CcYmqTestSuiteParametrized, TestBasicYMQClientYMQServer) +{ + const auto address = GetAddress(2889); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_ymq(host, port); }}); + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_ymq(address); }, [=] { return basic_server_ymq(address); }}); - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); - } + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} - // same as above, except YMQs protocol is directly implemented on top of a TCP socket - TEST(CcYmqTestSuite, TestBasicRawClientYMQServer) - { - const auto host = "localhost"; - const auto port = 2890; +// same as above, except YMQs protocol is directly implemented on top of a TCP socket +TEST_P(CcYmqTestSuiteParametrized, TestBasicRawClientYMQServer) +{ + const auto address = GetAddress(2891); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = - test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_raw(address); }, [=] { return basic_server_ymq(address); }}); - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); - } + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} - TEST(CcYmqTestSuite, TestBasicRawClientRawServer) - { - const auto host = "localhost"; - const auto port = 2891; +TEST_P(CcYmqTestSuiteParametrized, TestBasicRawClientRawServer) +{ + const auto address = GetAddress(2892); - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_raw(port); }}); + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_raw(address); }, [=] { return basic_server_raw(address); }}); - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); - } + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} - // this is the same as above, except that it has no delay before calling close() on the socket - TEST(CcYmqTestSuite, TestBasicRawClientRawServerNoDelay) - { - const auto host = "localhost"; - const auto port = 2892; +// this is the same as above, except that it has no delay before calling close() on the socket +TEST_P(CcYmqTestSuiteParametrized, TestBasicRawClientRawServerNoDelay) +{ + const auto address = GetAddress(2893); - auto result = - test(10, {[=] { return basic_client_raw(host, port); }, [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + auto result = test(10, {[=] { return basic_client_raw(address); }, [=] { return basic_server_ymq(address); }}); - TEST(CcYmqTestSuite, TestBasicDelayYMQClientRawServer) - { - const auto host = "localhost"; - const auto port = 2893; + EXPECT_EQ(result, TestResult::Success); +} - // this is the test harness, it accepts a timeout, a list of functions to run, - // and an optional third argument used to coordinate the execution of python (for mitm) - auto result = test(10, {[=] { return basic_client_ymq(host, port); }, [=] { return basic_server_raw(port); }}); +TEST_P(CcYmqTestSuiteParametrized, TestBasicDelayYMQClientRawServer) +{ + const auto address = GetAddress(2894); - // test() aggregates the results across all of the provided functions - EXPECT_EQ(result, TestResult::Success); - } + // this is the test harness, it accepts a timeout, a list of functions to run, + // and an optional third argument used to coordinate the execution of python (for mitm) + auto result = test(10, {[=] { return basic_client_ymq(address); }, [=] { return basic_server_raw(address); }}); - // in this test case, the client sends a large message to the server - // YMQ should be able to handle this without issue - TEST(CcYmqTestSuite, TestClientSendBigMessageToServer) - { - const auto host = "localhost"; - const auto port = 2894; - - auto result = test( - 10, - {[=] { return client_sends_big_message(host, port); }, - [=] { return server_receives_big_message(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + // test() aggregates the results across all of the provided functions + EXPECT_EQ(result, TestResult::Success); +} - // this is the no-op/passthrough man in the middle test - // for this test case we use YMQ on both the client side and the server side - // the client connects to the mitm, and the mitm connects to the server - // when the mitm receives packets from the client, it forwards it to the server without changing it - // and similarly when it receives packets from the server, it forwards them to the client - // - // the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, - // and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port - // this defines the address of the mitm, and the addresses that can connect to it - // for more, see the python mitm files - TEST(CcYmqTestSuite, TestMitmPassthrough) - { +// in this test case, the client sends a large message to the server +// YMQ should be able to handle this without issue +TEST_P(CcYmqTestSuiteParametrized, TestClientSendBigMessageToServer) +{ + const auto address = GetAddress(2895); + + auto result = test( + 10, {[=] { return client_sends_big_message(address); }, [=] { return server_receives_big_message(address); }}); + + EXPECT_EQ(result, TestResult::Success); +} + +// this is the no-op/passthrough man in the middle test +// for this test case we use YMQ on both the client side and the server side +// the client connects to the mitm, and the mitm connects to the server +// when the mitm receives packets from the client, it forwards it to the server without changing it +// and similarly when it receives packets from the server, it forwards them to the client +// +// the mitm is implemented in Python. we pass the name of the test case, which corresponds to the Python filename, +// and a list of arguments, which are: mitm ip, mitm port, remote ip, remote port +// this defines the address of the mitm, and the addresses that can connect to it +// for more, see the python mitm files +TEST(CcYmqTestSuite, TestMitmPassthrough) +{ #ifdef __linux__ - auto mitm_ip = "192.0.2.4"; - auto remote_ip = "192.0.2.3"; + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; #endif // __linux__ #ifdef _WIN32 - auto mitm_ip = "127.0.0.1"; - auto remote_ip = "127.0.0.1"; + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; #endif // _WIN32 - auto mitm_port = random_port(); - auto remote_port = 23571; - - // the Python program must be the first and only the first function passed to test() - // we must also pass `true` as the third argument to ensure that Python is fully started - // before beginning the test - auto result = test( - 20, - {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, - [=] { return basic_client_ymq(mitm_ip, mitm_port); }, - [=] { return basic_server_ymq(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); - } + auto mitm_port = random_port(); + auto remote_port = 23571; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_ymq(std::format("tcp://{}:{}", mitm_ip, mitm_port)); }, + [=] { return basic_server_ymq(std::format("tcp://{}:{}", remote_ip, remote_port)); }}, + true); + + EXPECT_EQ(result, TestResult::Success); +} - // this is the same as the above, but both the client and server use raw sockets - TEST(CcYmqTestSuite, TestMitmPassthroughRaw) - { +// this is the same as the above, but both the client and server use raw sockets +TEST(CcYmqTestSuite, TestMitmPassthroughRaw) +{ #ifdef __linux__ - auto mitm_ip = "192.0.2.4"; - auto remote_ip = "192.0.2.3"; + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; #endif // __linux__ #ifdef _WIN32 - auto mitm_ip = "127.0.0.1"; - auto remote_ip = "127.0.0.1"; + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; #endif // _WIN32 - auto mitm_port = random_port(); - auto remote_port = 23574; - - // the Python program must be the first and only the first function passed to test() - // we must also pass `true` as the third argument to ensure that Python is fully started - // before beginning the test - auto result = test( - 20, - {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, - [=] { return basic_client_raw(mitm_ip, mitm_port); }, - [=] { return basic_server_raw(remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); - } + auto mitm_port = random_port(); + auto remote_port = 23574; + + // the Python program must be the first and only the first function passed to test() + // we must also pass `true` as the third argument to ensure that Python is fully started + // before beginning the test + auto result = test( + 20, + {[=] { return run_mitm("passthrough", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return basic_client_ymq(std::format("tcp://{}:{}", mitm_ip, mitm_port)); }, + [=] { return basic_server_ymq(std::format("tcp://{}:{}", remote_ip, remote_port)); }}, + true); + EXPECT_EQ(result, TestResult::Success); +} - // this test uses the mitm to test the reconnect logic of YMQ by sending RST packets - TEST(CcYmqTestSuite, TestMitmReconnect) - { +// this test uses the mitm to test the reconnect logic of YMQ by sending RST packets +TEST(CcYmqTestSuite, TestMitmReconnect) +{ #ifdef __linux__ - auto mitm_ip = "192.0.2.4"; - auto remote_ip = "192.0.2.3"; + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; #endif // __linux__ #ifdef _WIN32 - auto mitm_ip = "127.0.0.1"; - auto remote_ip = "127.0.0.1"; + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; #endif // _WIN32 - auto mitm_port = random_port(); - auto remote_port = 23572; - - auto result = test( - 30, - {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, - [=] { return reconnect_client_main(mitm_ip, mitm_port); }, - [=] { return reconnect_server_main(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); - } + auto mitm_port = random_port(); + auto remote_port = 23572; - // TODO: Make this more reliable, and re-enable it - // in this test, the mitm drops a random % of packets arriving from the client and server - TEST(CcYmqTestSuite, TestMitmRandomlyDropPackets) - { + auto result = test( + 30, + {[=] { return run_mitm("send_rst_to_client", mitm_ip, mitm_port, remote_ip, remote_port); }, + [=] { return reconnect_client_main(std::format("tcp://{}:{}", mitm_ip, mitm_port)); }, + [=] { return reconnect_server_main(std::format("tcp://{}:{}", remote_ip, remote_port)); }}, + true); + + EXPECT_EQ(result, TestResult::Success); +} + +// in this test, the mitm drops a random % of packets arriving from the client and server +TEST(CcYmqTestSuite, TestMitmRandomlyDropPackets) +{ #ifdef __linux__ - auto mitm_ip = "192.0.2.4"; - auto remote_ip = "192.0.2.3"; + auto mitm_ip = "192.0.2.4"; + auto remote_ip = "192.0.2.3"; #endif // __linux__ #ifdef _WIN32 - auto mitm_ip = "127.0.0.1"; - auto remote_ip = "127.0.0.1"; + auto mitm_ip = "127.0.0.1"; + auto remote_ip = "127.0.0.1"; #endif // _WIN32 - auto mitm_port = random_port(); - auto remote_port = 23573; - - auto result = test( - 60, - {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, - [=] { return basic_client_ymq(mitm_ip, mitm_port); }, - [=] { return basic_server_ymq(remote_ip, remote_port); }}, - true); - EXPECT_EQ(result, TestResult::Success); - } + auto mitm_port = random_port(); + auto remote_port = 23573; - // in this test the client is sending a message to the server - // but we simulate a slow network connection by sending the message in segmented chunks - TEST(CcYmqTestSuite, TestSlowNetwork) - { - const auto host = "localhost"; - const auto port = 2895; + auto result = test( + 60, + {[=] { return run_mitm("randomly_drop_packets", mitm_ip, mitm_port, remote_ip, remote_port, {"0.3"}); }, + [=] { return basic_client_ymq(std::format("tcp://{}:{}", mitm_ip, mitm_port)); }, + [=] { return basic_server_ymq(std::format("tcp://{}:{}", remote_ip, remote_port)); }}, + true); - auto result = test( - 20, - {[=] { return client_simulated_slow_network(host, port); }, [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + EXPECT_EQ(result, TestResult::Success); +} - // TODO: figure out why this test fails in ci sometimes, and re-enable - // - // in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects - // then a new client connection is established, and this one sends a complete identity and message - // YMQ should be able to recover from a poorly-behaved client like this - TEST(CcYmqTestSuite, TestClientSendIncompleteIdentity) - { - const auto host = "localhost"; - const auto port = 2896; - - auto result = test( - 20, - {[=] { return client_sends_incomplete_identity(host, port); }, - [=] { return basic_server_ymq(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } +// in this test the client is sending a message to the server +// but we simulate a slow network connection by sending the message in segmented chunks +TEST_P(CcYmqTestSuiteParametrized, TestSlowNetwork) +{ + const auto address = GetAddress(2905); - // TODO: this should pass - // currently YMQ rejects the second connection, saying that the message is too large even when it isn't - // - // in this test, the client sends an unrealistically-large header - // it is important that YMQ checks the header size before allocating memory - // both for resilence against attacks and to guard against errors - TEST(CcYmqTestSuite, TestClientSendHugeHeader) - { - const auto host = "localhost"; - const auto port = 2897; - - auto result = test( - 20, - {[=] { return client_sends_huge_header(host, port); }, - [=] { return server_receives_huge_header(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + auto result = + test(20, {[=] { return client_simulated_slow_network(address); }, [=] { return basic_server_ymq(address); }}); - // in this test, the client sends empty messages to the server - // there are in effect two kinds of empty messages: Bytes() and Bytes("") - // in the former case, the bytes contains a nullptr - // in the latter case, the bytes contains a zero-length allocation - // it's important that the behaviour of YMQ is known for both of these cases - TEST(CcYmqTestSuite, TestClientSendEmptyMessage) - { - const auto host = "localhost"; - const auto port = 2898; - - auto result = test( - 20, - {[=] { return client_sends_empty_messages(host, port); }, - [=] { return server_receives_empty_messages(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + EXPECT_EQ(result, TestResult::Success); +} - // this case tests the publish-subscribe pattern of YMQ - // we create one publisher and two subscribers with a common topic - // the publisher will send two messages to the wrong topic - // none of the subscribers should receive these - // and then the publisher will send a message to the correct topic - // both subscribers should receive this message - TEST(CcYmqTestSuite, TestPubSub) - { - const auto host = "localhost"; - const auto port = 2900; - auto topic = "mytopic"; +// in this test, a client connects to the YMQ server but only partially sends its identity and then disconnects +// then a new client connection is established, and this one sends a complete identity and message +// YMQ should be able to recover from a poorly-behaved client like this +TEST_P(CcYmqTestSuiteParametrized, TestClientSendIncompleteIdentity) +{ + const auto address = GetAddress(2896); -// allocate a semaphore to synchronize the publisher and subscriber processes -#ifdef __linux__ - sem_t* sem = static_cast( - mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + auto result = test( + 20, {[=] { return client_sends_incomplete_identity(address); }, [=] { return basic_server_ymq(address); }}); - if (sem == MAP_FAILED) - throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + EXPECT_EQ(result, TestResult::Success); +} + +// in this test, the client sends an unrealistically-large header +// it is important that YMQ checks the header size before allocating memory +// both for resilence against attacks and to guard against errors +TEST_P(CcYmqTestSuiteParametrized, TestClientSendHugeHeader) +{ + const auto address = GetAddress(2897); - if (sem_init(sem, 1, 0) < 0) - throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); + auto result = test( + 20, {[=] { return client_sends_huge_header(address); }, [=] { return server_receives_huge_header(address); }}); + + EXPECT_EQ(result, TestResult::Success); +} + +// in this test, the client sends empty messages to the server +// there are in effect two kinds of empty messages: Bytes() and Bytes("") +// in the former case, the bytes contains a nullptr +// in the latter case, the bytes contains a zero-length allocation +// it's important that the behaviour of YMQ is known for both of these cases +TEST_P(CcYmqTestSuiteParametrized, TestClientSendEmptyMessage) +{ + const auto address = GetAddress(2898); + + auto result = test( + 20, + {[=] { return client_sends_empty_messages(address); }, + [=] { return server_receives_empty_messages(address); }}); + + EXPECT_EQ(result, TestResult::Success); +} + +// this case tests the publish-subscribe pattern of YMQ +// we create one publisher and two subscribers with a common topic +// the publisher will send two messages to the wrong topic +// none of the subscribers should receive these +// and then the publisher will send a message to the correct topic +// both subscribers should receive this message +TEST_P(CcYmqTestSuiteParametrized, TestPubSub) +{ + const auto address = GetAddress(2900); + auto topic = "mytopic"; + + // allocate a semaphore to synchronize the publisher and subscriber processes +#ifdef __linux__ + sem_t* sem = + static_cast(mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + if (sem == MAP_FAILED) + throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + if (sem_init(sem, 1, 0) < 0) + throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); #endif // __linux__ #ifdef _WIN32 - HANDLE sem = CreateSemaphore( - nullptr, // default security attributes - 0, // initial count - 2, // maximum count - nullptr); // unnamed semaphore - if (sem == nullptr) - throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); -#endif // _WIN32 - auto result = test( - 20, - {[=] { return pubsub_publisher(host, port, topic, sem, 2); }, - [=] { return pubsub_subscriber(host, port, topic, 0, sem); }, - [=] { return pubsub_subscriber(host, port, topic, 1, sem); }}); + HANDLE sem = CreateSemaphore( + + nullptr, // default security attributes + 0, // initial count + 2, // maximum count + nullptr); // unnamed semaphore + if (sem == nullptr) + + throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); + +#endif // _WIN32 + auto result = test( + 20, + {[=] { return pubsub_publisher(address, topic, sem, 2); }, + [=] { return pubsub_subscriber(address, topic, 0, sem); }, + [=] { return pubsub_subscriber(address, topic, 1, sem); }}); #ifdef __linux__ - sem_destroy(sem); - munmap(sem, sizeof(sem_t)); + sem_destroy(sem); + munmap(sem, sizeof(sem_t)); #endif // __linux__ #ifdef _WIN32 - CloseHandle(sem); + CloseHandle(sem); #endif // _WIN32 + EXPECT_EQ(result, TestResult::Success); +} - EXPECT_EQ(result, TestResult::Success); - } - - // this sets the publisher with an empty topic and the subscribers with two other topics - // both subscribers should get all messages - TEST(CcYmqTestSuite, TestPubSubEmptyTopic) - { - const auto host = "localhost"; - const auto port = 2906; +// this sets the publisher with an empty topic and the subscribers with two other topics +// both subscribers should get all messages +TEST_P(CcYmqTestSuiteParametrized, TestPubSubEmptyTopic) +{ + const auto address = GetAddress(2906); -// allocate a semaphore to synchronize the publisher and subscriber processes + // allocate a semaphore to synchronize the publisher and subscriber processes #ifdef __linux__ - sem_t* sem = static_cast( - mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); - - if (sem == MAP_FAILED) - throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + sem_t* sem = + static_cast(mmap(nullptr, sizeof(sem_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0)); + if (sem == MAP_FAILED) + throw std::system_error(errno, std::generic_category(), "failed to map shared memory for semaphore"); + if (sem_init(sem, 1, 0) < 0) + throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); - if (sem_init(sem, 1, 0) < 0) - throw std::system_error(errno, std::generic_category(), "failed to initialize semaphore"); #endif // __linux__ #ifdef _WIN32 - HANDLE sem = CreateSemaphore( - nullptr, // default security attributes - 0, // initial count - 2, // maximum count - nullptr); // unnamed semaphore - if (sem == nullptr) - throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); -#endif // _WIN32 - auto result = test( - 20, - {[=] { return pubsub_publisher(host, port, "", sem, 2); }, - [=] { return pubsub_subscriber(host, port, "abc", 0, sem); }, - [=] { return pubsub_subscriber(host, port, "def", 1, sem); }}); + HANDLE sem = CreateSemaphore( + + nullptr, // default security attributes + 0, // initial count + 2, // maximum count + nullptr); // unnamed semaphore + if (sem == nullptr) + + throw std::system_error(GetLastError(), std::generic_category(), "failed to create semaphore"); + +#endif // _WIN32 + auto result = test( + 20, + {[=] { return pubsub_publisher(address, "", sem, 2); }, + [=] { return pubsub_subscriber(address, "abc", 0, sem); }, + [=] { return pubsub_subscriber(address, "def", 1, sem); }}); #ifdef __linux__ - sem_destroy(sem); - munmap(sem, sizeof(sem_t)); + sem_destroy(sem); + munmap(sem, sizeof(sem_t)); #endif // __linux__ #ifdef _WIN32 - CloseHandle(sem); + CloseHandle(sem); #endif // _WIN32 + EXPECT_EQ(result, TestResult::Success); +} - EXPECT_EQ(result, TestResult::Success); - } +// in this test case, the client establishes a connection with the server and then explicitly closes it +TEST_P(CcYmqTestSuiteParametrized, DISABLED_TestClientCloseEstablishedConnection) - // in this test case, the client establishes a connection with the server and then explicitly closes it - TEST(CcYmqTestSuite, DISABLED_TestClientCloseEstablishedConnection) - { - const auto host = "localhost"; - const auto port = 2902; - - auto result = test( - 20, - {[=] { return client_close_established_connection_client(host, port); }, - [=] { return client_close_established_connection_server(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } +{ + const auto address = GetAddress(2902); - // this test case is similar to the one above, except that it requests the socket stop before closing the connection - TEST(CcYmqTestSuite, TestClientSocketStopBeforeCloseConnection) - { - const auto host = "localhost"; - const auto port = 2904; - - auto result = test( - 20, - {[=] { return client_socket_stop_before_close_connection(host, port); }, - [=] { return server_socket_stop_before_close_connection(host, port); }}); - EXPECT_EQ(result, TestResult::Success); - } + auto result = test( + 20, + {[=] { return client_close_established_connection_client(address); }, + [=] { return client_close_established_connection_server(address); }}); - // in this test case, the we try to close a connection that does not exist - TEST(CcYmqTestSuite, TestClientCloseNonexistentConnection) - { - auto result = close_nonexistent_connection(); - EXPECT_EQ(result, TestResult::Success); - } + EXPECT_EQ(result, TestResult::Success); +} - // this test case verifies that requesting a socket stop causes pending and subsequent operations to be cancelled - TEST(CcYmqTestSuite, TestRequestSocketStop) - { - auto result = test_request_stop(); - EXPECT_EQ(result, TestResult::Success); - } +// this test case is similar to the one above, except that it requests the socket stop before closing the connection +TEST_P(CcYmqTestSuiteParametrized, TestClientSocketStopBeforeCloseConnection) - // main +{ + const auto address = GetAddress(2904); - int main(int argc, char** argv) - { - ensure_python_initialized(); + auto result = test( + 20, + {[=] { return client_socket_stop_before_close_connection(address); }, + [=] { return server_socket_stop_before_close_connection(address); }}); + + EXPECT_EQ(result, TestResult::Success); +} + +// in this test case, the we try to close a connection that does not exist +TEST(CcYmqTestSuite, TestClientCloseNonexistentConnection) + +{ + auto result = close_nonexistent_connection(); + EXPECT_EQ(result, TestResult::Success); +} + +// this test case verifies that requesting a socket stop causes pending and subsequent operations to be cancelled +TEST(CcYmqTestSuite, TestRequestSocketStop) +{ + auto result = test_request_stop(); + EXPECT_EQ(result, TestResult::Success); +} + +std::vector GetTransports() +{ + std::vector transports; + transports.push_back("tcp"); +#ifdef __linux__ + transports.push_back("ipc"); +#endif + return transports; +} + +// parametrize the test with tcp and ipc addresses +INSTANTIATE_TEST_SUITE_P( + YMQTransport, + CcYmqTestSuiteParametrized, + ::testing::ValuesIn(GetTransports()), + [](const testing::TestParamInfo& info) { + // use tcp/ipc as suffix for test names + return info.param; + }); + +// main +int main(int argc, char** argv) +{ + ensure_python_initialized(); #ifdef _WIN32 - // initialize winsock - WSADATA wsaData = {}; - int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (iResult != 0) { - std::cerr << "WSAStartup failed: " << iResult << "\n"; - return 1; - } + // initialize winsock + WSADATA wsaData = {}; + int iResult = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (iResult != 0) { + std::cerr << "WSAStartup failed: " << iResult << "\n"; + return 1; + } #endif // _WIN32 - testing::InitGoogleTest(&argc, argv); - auto result = RUN_ALL_TESTS(); + testing::InitGoogleTest(&argc, argv); + auto result = RUN_ALL_TESTS(); #ifdef _WIN32 - WSACleanup(); + WSACleanup(); #endif // _WIN32 - maybe_finalize_python(); - return result; - } + maybe_finalize_python(); + return result; +} diff --git a/tests/pymod_ymq/test_types.py b/tests/pymod_ymq/test_types.py index e91313080..6c1dcca59 100644 --- a/tests/pymod_ymq/test_types.py +++ b/tests/pymod_ymq/test_types.py @@ -70,10 +70,6 @@ def test_io_context(self): ctx = ymq.IOContext(num_threads=3) self.assertEqual(ctx.num_threads, 3) - # TODO: backporting to 3.8 broke this somehow - # it causes a segmentation fault - # re-enable once fixed - @unittest.skip("causes segmentation fault") def test_io_socket(self): # check that we can't create io socket instances directly self.assertRaises(TypeError, lambda: ymq.IOSocket()) # type: ignore