-
Notifications
You must be signed in to change notification settings - Fork 1
Add tests for the multi-table synthesizer code #69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c059d93
9f0290a
80073f2
1d43783
f7d69ec
01ed9c1
0dbc6b4
128da65
f158b44
6503788
d599334
1950f5c
3e8237c
c957dec
4d0707b
b7db96e
30d0a0d
779b108
0e3a42a
8600ccb
bc67266
774e99b
1168e93
bf05c3c
d90ed2c
96414be
aafd66c
594d9cd
3a2b203
972947d
77a2249
67368ab
293f4d9
dbe71a4
3a29c46
f8c9adf
1c055e8
390ad5b
84fe972
56a09ec
f35b596
0a9994a
c7ed903
ba7ab5a
771bde8
1f4fed2
57bd33c
e9ffe39
224b265
10e9989
94f014b
f479990
7d120ac
626a39a
93dfd31
8eb21ae
53c1320
e75c5ca
3b472e3
ea8e2db
def3b21
747d9c4
42cf7c1
2ae4dae
5d2240b
9d1af39
0d5f556
c38af00
ae09be9
1175bc6
5316c48
edc60bb
74d1a7a
40574db
bdd28da
585e874
f80068f
a7c7948
d75f587
8d170f9
0cc7982
2db7af6
a8a2525
2e39c86
99a6fbe
164447d
c50763c
d648615
28e91bb
f61ead0
9e597df
a3b70da
93ecd5a
acc94e1
888d84e
37eb0c4
49c0b15
3495471
f87293d
490ce23
b863c56
3ec141a
dcaa990
2e935ed
8f5cb32
ac2ff9b
ae8fac2
856e280
f4399ab
914b1b4
428da3c
f504a19
8093d9b
c9979b2
0339aab
bd9c11e
336f392
4846af9
a7e21ca
e51ac63
9647437
23f994e
33d3025
7a87d62
8d38c0e
569291f
b06db63
a3a439c
503e800
ef0c2a1
5a9bc1e
7e1e07f
bbfd6b7
88f1d8c
b8e9990
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -786,7 +786,10 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912 | |
| "df": child_final_df, | ||
| "keys": child_primary_keys_arr.flatten().tolist(), | ||
| } | ||
| with open(os.path.join(save_dir, "before_matching/synthetic_tables.pkl"), "wb") as file: | ||
|
|
||
| before_matching_dir = save_dir / "before_matching" | ||
| before_matching_dir.mkdir(parents=True, exist_ok=True) | ||
| with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making sure the directory exists before saving a file in it. |
||
| pickle.dump(synthetic_tables, file) | ||
|
|
||
| synthesizing_end_time = time.time() | ||
|
|
@@ -800,12 +803,8 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912 | |
|
|
||
| cleaned_tables: dict[str, pd.DataFrame] = {} | ||
| for table_key, table_val in final_tables.items(): | ||
| if "account_id" in tables[table_key]["original_cols"]: | ||
| cols = tables[table_key]["original_cols"] | ||
| cols.remove("account_id") | ||
| else: | ||
| cols = tables[table_key]["original_cols"] | ||
| cleaned_tables[table_key] = pd.DataFrame(table_val[cols]) | ||
| column_names = [column_name for column_name in tables[table_key]["original_cols"] if "_id" not in column_name] | ||
| cleaned_tables[table_key] = pd.DataFrame(table_val[column_names]) | ||
|
|
||
| for cleaned_key, cleaned_val in cleaned_tables.items(): | ||
| table_dir = os.path.join( | ||
|
|
||
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| { | ||
lotif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "X_gen": [ | ||
| [ | ||
| -89.64320001648818, | ||
| 2.329031172032132, | ||
| -122.97271923749832, | ||
| 552.3706152861826, | ||
| 353.47951217405426, | ||
| -63.164915493559306, | ||
| -42.27259013378604, | ||
| 244.21392290993887 | ||
| ], | ||
| [-0.4694302555020733, | ||
| 15.336690906361277, | ||
| -48.59970780139716, | ||
| -358.65097509895173, | ||
| 411.39200743280094, | ||
| 415.9651477725036, | ||
| -12.980662539762594, | ||
| -370.11192775534397 | ||
| ], | ||
| [5.009930498133295, | ||
| -220.79264470424582, | ||
| -4.129379545636459, | ||
| -188.011555249935, | ||
| 218.10979023918082, | ||
| 221.16927688555808, | ||
| 49.89701474616661, | ||
| -194.37953943919408 | ||
| ], | ||
| [-1.3109146973467711, | ||
| -73.2679936874503, | ||
| -9.218660554989645, | ||
| -389.99286808084486, | ||
| -490.3925197112697, | ||
| -423.00630661809424, | ||
| 432.9884292987812, | ||
| -397.6777786014056 | ||
| ], | ||
| [8.342572127948289, | ||
| 9.36842404400312, | ||
| -72.28739585181947, | ||
| -489.4411862829012, | ||
| 563.4325829362252, | ||
| 568.3398615720979, | ||
| -16.894123940346486, | ||
| -504.9528775096839 | ||
| ] | ||
| ], | ||
| "y_gen": [1, 0, 1, 0, 0] | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| { | ||
| "relation_order": [ | ||
| [null, "account"], | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This relation order was missing and it is needed for the synthesizer code. This is also the reason why the assertion data had to be re-generated and re-uploaded. |
||
| ["account", "trans"] | ||
| ], | ||
| "tables": { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,19 +2,22 @@ | |
| import pickle | ||
| import random | ||
| from collections.abc import Callable | ||
| from logging import WARNING | ||
| from pathlib import Path | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
| from torch.nn import functional | ||
|
|
||
| from midst_toolkit.common.logger import log | ||
| from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds | ||
| from midst_toolkit.common.variables import DEVICE | ||
| from midst_toolkit.models.clavaddpm.clustering import clava_clustering | ||
| from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table | ||
| from midst_toolkit.models.clavaddpm.model import Classifier | ||
| from midst_toolkit.models.clavaddpm.train import clava_training | ||
| from tests.integration.utils import is_running_on_ci_environment | ||
|
|
||
|
|
||
| CLUSTERING_CONFIG = { | ||
|
|
@@ -240,8 +243,8 @@ def test_load_multi_table(): | |
| }, | ||
| } | ||
|
|
||
| assert relation_order == [["account", "trans"]] | ||
| assert dataset_meta["relation_order"] == [["account", "trans"]] | ||
| assert relation_order == [[None, "account"], ["account", "trans"]] | ||
| assert dataset_meta["relation_order"] == [[None, "account"], ["account", "trans"]] | ||
| assert dataset_meta["tables"] == { | ||
| "account": {"children": ["trans"], "parents": []}, | ||
| "trans": {"children": [], "parents": ["account"]}, | ||
|
|
@@ -273,7 +276,7 @@ def test_train_single_table(tmp_path: Path): | |
| ) | ||
| x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy() | ||
|
|
||
| with open("tests/integration/assets/single_table/assertion_data/syntetic_data.json", "r") as f: | ||
| with open("tests/integration/assets/single_table/assertion_data/synthetic_data.json", "r") as f: | ||
| expected_results = json.load(f) | ||
|
|
||
| model_data = dict(models[key]["diffusion"].named_parameters()) | ||
|
|
@@ -285,12 +288,10 @@ def test_train_single_table(tmp_path: Path): | |
| expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()} | ||
|
|
||
| model_layers = list(model_data.keys()) | ||
| expected_model_layers = list(expected_model_data.keys()) | ||
|
|
||
| # Adding those asserts under an if condition because they only pass on github. | ||
| # In the else block, we set a tolerance that would work across platforms | ||
| # however, it is way too high of a tolerance. | ||
| if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]): | ||
| if is_running_on_ci_environment(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a better check for those |
||
| # if the first layer is equal with minimal tolerance, all others should be equal as well | ||
| assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers) | ||
|
|
||
|
|
@@ -303,6 +304,7 @@ def test_train_single_table(tmp_path: Path): | |
| # Otherwise, set a tolerance that would work across platforms | ||
| # TODO: Figure out a way to set a lower tolerance | ||
| # https://app.clickup.com/t/868f43wp0 | ||
| log(WARNING, "Not running on CI, assertions are made with a higher tolerance.") | ||
| assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers) | ||
|
|
||
| unset_all_random_seeds() | ||
|
|
@@ -332,7 +334,7 @@ def test_train_multi_table(tmp_path: Path): | |
| ) | ||
| x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy() | ||
|
|
||
| with open("tests/integration/assets/multi_table/assertion_data/syntetic_data.json", "r") as f: | ||
| with open("tests/integration/assets/multi_table/assertion_data/synthetic_data.json", "r") as f: | ||
| expected_results = json.load(f) | ||
|
|
||
| model_data = dict(models[1][key]["diffusion"].named_parameters()) | ||
|
|
@@ -343,13 +345,11 @@ def test_train_multi_table(tmp_path: Path): | |
| # Making sure the expected model data is loaded on the correct device | ||
| expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()} | ||
|
|
||
| model_layers = list(model_data.keys()) | ||
| expected_model_layers = list(expected_model_data.keys()) | ||
|
|
||
| # Adding those asserts under an if condition because they only pass on github. | ||
| # In the else block, we set a tolerance that would work across platforms | ||
| # however, it is way too high of a tolerance. | ||
| if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]): | ||
| model_layers = list(model_data.keys()) | ||
| if is_running_on_ci_environment(): | ||
| # if the first layer is equal with minimal tolerance, all others should be equal as well | ||
| assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers) | ||
|
|
||
|
|
@@ -362,6 +362,7 @@ def test_train_multi_table(tmp_path: Path): | |
| # Otherwise, set a tolerance that would work across platforms | ||
| # TODO: Figure out a way to set a lower tolerance | ||
| # https://app.clickup.com/t/868f43wp0 | ||
| log(WARNING, "Not running on CI, assertions are made with a higher tolerance.") | ||
| assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers) | ||
|
|
||
| classifier_scale = 1.0 | ||
|
|
@@ -382,11 +383,11 @@ def test_train_multi_table(tmp_path: Path): | |
| ).to(DEVICE) | ||
|
|
||
| # Adding those asserts under an if condition because they only pass on github. | ||
| # In the else block, we set a tolerance that would work across platforms | ||
| # however, it is way too high of a tolerance. | ||
| if torch.allclose(conditional_sample[0], expected_conditional_sample[0]): | ||
| if is_running_on_ci_environment(): | ||
| # if the first values are equal with minimal tolerance, all others should be equal as well | ||
| assert torch.allclose(conditional_sample, expected_conditional_sample) | ||
| else: | ||
| log(WARNING, "Not running on CI, skipping detailed assertions.") | ||
|
|
||
| unset_all_random_seeds() | ||
|
|
||
|
|
@@ -401,7 +402,7 @@ def test_clustering_reload(tmp_path: Path): | |
| tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) | ||
|
|
||
| # Assert | ||
| account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster"]) | ||
| account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster", "placeholder"]) | ||
| account_original_df_as_float = tables["account"]["original_df"].astype(float) | ||
| assert account_df_no_clustering.equals(account_original_df_as_float) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| import pickle | ||
| from copy import deepcopy | ||
| from logging import WARNING | ||
| from pathlib import Path | ||
|
|
||
| import pytest | ||
|
|
||
| from midst_toolkit.common.logger import log | ||
| from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds | ||
| from midst_toolkit.common.variables import DEVICE | ||
| from midst_toolkit.models.clavaddpm.clustering import clava_clustering | ||
| from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table | ||
| from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing | ||
| from midst_toolkit.models.clavaddpm.train import clava_training | ||
| from tests.integration.utils import is_running_on_ci_environment | ||
|
|
||
|
|
||
| CLUSTERING_CONFIG = { | ||
| "parent_scale": 1.0, | ||
| "num_clusters": 3, | ||
| "clustering_method": "kmeans_and_gmm", | ||
| } | ||
|
|
||
| DIFFUSION_CONFIG = { | ||
| "d_layers": [512, 1024, 1024, 1024, 1024, 512], | ||
| "dropout": 0.0, | ||
| "num_timesteps": 100, | ||
| "model_type": "mlp", | ||
| "iterations": 1000, | ||
| "batch_size": 24, | ||
| "lr": 0.0006, | ||
| "gaussian_loss_type": "mse", | ||
| "weight_decay": 1e-05, | ||
| "scheduler": "cosine", | ||
| "data_split_ratios": [0.99, 0.005, 0.005], | ||
| } | ||
|
|
||
| CLASSIFIER_CONFIG = { | ||
| "d_layers": [128, 256, 512, 1024, 512, 256, 128], | ||
| "lr": 0.0001, | ||
| "dim_t": 128, | ||
| "batch_size": 24, | ||
| "iterations": 1000, | ||
| "data_split_ratios": [0.99, 0.005, 0.005], | ||
| } | ||
|
|
||
| SYNTHESIZING_CONFIG = { | ||
| "general": { | ||
| "exp_name": "ensemble_attack", | ||
| "workspace_dir": None, | ||
| "sample_prefix": "", | ||
| }, | ||
| "sampling": { | ||
| "batch_size": 2, | ||
| "classifier_scale": 1.0, | ||
| }, | ||
| "matching": { | ||
| "num_matching_clusters": 1, | ||
| "matching_batch_size": 1, | ||
| "unique_matching": True, | ||
| "no_matching": False, | ||
| }, | ||
| } | ||
|
|
||
|
|
||
| @pytest.mark.integration_test() | ||
| def test_clava_synthesize_multi_table(tmp_path: Path): | ||
| # Setup | ||
| set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True) | ||
|
|
||
| # Act | ||
| tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/multi_table/")) | ||
| tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG) | ||
| models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE) | ||
|
|
||
| # TODO: Temporary, we should refactor those configs | ||
| configs = deepcopy(SYNTHESIZING_CONFIG) | ||
| configs["general"]["workspace_dir"] = str(tmp_path) | ||
|
|
||
| cleaned_tables, _, _ = clava_synthesizing( | ||
| tables, | ||
| relation_order, | ||
| tmp_path, | ||
| all_group_lengths_prob_dicts, | ||
| models[1], | ||
| configs, | ||
| ) | ||
|
|
||
| # Assert | ||
| assert cleaned_tables["account"].shape == (9, 2) | ||
| assert cleaned_tables["trans"].shape == (145, 8) | ||
|
|
||
| if is_running_on_ci_environment(): | ||
| expected_cleaned_tables = pickle.loads( | ||
| Path("tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl").read_bytes(), | ||
| ) | ||
| assert cleaned_tables["account"].equals(expected_cleaned_tables["account"]) | ||
| assert cleaned_tables["trans"].equals(expected_cleaned_tables["trans"]) | ||
lotif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| else: | ||
| log(WARNING, "Not running on CI, skipping detailed assertions.") | ||
|
|
||
| unset_all_random_seeds() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding this just so we can see if this environment variable is set in case we need to debug it later.