Skip to content

Commit 32eba99

Browse files
Merge pull request #13 from Clearbox-AI/develop
Develop in main
2 parents 4c481d9 + 463e770 commit 32eba99

6 files changed

Lines changed: 191 additions & 47 deletions

File tree

README.md

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,16 @@ Below is a code snippet example for the usage of the library:
8080
# Import the necessary modules from the SURE library
8181
from sure import Preprocessor, report
8282
from sure.utility import (compute_statistical_metrics, compute_mutual_info,
83-
compute_utility_metrics_class)
83+
compute_utility_metrics_class,
84+
detection,
85+
query_power)
8486
from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test,
8587
adversary_dataset, membership_inference_test)
8688

8789
# Assuming real_data, valid_data and synth_data are three pandas DataFrames
8890

8991
# Preprocessor initialization and query execution on the real, synthetic and validation datasets
90-
preprocessor = Preprocessor(real_data, get_discarded_info=False, num_fill_null='forward', scaling='standardize')
92+
preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')
9193

9294
real_data_preprocessed = preprocessor.transform(real_data)
9395
valid_data_preprocessed = preprocessor.transform(valid_data)
@@ -115,6 +117,12 @@ dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_tra
115117
dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid)
116118
share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)
117119

120+
# Detection Score
121+
detection_score = detection(real_data, synth_data, preprocessor)
122+
123+
# Query Power
124+
query_power_score = query_power(real_data, synth_data, preprocessor)
125+
118126
# ML privacy attack sandbox initialization and simulation
119127
adversary_df = adversary_dataset(real_data_preprocessed, valid_data_preprocessed)
120128
# The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not

docs/source/doc_2.md

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,19 @@ Follow the step-by-step guide to test the library using the provided [instructio
2323
# Import the necessary modules from the SURE library
2424
from sure import Preprocessor, report
2525
from sure.utility import (compute_statistical_metrics, compute_mutual_info,
26-
compute_utility_metrics_class)
26+
compute_utility_metrics_class,
27+
detection,
28+
query_power)
2729
from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test,
2830
adversary_dataset, membership_inference_test)
2931

3032
# Assuming real_data, valid_data and synth_data are three pandas DataFrames
3133

32-
# Real dataset - Preprocessor initialization and query exacution
33-
preprocessor = Preprocessor(real_data, get_discarded_info=False)
34-
real_data_preprocessed = preprocessor.transform(real_data, num_fill_null='forward', scaling='standardize')
35-
36-
# Validation dataset - Preprocessor initialization and query exacution
37-
preprocessor = Preprocessor(valid_data, get_discarded_info=False)
38-
valid_data_preprocessed = preprocessor.transform(valid_data, num_fill_null='forward', scaling='standardize')
39-
40-
# Synthetic dataset - Preprocessor initialization and query exacution
41-
preprocessor = Preprocessor(synth_data, get_discarded_info=False)
42-
synth_data_preprocessed = preprocessor.transform(synth_data, num_fill_null='forward', scaling='standardize')
34+
# Preprocessor initialization and query execution on the real, synthetic and validation datasets
35+
preprocessor = Preprocessor(real_data)
36+
real_data_preprocessed = preprocessor.transform(real_data)
37+
valid_data_preprocessed = preprocessor.transform(valid_data)
38+
synth_data_preprocessed = preprocessor.transform(synth_data)
4339

4440
# Statistical properties and mutual information
4541
num_features_stats, cat_features_stats, temporal_feat_stats = compute_statistical_metrics(real_data, synth_data)
@@ -63,6 +59,12 @@ dcr_zero_synth_train = number_of_dcr_equal_to_zero("synth_train", dcr_synth_tra
6359
dcr_zero_synth_valid = number_of_dcr_equal_to_zero("synth_val", dcr_synth_valid)
6460
share = validation_dcr_test(dcr_synth_train, dcr_synth_valid)
6561

62+
# Detection Score
63+
detection_score = detection(real_data, synth_data, preprocessor)
64+
65+
# Query Power
66+
query_power_score = query_power(real_data, synth_data, preprocessor)
67+
6668
# ML privacy attack sandbox initialization and simulation
6769
adversary_df = adversary_dataset(real_data_preprocessed, valid_data_preprocessed)
6870
# The function adversary_dataset adds a column "privacy_test_is_training" to the adversary dataset, indicating whether the record was part of the training set or not

examples/sure_test.ipynb

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@
6464
"\n",
6565
"from sure import Preprocessor, report\n",
6666
"from sure.utility import (compute_statistical_metrics, compute_mutual_info,\n",
67-
"\t\t\t compute_utility_metrics_class)\n",
67+
"\t\t\t \t\t\t compute_utility_metrics_class,\n",
68+
"\t\t\t\t\t\t detection,\n",
69+
"\t\t\t\t\t\t query_power)\n",
6870
"from sure.privacy import (distance_to_closest_record, dcr_stats, number_of_dcr_equal_to_zero, validation_dcr_test, \n",
69-
"\t\t\t adversary_dataset, membership_inference_test)"
71+
"\t\t\t adversary_dataset, membership_inference_test)"
7072
]
7173
},
7274
{
@@ -111,7 +113,7 @@
111113
"outputs": [],
112114
"source": [
113115
"# Preprocessor initialization and query execution on the real, synthetic and validation datasets\n",
114-
"preprocessor = Preprocessor(real_data, get_discarded_info=False, num_fill_null='forward', scaling='standardize')\n",
116+
"preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='standardize')\n",
115117
"\n",
116118
"real_data_preprocessed = preprocessor.transform(real_data)\n",
117119
"valid_data_preprocessed = preprocessor.transform(valid_data)\n",
@@ -129,7 +131,8 @@
129131
"cell_type": "markdown",
130132
"metadata": {},
131133
"source": [
132-
"#### 2.1 Statistical properties and mutual information"
134+
"#### 2.1 Statistical properties and mutual information\n",
135+
"These functions compute general statistical features, the correlation matrices and the difference between the correlation matrix of the real and synthetic dataset."
133136
]
134137
},
135138
{
@@ -147,7 +150,10 @@
147150
"cell_type": "markdown",
148151
"metadata": {},
149152
"source": [
150-
"#### 2.2 ML utility - Train on Synthetic Test on Real"
153+
"#### 2.2 ML utility - Train on Synthetic Test on Real\n",
154+
"The `compute_utility_metrics_class` trains multiple machine learning classification models on the synthetic dataset and evaluates their performance on the validation set.\n",
155+
"\n",
156+
"For comparison, it also trains the same models on the original training set and evaluates them on the same validation set. This allows a direct comparison between models trained on synthetic data and those trained on real data."
151157
]
152158
},
153159
{
@@ -168,6 +174,63 @@
168174
"TSTR_metrics = compute_utility_metrics_class(X_train, X_synth, X_test, y_train, y_synth, y_test)"
169175
]
170176
},
177+
{
178+
"cell_type": "markdown",
179+
"metadata": {},
180+
"source": [
181+
"#### 2.3 Detection Score\n",
182+
"Computes the detection score by training an XGBoost model to differentiate between original and synthetic data. \n",
183+
"\n",
184+
"The lower the model's accuracy, the higher the quality of the synthetic data.\n",
185+
"\n",
186+
"\n",
187+
"The detection score is computed as\n",
188+
"\n",
189+
"detection_score = 2*(1 - ROC_AUC)\n",
190+
"\n",
191+
"So if ROC_AUC<=0.5 the synthetic dataset is considered indistinguishable from the real dataset (detection score =1)\n"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"metadata": {},
198+
"outputs": [],
199+
"source": [
200+
"detection_score = detection(real_data, synth_data, preprocessor)\n",
201+
"print(\"Detection accuracy: \", detection_score[\"accuracy\"])\n",
202+
"print(\"Detection ROC_AUC: \", detection_score[\"ROC_AUC\"])\n",
203+
"print(\"Detection score: \", detection_score[\"score\"])\n",
204+
"print(\"Detection feature importances: \", detection_score[\"feature_importances\"])"
205+
]
206+
},
207+
{
208+
"cell_type": "markdown",
209+
"metadata": {},
210+
"source": [
211+
"#### Query Power\n",
212+
"Generates and runs queries to compare the original and synthetic datasets.\n",
213+
"\n",
214+
"This method creates random queries that filter data from both datasets.\n",
215+
"\n",
216+
"The similarity between the sizes of the filtered results is used to score the quality of the synthetic data."
217+
]
218+
},
219+
{
220+
"cell_type": "code",
221+
"execution_count": null,
222+
"metadata": {},
223+
"outputs": [],
224+
"source": [
225+
"query_power_score = query_power(real_data, synth_data, preprocessor)\n",
226+
"\n",
227+
"print(\"Query Power score: \", query_power_score[\"score\"])\n",
228+
"for query in query_power_score[\"queries\"]:\n",
229+
" print(\"\\n\", query[\"text\"])\n",
230+
" print(\"Query result on real: \", query[\"original_df\"])\n",
231+
" print(\"Query result on synthetic: \", query[\"synthetic_df\"])"
232+
]
233+
},
171234
{
172235
"cell_type": "markdown",
173236
"metadata": {},
@@ -278,9 +341,9 @@
278341
],
279342
"metadata": {
280343
"kernelspec": {
281-
"display_name": "projects",
344+
"display_name": "Python (test_sure)",
282345
"language": "python",
283-
"name": "python3"
346+
"name": "test_sure"
284347
},
285348
"language_info": {
286349
"codemirror_mode": {

sure/distance_metrics/distance.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def distance_to_closest_record(
8585
y_dataframe: pd.DataFrame | pl.DataFrame | pl.LazyFrame = None,
8686
feature_weights: np.ndarray | List = None,
8787
parallel: bool = True,
88-
save_output: bool = True,
88+
save_data: bool = True,
8989
path_to_json: str = ""
9090
) -> np.ndarray:
9191
"""
@@ -120,8 +120,10 @@ def distance_to_closest_record(
120120
If None, each feature weight is 1.0
121121
parallel : Boolean, optional
122122
Whether to enable the parallelization to compute Gower matrix, by default True
123-
save_output : bool
123+
save_data : bool
124124
If True, saves the DCR information into the JSON file used to generate the final report.
125+
path_to_json : str
126+
Path to the JSON file used to generate the final report.
125127
126128
Returns
127129
-------
@@ -263,12 +265,13 @@ def distance_to_closest_record(
263265
weight_sum,
264266
fill_diagonal,
265267
)
266-
if save_output:
268+
if save_data:
267269
_save_to_json("dcr_"+dcr_name, dcr, path_to_json)
268270
return dcr
269271

270272
def dcr_stats(dcr_name: str,
271273
distances_to_closest_record: np.ndarray,
274+
save_data: bool = True,
272275
path_to_json: str = "") -> Dict:
273276
"""
274277
This function returns the statisitcs for an array containing DCR computed previously.
@@ -284,6 +287,10 @@ def dcr_stats(dcr_name: str,
284287
distances_to_closest_record : np.ndarray
285288
A 1D-array containing the Distance to the Closest Record for each row of a dataframe
286289
shape (dataframe rows, )
290+
save_data : bool
291+
If True, saves the DCR information into the JSON file used to generate the final report.
292+
path_to_json : str
293+
Path to the JSON file used to generate the final report.
287294
288295
Returns
289296
-------
@@ -303,11 +310,13 @@ def dcr_stats(dcr_name: str,
303310
"75%": dcr_percentiles[3].item(),
304311
"max": dcr_percentiles[4].item(),
305312
}
306-
_save_to_json("dcr_"+dcr_name+"_stats", dcr_stats, path_to_json)
313+
if save_data:
314+
_save_to_json("dcr_"+dcr_name+"_stats", dcr_stats, path_to_json)
307315
return dcr_stats
308316

309317
def number_of_dcr_equal_to_zero(dcr_name: str,
310318
distances_to_closest_record: np.ndarray,
319+
save_data: bool = True,
311320
path_to_json: str = "") -> int_type:
312321
"""
313322
Return the number of 0s in the given DCR array, that is the number of duplicates/clones detected.
@@ -317,6 +326,10 @@ def number_of_dcr_equal_to_zero(dcr_name: str,
317326
distances_to_closest_record : np.ndarray
318327
A 1D-array containing the Distance to the Closest Record for each row of a dataframe
319328
shape (dataframe rows, )
329+
save_data : bool
330+
If True, saves the DCR information into the JSON file used to generate the final report.
331+
path_to_json : str
332+
Path to the JSON file used to generate the final report.
320333
321334
Returns
322335
-------
@@ -327,14 +340,16 @@ def number_of_dcr_equal_to_zero(dcr_name: str,
327340
raise TypeError("dcr_name must be one of the following:\n -\"synth_train\"\n -\"synth_val\"\n -\"other\"")
328341

329342
zero_values_mask = distances_to_closest_record == 0.0
330-
_save_to_json("dcr_"+dcr_name+"_num_of_zeros", zero_values_mask.sum(), path_to_json)
343+
if save_data:
344+
_save_to_json("dcr_"+dcr_name+"_num_of_zeros", zero_values_mask.sum(), path_to_json)
331345
return zero_values_mask.sum()
332346

333347
# def dcr_histogram(
334348
# dcr_name: str,
335349
# distances_to_closest_record: np.ndarray,
336350
# bins: int = 20,
337351
# scale_to_100: bool = True,
352+
# save_data: bool = True,
338353
# path_to_json: str = ""
339354
# ) -> Dict:
340355
# """
@@ -394,12 +409,14 @@ def number_of_dcr_equal_to_zero(dcr_name: str,
394409
# "bins_edge_without_zero": bins_without_zero.tolist(),
395410
# }
396411

397-
# _save_to_json("dcr_"+dcr_name+"_hist", dcr_hist, path_to_json)
412+
# if save_data:
413+
# _save_to_json("dcr_"+dcr_name+"_hist", dcr_hist, path_to_json)
398414
# return dcr_hist
399415

400416
def validation_dcr_test(
401417
dcr_synth_train: np.ndarray,
402418
dcr_synth_validation: np.ndarray,
419+
save_data: bool = True,
403420
path_to_json: str = ""
404421
) -> float_type:
405422
"""
@@ -416,6 +433,10 @@ def validation_dcr_test(
416433
dcr_synth_validation : np.ndarray
417434
A 1D-array containing the Distance to the Closest Record for each row of the synthetic
418435
dataset wrt the validation dataset, shape (synthetic rows, )
436+
save_data : bool
437+
If True, saves the DCR information into the JSON file used to generate the final report.
438+
path_to_json : str
439+
Path to the JSON file used to generate the final report.
419440
420441
Returns
421442
-------
@@ -455,5 +476,6 @@ def validation_dcr_test(
455476
percentage = synth_dcr_smaller_than_holdout_dcr_sum / number_of_rows * 100
456477

457478
dcr_validation = {"percentage": round(percentage,4), "warnings": warnings}
458-
_save_to_json("dcr_validation", dcr_validation, path_to_json)
479+
if save_data:
480+
_save_to_json("dcr_validation", dcr_validation, path_to_json)
459481
return dcr_validation

sure/privacy/privacy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def membership_inference_test(
9090
synthetic_dataset: pd.DataFrame | pl.DataFrame | pl.LazyFrame,
9191
adversary_guesses_ground_truth: np.ndarray | pd.DataFrame | pl.DataFrame | pl.LazyFrame | pl.Series,
9292
parallel: bool = True,
93+
save_data = True,
9394
path_to_json: str = ""
9495
):
9596
"""
@@ -105,6 +106,8 @@ def membership_inference_test(
105106
Ground truth labels indicating whether a sample is from the original training dataset or not.
106107
parallel : bool, optional
107108
Whether to use parallel processing for distance calculations, by default True.
109+
save_data : bool
110+
If True, saves the DCR information into the JSON file used to generate the final report, by default True.
108111
path_to_json : str, optional
109112
Path to save the attack output as a JSON file. If empty, the output is not saved, by default "".
110113
@@ -157,5 +160,6 @@ def membership_inference_test(
157160
"membership_inference_mean_risk_score": membership_inference_mean_risk_score,
158161
}
159162

160-
_save_to_json("MIA_attack", attack_output, path_to_json)
163+
if save_data:
164+
_save_to_json("MIA_attack", attack_output, path_to_json)
161165
return attack_output

0 commit comments

Comments
 (0)