Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
c059d93
Merge "Refactor Transformations handling: replace get_T_dict with a d…
lotif Sep 22, 2025
9f0290a
Refactor clustering method handling: Introduce ClusteringMethod enum …
lotif Sep 22, 2025
80073f2
Refactor model handling: Introduce ModelType enum for improved type s…
lotif Sep 22, 2025
1d43783
Merge "Refactor model parameters: Introduce ModelParameters and RTDLP…
lotif Sep 22, 2025
f7d69ec
Merge "Refactor y condition handling: Replace string literals with Is…
lotif Sep 22, 2025
01ed9c1
Refactor Gaussian loss handling: Introduce GaussianLossType enum to r…
lotif Sep 22, 2025
0dbc6b4
Refactor scheduler handling: Introduce Scheduler enum to replace stri…
lotif Sep 22, 2025
128da65
Merge "Refactor sampler initialization: Update UniformSampler and Los…
lotif Sep 22, 2025
f158b44
Enhance metric and loss handling: Refactor loss computation in _numer…
lotif Sep 23, 2025
6503788
Transforming a lot of literals into enums
lotif Sep 29, 2025
d599334
WIP renaming RTDL, cat and num and data splits
lotif Sep 29, 2025
1950f5c
Using more data splits and adding types for gaussian parametrization
lotif Sep 30, 2025
3e8237c
Adding enum for YType
lotif Sep 30, 2025
c957dec
Merge branch 'main' into marcelo/classes-and-enums-2
lotif Sep 30, 2025
4d0707b
Renaming Scheduler to SchedulerType and moving it and GaussianLossTyp…
lotif Sep 30, 2025
b7db96e
Merge remote-tracking branch 'origin/marcelo/classes-and-enums-2' int…
lotif Sep 30, 2025
30d0a0d
WIP CR by David
lotif Sep 30, 2025
779b108
Merge branch 'main' into marcelo/classes-and-enums-2
lotif Oct 1, 2025
0e3a42a
Cont'd CR comments by David
lotif Oct 1, 2025
8600ccb
Merge remote-tracking branch 'origin/marcelo/classes-and-enums-2' int…
lotif Oct 1, 2025
bc67266
Adding TODO
lotif Oct 1, 2025
774e99b
WIp starting the breakdown
lotif Oct 1, 2025
1168e93
Renames
lotif Oct 1, 2025
bf05c3c
Last breakdown
lotif Oct 1, 2025
d90ed2c
Removing ignore
lotif Oct 1, 2025
96414be
Finished refactoring
lotif Oct 1, 2025
aafd66c
Merge branch 'main' into marcelo/remove-ignores
lotif Oct 1, 2025
594d9cd
Renamings, mostly
lotif Oct 2, 2025
3a2b203
More enums
lotif Oct 2, 2025
972947d
Adding datasplits class
lotif Oct 2, 2025
77a2249
Splitting into another function
lotif Oct 2, 2025
67368ab
Adding docstrings, removing save
lotif Oct 2, 2025
293f4d9
Renaming function
lotif Oct 2, 2025
dbe71a4
Merge branch 'marcelo/refactoring-pair-clustering' into marcelo/renam…
lotif Oct 2, 2025
3a29c46
Merge branch 'marcelo/refactoring-pair-clustering' into marcelo/remov…
lotif Oct 2, 2025
f8c9adf
One more refactor
lotif Oct 2, 2025
1c055e8
rolling back table_domain renamings
lotif Oct 2, 2025
390ad5b
Merge branch 'marcelo/renamings' into marcelo/refactor-process-pipeli…
lotif Oct 2, 2025
84fe972
Splitting the make_dataset_from_df function
lotif Oct 2, 2025
56a09ec
Fixing broken code from revert
lotif Oct 2, 2025
f35b596
CR by David
lotif Oct 2, 2025
0a9994a
CR by David
lotif Oct 3, 2025
c7ed903
Merge branch 'marcelo/refactoring-pair-clustering' into marcelo/renam…
lotif Oct 3, 2025
ba7ab5a
Merge branch 'marcelo/renamings' into marcelo/refactor-process-pipeli…
lotif Oct 3, 2025
771bde8
Merge branch 'marcelo/refactor-process-pipeline-data' into marcelo/re…
lotif Oct 3, 2025
1f4fed2
CR by David
lotif Oct 3, 2025
57bd33c
Merge branch 'marcelo/refactor-process-pipeline-data' into marcelo/re…
lotif Oct 3, 2025
e9ffe39
CR by David
lotif Oct 6, 2025
224b265
CR by David and Fatemeh
lotif Oct 6, 2025
10e9989
Merge branch 'main' into marcelo/refactoring-pair-clustering
lotif Oct 6, 2025
94f014b
CR by David and Fatemeh
lotif Oct 6, 2025
f479990
Merge remote-tracking branch 'origin/marcelo/refactoring-pair-cluster…
lotif Oct 6, 2025
7d120ac
Merge branch 'marcelo/refactoring-pair-clustering' into marcelo/renam…
lotif Oct 6, 2025
626a39a
Merge branch 'marcelo/renamings' into marcelo/refactor-process-pipeli…
lotif Oct 6, 2025
93dfd31
Merge branch 'marcelo/refactor-process-pipeline-data' into marcelo/re…
lotif Oct 6, 2025
8eb21ae
Last CR comment by David
lotif Oct 6, 2025
53c1320
Fixing merge conflicts
lotif Oct 6, 2025
e75c5ca
Merge branch 'marcelo/renamings' into marcelo/refactor-process-pipeli…
lotif Oct 6, 2025
3b472e3
Merge branch 'marcelo/refactor-process-pipeline-data' into marcelo/re…
lotif Oct 6, 2025
ea8e2db
Merge branch 'main' into marcelo/refactor-process-pipeline-data
lotif Oct 6, 2025
def3b21
Merge branch 'marcelo/refactor-process-pipeline-data' into marcelo/re…
lotif Oct 6, 2025
747d9c4
Merge branch 'main' into marcelo/refactor-make-dataset-from-df
lotif Oct 6, 2025
42cf7c1
Merge remote-tracking branch 'origin/marcelo/refactoring-pair-cluster…
lotif Oct 6, 2025
2ae4dae
Merge remote-tracking branch 'origin/marcelo/renamings' into marcelo/…
lotif Oct 6, 2025
5d2240b
Merge remote-tracking branch 'origin/marcelo/refactor-process-pipelin…
lotif Oct 6, 2025
9d1af39
Merge branch 'marcelo/refactor-make-dataset-from-df' into marcelo/rem…
lotif Oct 6, 2025
0d5f556
Removing some more ignores
lotif Oct 7, 2025
c38af00
Fixing the gaussian multinomial diffusion module
lotif Oct 8, 2025
ae09be9
Fixing the gaussian multinomial diffusion module
lotif Oct 8, 2025
1175bc6
Small docstring adjustment
lotif Oct 8, 2025
5316c48
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 8, 2025
edc60bb
CR by Fatemeh
lotif Oct 10, 2025
74d1a7a
Merge branch 'marcelo/refactor-make-dataset-from-df' into marcelo/rem…
lotif Oct 10, 2025
40574db
Removing ignores in model.py and minor refactorings
lotif Oct 10, 2025
bdd28da
Removing the rest of the ignores
lotif Oct 10, 2025
585e874
Sokme more renamings
lotif Oct 14, 2025
f80068f
Merge branch 'main' into marcelo/refactor-make-dataset-from-df
emersodb Oct 14, 2025
a7c7948
Adding TODO
lotif Oct 16, 2025
d75f587
Merge branch 'main' into marcelo/refactor-make-dataset-from-df
lotif Oct 16, 2025
8d170f9
Merge branch 'marcelo/refactor-make-dataset-from-df' into marcelo/rem…
lotif Oct 16, 2025
0cc7982
Merge branch 'main' into marcelo/remove-ignores
lotif Oct 16, 2025
2db7af6
Merge branch 'main' into marcelo/refactor-gaussian
lotif Oct 16, 2025
a8a2525
Removing some ignores from the synthesizer
lotif Oct 16, 2025
2e39c86
Removing some ignores from the synthesizer
lotif Oct 16, 2025
99a6fbe
Fixing refactoring issues
lotif Oct 16, 2025
164447d
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 16, 2025
c50763c
Merge branch 'main' into marcelo/refactor-gaussian
emersodb Oct 17, 2025
d648615
David's CR
lotif Oct 17, 2025
28e91bb
Merge remote-tracking branch 'origin/marcelo/refactor-gaussian' into …
lotif Oct 17, 2025
f61ead0
Merge branch 'main' into marcelo/remove-ignores
lotif Oct 17, 2025
9e597df
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 17, 2025
a3b70da
Better docstrings
lotif Oct 17, 2025
93ecd5a
Change I forgot to submit
lotif Oct 17, 2025
acc94e1
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 17, 2025
888d84e
Dynamically passing the device to test_model.py
lotif Oct 20, 2025
37eb0c4
Dynamically passing the device to test_model.py
lotif Oct 20, 2025
49c0b15
Adding .to(DEVICE)
lotif Oct 20, 2025
3495471
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 20, 2025
f87293d
CR by David
lotif Oct 20, 2025
490ce23
Merge branch 'main' into marcelo/refactor-gaussian
emersodb Oct 20, 2025
b863c56
This may work in the cluster
lotif Oct 20, 2025
3ec141a
Merge remote-tracking branch 'origin/marcelo/refactor-gaussian' into …
lotif Oct 20, 2025
dcaa990
Merge branch 'marcelo/refactor-gaussian' into marcelo/remove-ignores
lotif Oct 20, 2025
2e935ed
Merge branch 'main' into marcelo/remove-ignores
lotif Oct 20, 2025
8f5cb32
WIP needs docstrings
lotif Oct 21, 2025
ac2ff9b
more details
lotif Oct 21, 2025
ae8fac2
Adding docstrings
lotif Oct 21, 2025
856e280
David's CR
lotif Oct 21, 2025
f4399ab
CR by David
lotif Oct 22, 2025
914b1b4
Typo
lotif Oct 22, 2025
428da3c
Merge branch 'main' into marcelo/remove-ignores
emersodb Oct 22, 2025
f504a19
Merge branch 'marcelo/remove-ignores' into marcelo/refactor-sample
lotif Oct 22, 2025
8093d9b
Merge branch 'main' into marcelo/refactor-sample
lotif Oct 22, 2025
c9979b2
Merge branch 'main' into marcelo/refactor-conditional-sampling
lotif Oct 22, 2025
0339aab
Adding a test for synthesizing multi table data
lotif Oct 22, 2025
bd9c11e
commenting asserts
lotif Oct 22, 2025
336f392
Uploading second asset
lotif Oct 22, 2025
4846af9
Adding data and printing synthesizer test results
lotif Oct 24, 2025
a7e21ca
Adding verbose test logs
lotif Oct 24, 2025
e51ac63
Uploading synthetic data
lotif Oct 24, 2025
9647437
commenting lines that will fail test so the upload works
lotif Oct 24, 2025
23f994e
Adding check to run on github actions
lotif Oct 24, 2025
33d3025
Adding assertions and an IF condition for checking if it's running on…
lotif Oct 24, 2025
7a87d62
Uploading one more result
lotif Oct 24, 2025
8d38c0e
Putting the right shape for the cleaned tables
lotif Oct 24, 2025
569291f
using .equals instead of == for dataframes
lotif Oct 24, 2025
b06db63
equals on each key instead
lotif Oct 24, 2025
a3a439c
Uploading updated tensors for github
lotif Oct 24, 2025
503e800
Removing file upload
lotif Oct 24, 2025
ef0c2a1
Addressing coderabbit comment
lotif Oct 24, 2025
5a9bc1e
More code comments by coderabbit
lotif Oct 24, 2025
7e1e07f
Changing to log warning and making sure it appears in the github logs…
lotif Oct 24, 2025
bbfd6b7
Merge branch 'main' into marcelo/add-syth-multi-table-test
lotif Oct 27, 2025
88f1d8c
better way of getting ID columns
lotif Oct 27, 2025
b8e9990
Didn't work, applying David's suggestion instead
lotif Oct 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ jobs:
- name: Install the project
run: uv sync --all-extras --dev

- name: Is running on CI environment (GitHub Actions)?
run: |
python -c "import os; print('Result: ', os.getenv('GITHUB_ACTIONS', 'Not set'))"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding this just so we can see if this environment variable is set in case we need to debug it later.


- name: Install dependencies and check code
run: |
uv run pytest -m "integration_test"
uv run pytest -m "integration_test" --log-cli-level=WARNING
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ jobs:

- name: Install dependencies and check code
run: |
uv run pytest -m "not integration_test"
uv run pytest -m "not integration_test" --log-cli-level=WARNING
13 changes: 6 additions & 7 deletions src/midst_toolkit/models/clavaddpm/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,10 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912
"df": child_final_df,
"keys": child_primary_keys_arr.flatten().tolist(),
}
with open(os.path.join(save_dir, "before_matching/synthetic_tables.pkl"), "wb") as file:

before_matching_dir = save_dir / "before_matching"
before_matching_dir.mkdir(parents=True, exist_ok=True)
with open(before_matching_dir / "synthetic_tables.pkl", "wb") as file:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making sure the directory exists before saving a file in it.

pickle.dump(synthetic_tables, file)

synthesizing_end_time = time.time()
Expand All @@ -800,12 +803,8 @@ def clava_synthesizing( # noqa: PLR0915, PLR0912

cleaned_tables: dict[str, pd.DataFrame] = {}
for table_key, table_val in final_tables.items():
if "account_id" in tables[table_key]["original_cols"]:
cols = tables[table_key]["original_cols"]
cols.remove("account_id")
else:
cols = tables[table_key]["original_cols"]
cleaned_tables[table_key] = pd.DataFrame(table_val[cols])
column_names = [column_name for column_name in tables[table_key]["original_cols"] if "_id" not in column_name]
cleaned_tables[table_key] = pd.DataFrame(table_val[column_names])

for cleaned_key, cleaned_val in cleaned_tables.items():
table_dir = os.path.join(
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
{
"X_gen": [
[
-89.64320001648818,
2.329031172032132,
-122.97271923749832,
552.3706152861826,
353.47951217405426,
-63.164915493559306,
-42.27259013378604,
244.21392290993887
],
[-0.4694302555020733,
15.336690906361277,
-48.59970780139716,
-358.65097509895173,
411.39200743280094,
415.9651477725036,
-12.980662539762594,
-370.11192775534397
],
[5.009930498133295,
-220.79264470424582,
-4.129379545636459,
-188.011555249935,
218.10979023918082,
221.16927688555808,
49.89701474616661,
-194.37953943919408
],
[-1.3109146973467711,
-73.2679936874503,
-9.218660554989645,
-389.99286808084486,
-490.3925197112697,
-423.00630661809424,
432.9884292987812,
-397.6777786014056
],
[8.342572127948289,
9.36842404400312,
-72.28739585181947,
-489.4411862829012,
563.4325829362252,
568.3398615720979,
-16.894123940346486,
-504.9528775096839
]
],
"y_gen": [1, 0, 1, 0, 0]
}
1 change: 1 addition & 0 deletions tests/integration/assets/multi_table/dataset_meta.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"relation_order": [
[null, "account"],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This relation order was missing and it is needed for the synthesizer code. This is also the reason why the assertion data had to be re-generated and re-uploaded.

["account", "trans"]
],
"tables": {
Expand Down
31 changes: 16 additions & 15 deletions tests/integration/models/clavaddpm/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import pickle
import random
from collections.abc import Callable
from logging import WARNING
from pathlib import Path

import numpy as np
import pytest
import torch
from torch.nn import functional

from midst_toolkit.common.logger import log
from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds
from midst_toolkit.common.variables import DEVICE
from midst_toolkit.models.clavaddpm.clustering import clava_clustering
from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table
from midst_toolkit.models.clavaddpm.model import Classifier
from midst_toolkit.models.clavaddpm.train import clava_training
from tests.integration.utils import is_running_on_ci_environment


CLUSTERING_CONFIG = {
Expand Down Expand Up @@ -240,8 +243,8 @@ def test_load_multi_table():
},
}

assert relation_order == [["account", "trans"]]
assert dataset_meta["relation_order"] == [["account", "trans"]]
assert relation_order == [[None, "account"], ["account", "trans"]]
assert dataset_meta["relation_order"] == [[None, "account"], ["account", "trans"]]
assert dataset_meta["tables"] == {
"account": {"children": ["trans"], "parents": []},
"trans": {"children": [], "parents": ["account"]},
Expand Down Expand Up @@ -273,7 +276,7 @@ def test_train_single_table(tmp_path: Path):
)
x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()

with open("tests/integration/assets/single_table/assertion_data/syntetic_data.json", "r") as f:
with open("tests/integration/assets/single_table/assertion_data/synthetic_data.json", "r") as f:
expected_results = json.load(f)

model_data = dict(models[key]["diffusion"].named_parameters())
Expand All @@ -285,12 +288,10 @@ def test_train_single_table(tmp_path: Path):
expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()}

model_layers = list(model_data.keys())
expected_model_layers = list(expected_model_data.keys())

# Adding those asserts under an if condition because they only pass on github.
# In the else block, we set a tolerance that would work across platforms
# however, it is way too high of a tolerance.
if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]):
if is_running_on_ci_environment():
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a better check for those if conditions.

# if the first layer is equal with minimal tolerance, all others should be equal as well
assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers)

Expand All @@ -303,6 +304,7 @@ def test_train_single_table(tmp_path: Path):
# Otherwise, set a tolerance that would work across platforms
# TODO: Figure out a way to set a lower tolerance
# https://app.clickup.com/t/868f43wp0
log(WARNING, "Not running on CI, assertions are made with a higher tolerance.")
assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers)

unset_all_random_seeds()
Expand Down Expand Up @@ -332,7 +334,7 @@ def test_train_multi_table(tmp_path: Path):
)
x_gen, y_gen = x_gen_tensor.numpy(), y_gen_tensor.numpy()

with open("tests/integration/assets/multi_table/assertion_data/syntetic_data.json", "r") as f:
with open("tests/integration/assets/multi_table/assertion_data/synthetic_data.json", "r") as f:
expected_results = json.load(f)

model_data = dict(models[1][key]["diffusion"].named_parameters())
Expand All @@ -343,13 +345,11 @@ def test_train_multi_table(tmp_path: Path):
# Making sure the expected model data is loaded on the correct device
expected_model_data = {layer: data.to(DEVICE) for layer, data in expected_model_data.items()}

model_layers = list(model_data.keys())
expected_model_layers = list(expected_model_data.keys())

# Adding those asserts under an if condition because they only pass on github.
# In the else block, we set a tolerance that would work across platforms
# however, it is way too high of a tolerance.
if torch.allclose(model_data[model_layers[0]], expected_model_data[expected_model_layers[0]]):
model_layers = list(model_data.keys())
if is_running_on_ci_environment():
# if the first layer is equal with minimal tolerance, all others should be equal as well
assert all(torch.allclose(model_data[layer], expected_model_data[layer]) for layer in model_layers)

Expand All @@ -362,6 +362,7 @@ def test_train_multi_table(tmp_path: Path):
# Otherwise, set a tolerance that would work across platforms
# TODO: Figure out a way to set a lower tolerance
# https://app.clickup.com/t/868f43wp0
log(WARNING, "Not running on CI, assertions are made with a higher tolerance.")
assert all(torch.allclose(model_data[layer], expected_model_data[layer], atol=0.1) for layer in model_layers)

classifier_scale = 1.0
Expand All @@ -382,11 +383,11 @@ def test_train_multi_table(tmp_path: Path):
).to(DEVICE)

# Adding those asserts under an if condition because they only pass on github.
# In the else block, we set a tolerance that would work across platforms
# however, it is way too high of a tolerance.
if torch.allclose(conditional_sample[0], expected_conditional_sample[0]):
if is_running_on_ci_environment():
# if the first values are equal with minimal tolerance, all others should be equal as well
assert torch.allclose(conditional_sample, expected_conditional_sample)
else:
log(WARNING, "Not running on CI, skipping detailed assertions.")

unset_all_random_seeds()

Expand All @@ -401,7 +402,7 @@ def test_clustering_reload(tmp_path: Path):
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG)

# Assert
account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster"])
account_df_no_clustering = tables["account"]["df"].drop(columns=["account_trans_cluster", "placeholder"])
account_original_df_as_float = tables["account"]["original_df"].astype(float)
assert account_df_no_clustering.equals(account_original_df_as_float)

Expand Down
103 changes: 103 additions & 0 deletions tests/integration/models/clavaddpm/test_synthesizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pickle
from copy import deepcopy
from logging import WARNING
from pathlib import Path

import pytest

from midst_toolkit.common.logger import log
from midst_toolkit.common.random import set_all_random_seeds, unset_all_random_seeds
from midst_toolkit.common.variables import DEVICE
from midst_toolkit.models.clavaddpm.clustering import clava_clustering
from midst_toolkit.models.clavaddpm.data_loaders import load_multi_table
from midst_toolkit.models.clavaddpm.synthesizer import clava_synthesizing
from midst_toolkit.models.clavaddpm.train import clava_training
from tests.integration.utils import is_running_on_ci_environment


CLUSTERING_CONFIG = {
"parent_scale": 1.0,
"num_clusters": 3,
"clustering_method": "kmeans_and_gmm",
}

DIFFUSION_CONFIG = {
"d_layers": [512, 1024, 1024, 1024, 1024, 512],
"dropout": 0.0,
"num_timesteps": 100,
"model_type": "mlp",
"iterations": 1000,
"batch_size": 24,
"lr": 0.0006,
"gaussian_loss_type": "mse",
"weight_decay": 1e-05,
"scheduler": "cosine",
"data_split_ratios": [0.99, 0.005, 0.005],
}

CLASSIFIER_CONFIG = {
"d_layers": [128, 256, 512, 1024, 512, 256, 128],
"lr": 0.0001,
"dim_t": 128,
"batch_size": 24,
"iterations": 1000,
"data_split_ratios": [0.99, 0.005, 0.005],
}

SYNTHESIZING_CONFIG = {
"general": {
"exp_name": "ensemble_attack",
"workspace_dir": None,
"sample_prefix": "",
},
"sampling": {
"batch_size": 2,
"classifier_scale": 1.0,
},
"matching": {
"num_matching_clusters": 1,
"matching_batch_size": 1,
"unique_matching": True,
"no_matching": False,
},
}


@pytest.mark.integration_test()
def test_clava_synthesize_multi_table(tmp_path: Path):
# Setup
set_all_random_seeds(seed=133742, use_deterministic_torch_algos=True, disable_torch_benchmarking=True)

# Act
tables, relation_order, _ = load_multi_table(Path("tests/integration/assets/multi_table/"))
tables, all_group_lengths_prob_dicts = clava_clustering(tables, relation_order, tmp_path, CLUSTERING_CONFIG)
models = clava_training(tables, relation_order, tmp_path, DIFFUSION_CONFIG, CLASSIFIER_CONFIG, device=DEVICE)

# TODO: Temporary, we should refactor those configs
configs = deepcopy(SYNTHESIZING_CONFIG)
configs["general"]["workspace_dir"] = str(tmp_path)

cleaned_tables, _, _ = clava_synthesizing(
tables,
relation_order,
tmp_path,
all_group_lengths_prob_dicts,
models[1],
configs,
)

# Assert
assert cleaned_tables["account"].shape == (9, 2)
assert cleaned_tables["trans"].shape == (145, 8)

if is_running_on_ci_environment():
expected_cleaned_tables = pickle.loads(
Path("tests/integration/assets/multi_table/assertion_data/cleaned_tables.pkl").read_bytes(),
)
assert cleaned_tables["account"].equals(expected_cleaned_tables["account"])
assert cleaned_tables["trans"].equals(expected_cleaned_tables["trans"])

else:
log(WARNING, "Not running on CI, skipping detailed assertions.")

unset_all_random_seeds()
Loading