diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 28ebb114..0cbd86b7 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -58,6 +58,10 @@ jobs: - name: Install the project run: uv sync --all-extras --dev + - name: Is running on CI environment (GitHub Actions)? + run: | + python -c "import os; print('Result: ', os.getenv('GITHUB_ACTIONS', 'Not set'))" + - name: Install dependencies and check code run: | - uv run pytest -m "integration_test" + uv run pytest -m "integration_test" --log-cli-level=WARNING diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 2506a63f..96cdd21f 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -60,4 +60,4 @@ jobs: - name: Install dependencies and check code run: | - uv run pytest -m "not integration_test" + uv run pytest -m "not integration_test" --log-cli-level=WARNING diff --git a/src/midst_toolkit/models/clavaddpm/synthesizer.py b/src/midst_toolkit/models/clavaddpm/synthesizer.py index 0d829d4a..ad0a90f6 100644 --- a/src/midst_toolkit/models/clavaddpm/synthesizer.py +++ b/src/midst_toolkit/models/clavaddpm/synthesizer.py @@ -786,7 +786,10 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912 "df": child_final_df, "keys": child_primary_keys_arr.flatten().tolist(), } - with open(os.path.join(save_dir, "before_matching/synthetic_tables.pkl"), "wb") as file: + + before_matching_dir = save_dir / "before_matching" + before_matching_dir.mkdir(parents=True, exist_ok=True) + with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file: pickle.dump(synthetic_tables, file) synthesizing_end_time = time.time() @@ -800,12 +803,8 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912 cleaned_tables: dict[str, pd.DataFrame] = {} for table_key, table_val in final_tables.items(): - if "account_id" in tables[table_key]["original_cols"]: - cols = tables[table_key]["original_cols"] - cols.remove("account_id") - else: - cols = tables[table_key]["original_cols"] - cleaned_tables[table_key] = pd.DataFrame(table_val[cols]) + column_names = [column_name for column_name in tables[table_key]["original_cols"] if "_id" not in column_name] + cleaned_tables[table_key] = pd.DataFrame(table_val[column_names]) for cleaned_key, cleaned_val in cleaned_tables.items(): table_dir = os.path.join( diff --git a/tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl b/tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl new file mode 100644 index 00000000..4d40dd33 Binary files /dev/null and b/tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl differ diff --git a/tests/integration/assets/multi_table/assertion_data/conditional_samples.pt b/tests/integration/assets/multi_table/assertion_data/conditional_samples.pt index ac6d9391..9f8a51bc 100644 Binary files a/tests/integration/assets/multi_table/assertion_data/conditional_samples.pt and b/tests/integration/assets/multi_table/assertion_data/conditional_samples.pt differ diff --git a/tests/integration/assets/multi_table/assertion_data/diffusion_parameters.pkl b/tests/integration/assets/multi_table/assertion_data/diffusion_parameters.pkl index e562717c..47278483 100644 Binary files a/tests/integration/assets/multi_table/assertion_data/diffusion_parameters.pkl and b/tests/integration/assets/multi_table/assertion_data/diffusion_parameters.pkl differ diff --git a/tests/integration/assets/multi_table/assertion_data/syntetic_data.json b/tests/integration/assets/multi_table/assertion_data/syntetic_data.json deleted file mode 100644 index 127a1a9b..00000000 --- a/tests/integration/assets/multi_table/assertion_data/syntetic_data.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "X_gen": [ - [ - 20.22872584417114, - 92.01588539758231, - 66.09407621466602, - -395.99753424670297, - -546.8482325705069, - -24.101433892829142, - 33.05203129986196, - -379.23488795473145 - ], - [ - 2.434360110358925, - 9.06169838526799, - 4.003170601038821, - -96.79202634298802, - 30.110703649604655, - 88.57479030974434, - 8.772048228691677, - -91.52576106648934 - ], - [ - 33.85931705221509, - -9.106227703102576, - -7.762047268790178, - 266.0193164140085, - 284.4569535921904, - -36.940608327841595, - -5.195944146198937, - 119.86482313461545 - ], - [ - 181.37850764316246, - 15.935259065247706, - 23.034917774385075, - 171.92804003039046, - 309.3315215755009, - -46.31601606668806, - 9.086928058639302, - 145.02086683315446 - ], - [ - -50.80997244167481, - -64.8057892078512, - -207.33505176215743, - 1159.49264042098, - 817.0822017270992, - -100.08453704343738, - -109.28965383240923, - 311.3319359677195 - ] - ], - "y_gen": [ - 1, - 1, - 0, - 0, - 0 - ] -} diff --git a/tests/integration/assets/multi_table/assertion_data/synthetic_data.json b/tests/integration/assets/multi_table/assertion_data/synthetic_data.json new file mode 100644 index 00000000..cce22336 --- /dev/null +++ b/tests/integration/assets/multi_table/assertion_data/synthetic_data.json @@ -0,0 +1,51 @@ +{ + "X_gen": [ + [ + -89.64320001648818, + 2.329031172032132, + -122.97271923749832, + 552.3706152861826, + 353.47951217405426, + -63.164915493559306, + -42.27259013378604, + 244.21392290993887 + ], + [-0.4694302555020733, + 15.336690906361277, + -48.59970780139716, + -358.65097509895173, + 411.39200743280094, + 415.9651477725036, + -12.980662539762594, + -370.11192775534397 + ], + [5.009930498133295, + -220.79264470424582, + -4.129379545636459, + -188.011555249935, + 218.10979023918082, + 221.16927688555808, + 49.89701474616661, + -194.37953943919408 + ], + [-1.3109146973467711, + -73.2679936874503, + -9.218660554989645, + -389.99286808084486, + -490.3925197112697, + -423.00630661809424, + 432.9884292987812, + -397.6777786014056 + ], + [8.342572127948289, + 9.36842404400312, + -72.28739585181947, + -489.4411862829012, + 563.4325829362252, + 568.3398615720979, + -16.894123940346486, + -504.9528775096839 + ] + ], + "y_gen": [1, 0, 1, 0, 0] +} diff --git a/tests/integration/assets/multi_table/dataset_meta.json b/tests/integration/assets/multi_table/dataset_meta.json index 713390d5..5d1cafce 100644 --- a/tests/integration/assets/multi_table/dataset_meta.json +++ b/tests/integration/assets/multi_table/dataset_meta.json @@ -1,5 +1,6 @@ { "relation_order": [ + [null, "account"], ["account", "trans"] ], "tables": { diff --git a/tests/integration/assets/single_table/assertion_data/syntetic_data.json b/tests/integration/assets/single_table/assertion_data/synthetic_data.json similarity index 100% rename from tests/integration/assets/single_table/assertion_data/syntetic_data.json rename to tests/integration/assets/single_table/assertion_data/synthetic_data.json diff --git a/tests/integration/models/clavaddpm/test_model.py b/tests/integration/models/clavaddpm/test_model.py index e923e962..b43e7ab0 100644 --- a/tests/integration/models/clavaddpm/test_model.py +++ b/tests/integration/models/clavaddpm/test_model.py @@ -2,6 +2,7 @@ import pickle import random from collections.abc import Callable +from logging import WARNING from pathlib import Path import numpy as np @@ -9,12 +10,14 @@ import torch from torch.nn import functional +from midst_toolkit.common.logger import log from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds from midst_toolkit.common.variables import DEVICE from midst_toolkit.models.clavaddpm.clustering import clava_clustering from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table from midst_toolkit.models.clavaddpm.model import Classifier from midst_toolkit.models.clavaddpm.train import clava_training +from tests.integration.utils import is_running_on_ci_environment CLUSTERING_CONFIG = { @@ -240,8 +243,8 @@ def test_load_multi_table(): }, } - assert relation_order == [["account", "trans"]] - assert dataset_meta["relation_order"] == [["account", "trans"]] + assert relation_order == [[None, "account"], ["account", "trans"]] + assert dataset_meta["relation_order"] == [[None, "account"], ["account", "trans"]] assert dataset_meta["tables"] == { "account": {"children": ["trans"], "parents": []}, "trans": {"children": [], "parents": ["account"]}, @@ -273,7 +276,7 @@ def test_train_single_table(tmp_path: Path): ) x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy() - with open("tests/integration/assets/single_table/assertion_data/syntetic_data.json", "r") as f: + with open("tests/integration/assets/single_table/assertion_data/synthetic_data.json", "r") as f: expected_results = json.load(f) model_data = dict(models[key]["diffusion"].named_parameters()) @@ -285,12 +288,10 @@ def test_train_single_table(tmp_path: Path): expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()} model_layers = list(model_data.keys()) - expected_model_layers = list(expected_model_data.keys()) - # Adding those asserts under an if condition because they only pass on github. # In the else block, we set a tolerance that would work across platforms # however, it is way too high of a tolerance. - if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]): + if is_running_on_ci_environment(): # if the first layer is equal with minimal tolerance, all others should be equal as well assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers) @@ -303,6 +304,7 @@ def test_train_single_table(tmp_path: Path): # Otherwise, set a tolerance that would work across platforms # TODO: Figure out a way to set a lower tolerance # https://app.clickup.com/t/868f43wp0 + log(WARNING, "Not running on CI, assertions are made with a higher tolerance.") assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers) unset_all_random_seeds() @@ -332,7 +334,7 @@ def test_train_multi_table(tmp_path: Path): ) x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy() - with open("tests/integration/assets/multi_table/assertion_data/syntetic_data.json", "r") as f: + with open("tests/integration/assets/multi_table/assertion_data/synthetic_data.json", "r") as f: expected_results = json.load(f) model_data = dict(models[1][key]["diffusion"].named_parameters()) @@ -343,13 +345,11 @@ def test_train_multi_table(tmp_path: Path): # Making sure the expected model data is loaded on the correct device expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()} - model_layers = list(model_data.keys()) - expected_model_layers = list(expected_model_data.keys()) - # Adding those asserts under an if condition because they only pass on github. # In the else block, we set a tolerance that would work across platforms # however, it is way too high of a tolerance. - if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]): + model_layers = list(model_data.keys()) + if is_running_on_ci_environment(): # if the first layer is equal with minimal tolerance, all others should be equal as well assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers) @@ -362,6 +362,7 @@ def test_train_multi_table(tmp_path: Path): # Otherwise, set a tolerance that would work across platforms # TODO: Figure out a way to set a lower tolerance # https://app.clickup.com/t/868f43wp0 + log(WARNING, "Not running on CI, assertions are made with a higher tolerance.") assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers) classifier_scale = 1.0 @@ -382,11 +383,11 @@ def test_train_multi_table(tmp_path: Path): ).to(DEVICE) # Adding those asserts under an if condition because they only pass on github. - # In the else block, we set a tolerance that would work across platforms - # however, it is way too high of a tolerance. - if torch.allclose(conditional_sample[0], expected_conditional_sample[0]): + if is_running_on_ci_environment(): # if the first values are equal with minimal tolerance, all others should be equal as well assert torch.allclose(conditional_sample, expected_conditional_sample) + else: + log(WARNING, "Not running on CI, skipping detailed assertions.") unset_all_random_seeds() @@ -401,7 +402,7 @@ def test_clustering_reload(tmp_path: Path): tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) # Assert - account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster"]) + account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster", "placeholder"]) account_original_df_as_float = tables["account"]["original_df"].astype(float) assert account_df_no_clustering.equals(account_original_df_as_float) diff --git a/tests/integration/models/clavaddpm/test_synthesizer.py b/tests/integration/models/clavaddpm/test_synthesizer.py new file mode 100644 index 00000000..94f7902a --- /dev/null +++ b/tests/integration/models/clavaddpm/test_synthesizer.py @@ -0,0 +1,103 @@ +import pickle +from copy import deepcopy +from logging import WARNING +from pathlib import Path + +import pytest + +from midst_toolkit.common.logger import log +from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds +from midst_toolkit.common.variables import DEVICE +from midst_toolkit.models.clavaddpm.clustering import clava_clustering +from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table +from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing +from midst_toolkit.models.clavaddpm.train import clava_training +from tests.integration.utils import is_running_on_ci_environment + + +CLUSTERING_CONFIG = { + "parent_scale": 1.0, + "num_clusters": 3, + "clustering_method": "kmeans_and_gmm", +} + +DIFFUSION_CONFIG = { + "d_layers": [512, 1024, 1024, 1024, 1024, 512], + "dropout": 0.0, + "num_timesteps": 100, + "model_type": "mlp", + "iterations": 1000, + "batch_size": 24, + "lr": 0.0006, + "gaussian_loss_type": "mse", + "weight_decay": 1e-05, + "scheduler": "cosine", + "data_split_ratios": [0.99, 0.005, 0.005], +} + +CLASSIFIER_CONFIG = { + "d_layers": [128, 256, 512, 1024, 512, 256, 128], + "lr": 0.0001, + "dim_t": 128, + "batch_size": 24, + "iterations": 1000, + "data_split_ratios": [0.99, 0.005, 0.005], +} + +SYNTHESIZING_CONFIG = { + "general": { + "exp_name": "ensemble_attack", + "workspace_dir": None, + "sample_prefix": "", + }, + "sampling": { + "batch_size": 2, + "classifier_scale": 1.0, + }, + "matching": { + "num_matching_clusters": 1, + "matching_batch_size": 1, + "unique_matching": True, + "no_matching": False, + }, +} + + +@pytest.mark.integration_test() +def test_clava_synthesize_multi_table(tmp_path: Path): + # Setup + set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) + + # Act + tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/multi_table/")) + tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) + models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) + + # TODO: Temporary, we should refactor those configs + configs = deepcopy(SYNTHESIZING_CONFIG) + configs["general"]["workspace_dir"] = str(tmp_path) + + cleaned_tables, _, _ = clava_synthesizing( + tables, + relation_order, + tmp_path, + all_group_lengths_prob_dicts, + models[1], + configs, + ) + + # Assert + assert cleaned_tables["account"].shape == (9, 2) + assert cleaned_tables["trans"].shape == (145, 8) + + if is_running_on_ci_environment(): + expected_cleaned_tables = pickle.loads( + Path("tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl").read_bytes(), + ) + assert cleaned_tables["account"].equals(expected_cleaned_tables["account"]) + assert cleaned_tables["trans"].equals(expected_cleaned_tables["trans"]) + + else: + log(WARNING, "Not running on CI, skipping detailed assertions.") + + unset_all_random_seeds() diff --git a/tests/integration/utils.py b/tests/integration/utils.py new file mode 100644 index 00000000..80422f19 --- /dev/null +++ b/tests/integration/utils.py @@ -0,0 +1,11 @@ +import os + + +def is_running_on_ci_environment() -> bool: + """ + Check if running on a CI environment, particularly GitHub Actions. + + Returns: + bool: True if running on a CI environment (Github Actions), False otherwise. + """ + return os.getenv("GITHUB_ACTIONS", "false").lower() == "true"