diff --git a/simple/tests/stats/db_test.py b/simple/tests/stats/db_test.py index 1df77c82..2c7ee5d1 100644 --- a/simple/tests/stats/db_test.py +++ b/simple/tests/stats/db_test.py @@ -30,21 +30,13 @@ from stats.db import ImportStatus from stats.db import to_observation_tuple from stats.db import to_triple_tuple +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "db") _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") - -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - _TRIPLES = [ Triple("sub1", "typeOf", object_id="StatisticalVariable"), Triple("sub1", "pred1", object_value="objval1"), @@ -120,9 +112,9 @@ def test_main_dc_db(self): shutil.copy(mcf_file, expected_mcf_file) return - _compare_files(self, observations_file, expected_observations_file) - _compare_files(self, tmcf_file, expected_tmcf_file) - _compare_files(self, mcf_file, expected_mcf_file) + compare_files(self, observations_file, expected_observations_file) + compare_files(self, tmcf_file, expected_tmcf_file) + compare_files(self, mcf_file, expected_mcf_file) @mock.patch.dict(os.environ, {}) def test_get_cloud_sql_config_from_env_empty(self): diff --git a/simple/tests/stats/entities_importer_test.py b/simple/tests/stats/entities_importer_test.py index 4205f9e6..71ea283f 100644 --- a/simple/tests/stats/entities_importer_test.py +++ b/simple/tests/stats/entities_importer_test.py @@ -28,6 +28,7 @@ from stats.nodes import Nodes from stats.reporter import FileImportReporter from stats.reporter import ImportReporter +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -37,14 +38,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_triples(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from triples").fetchall() @@ -86,7 +79,7 @@ def _test_import(test: unittest.TestCase, test_name: str): shutil.copy(output_triples_path, expected_triples_path) return - _compare_files(test, output_triples_path, expected_triples_path) + compare_files(test, output_triples_path, expected_triples_path) class TestEntitiesImporter(unittest.TestCase): diff --git a/simple/tests/stats/events_importer_test.py b/simple/tests/stats/events_importer_test.py index 9d76da29..69324fda 100644 --- a/simple/tests/stats/events_importer_test.py +++ b/simple/tests/stats/events_importer_test.py @@ -29,6 +29,7 @@ from stats.nodes import Nodes from stats.reporter import FileImportReporter from stats.reporter import ImportReporter +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -38,14 +39,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_observations(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from observations").fetchall() @@ -104,8 +97,8 @@ def _test_import(test: unittest.TestCase, test_name: str): shutil.copy(output_observations_path, expected_observations_path) return - _compare_files(test, output_triples_path, expected_triples_path) - _compare_files(test, output_observations_path, expected_observations_path) + compare_files(test, output_triples_path, expected_triples_path) + compare_files(test, output_observations_path, expected_observations_path) class TestEventsImporter(unittest.TestCase): diff --git a/simple/tests/stats/mcf_importer_test.py b/simple/tests/stats/mcf_importer_test.py index 2bb481c9..f81c10d1 100644 --- a/simple/tests/stats/mcf_importer_test.py +++ b/simple/tests/stats/mcf_importer_test.py @@ -28,6 +28,7 @@ from stats.nodes import Nodes from stats.reporter import FileImportReporter from stats.reporter import ImportReporter +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -37,14 +38,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_triples(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from triples").fetchall() @@ -97,13 +90,13 @@ def _test_import(test: unittest.TestCase, shutil.copy(output_triples_path, expected_triples_path) return - _compare_files(test, output_triples_path, expected_triples_path) + compare_files(test, output_triples_path, expected_triples_path) else: if is_write_mode(): shutil.copy(output_mcf_path, expected_mcf_path) return - _compare_files(test, output_mcf_path, expected_mcf_path) + compare_files(test, output_mcf_path, expected_mcf_path) class TestMcfImporter(unittest.TestCase): diff --git a/simple/tests/stats/nl_test.py b/simple/tests/stats/nl_test.py index fa902c2b..072fb736 100644 --- a/simple/tests/stats/nl_test.py +++ b/simple/tests/stats/nl_test.py @@ -20,7 +20,9 @@ import pandas as pd from stats.data import Triple import stats.nl as nl +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode +from tests.stats.test_util import read_triples_csv from util.filehandler import LocalFileHandler _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), @@ -29,25 +31,12 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _read_triples_csv(path: str) -> list[Triple]: - df = pd.read_csv(path) - return [Triple(**kwargs) for kwargs in df.to_dict(orient='records')] - - -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _test_generate_nl_sentences(test: unittest.TestCase, test_name: str): test.maxDiff = None with tempfile.TemporaryDirectory() as temp_dir: input_triples_path = os.path.join(_INPUT_DIR, f"{test_name}.csv") - input_triples = _read_triples_csv(input_triples_path) + input_triples = read_triples_csv(input_triples_path) output_sentences_csv_path = os.path.join(temp_dir, f"{test_name}_sentences.csv") @@ -62,7 +51,7 @@ def _test_generate_nl_sentences(test: unittest.TestCase, test_name: str): shutil.copy(output_sentences_csv_path, expected_sentences_csv_path) return - _compare_files(test, output_sentences_csv_path, expected_sentences_csv_path) + compare_files(test, output_sentences_csv_path, expected_sentences_csv_path) class TestData(unittest.TestCase): diff --git a/simple/tests/stats/nodes_test.py b/simple/tests/stats/nodes_test.py index b572c196..580db182 100644 --- a/simple/tests/stats/nodes_test.py +++ b/simple/tests/stats/nodes_test.py @@ -23,6 +23,7 @@ from stats.data import StatVar from stats.data import StatVarGroup from stats.nodes import Nodes +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -30,15 +31,6 @@ "test_data", "nodes") _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") - -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - CONFIG_DATA = { "inputFiles": { "a.csv": { @@ -147,7 +139,7 @@ def test_triples(self): shutil.copy(output_path, expected_path) return - _compare_files(self, output_path, expected_path) + compare_files(self, output_path, expected_path) def test_variable_with_no_config(self): nodes = Nodes(CONFIG) diff --git a/simple/tests/stats/observations_importer_test.py b/simple/tests/stats/observations_importer_test.py index e82dac61..ff434ed3 100644 --- a/simple/tests/stats/observations_importer_test.py +++ b/simple/tests/stats/observations_importer_test.py @@ -27,6 +27,7 @@ from stats.observations_importer import ObservationsImporter from stats.reporter import FileImportReporter from stats.reporter import ImportReporter +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -36,14 +37,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_observations(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from observations").fetchall() @@ -94,7 +87,7 @@ def _test_import(test: unittest.TestCase, shutil.copy(output_path, expected_path) return - _compare_files(test, output_path, expected_path) + compare_files(test, output_path, expected_path) class TestObservationsImporter(unittest.TestCase): diff --git a/simple/tests/stats/runner_test.py b/simple/tests/stats/runner_test.py index 89fe3fbb..b880a7c3 100644 --- a/simple/tests/stats/runner_test.py +++ b/simple/tests/stats/runner_test.py @@ -26,6 +26,7 @@ from stats.data import Observation from stats.data import Triple from stats.runner import Runner +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util import dc_client @@ -37,14 +38,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_observations(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from observations").fetchall() @@ -109,9 +102,9 @@ def _test_runner(test: unittest.TestCase, shutil.copy(output_nl_sentences_path, expected_nl_sentences_path) return - _compare_files(test, output_triples_path, expected_triples_path) - _compare_files(test, output_observations_path, expected_observations_path) - _compare_files(test, output_nl_sentences_path, expected_nl_sentences_path) + compare_files(test, output_triples_path, expected_triples_path) + compare_files(test, output_observations_path, expected_observations_path) + compare_files(test, output_nl_sentences_path, expected_nl_sentences_path) class TestRunner(unittest.TestCase): diff --git a/simple/tests/stats/stat_var_hierarchy_generator_test.py b/simple/tests/stats/stat_var_hierarchy_generator_test.py index 87ae819a..ffe75b56 100644 --- a/simple/tests/stats/stat_var_hierarchy_generator_test.py +++ b/simple/tests/stats/stat_var_hierarchy_generator_test.py @@ -25,7 +25,9 @@ from stats.stat_var_hierarchy_generator import * from stats.stat_var_hierarchy_generator import _extract_svs from stats.stat_var_hierarchy_generator import _generate_internal +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode +from tests.stats.test_util import read_triples_csv _TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_data", "stat_var_hierarchy_generator") @@ -33,24 +35,10 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path, - message): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want, message) - - def _write_triples_csv(triples: list[Triple], path: str): pd.DataFrame(triples).to_csv(path, index=False) -def _read_triples_csv(path: str) -> list[Triple]: - df = pd.read_csv(path) - return [Triple(**kwargs) for kwargs in df.to_dict(orient='records')] - - def _strip_ns(v): return v[v.find(":") + 1:] @@ -83,7 +71,7 @@ def _test_generate_internal(test: unittest.TestCase, input_triples = _mcf_to_triples(input_mcf_path) else: input_triples_path = os.path.join(_INPUT_DIR, f"{test_name}.csv") - input_triples = _read_triples_csv(input_triples_path) + input_triples = read_triples_csv(input_triples_path) vertical_specs: list[VerticalSpec] = [] if has_vertical_specs: @@ -120,10 +108,10 @@ def _test_generate_internal(test: unittest.TestCase, shutil.copy(output_triples_csv_path, expected_triples_csv_path) return - _compare_files(test, output_svgs_json_path, expected_svgs_json_path, - f"Comparing SVGS JSON: {test_name}") - _compare_files(test, output_triples_csv_path, expected_triples_csv_path, - f"Comparing SVG TRIPLES: {test_name}") + compare_files(test, output_svgs_json_path, expected_svgs_json_path, + f"Comparing SVGS JSON: {test_name}") + compare_files(test, output_triples_csv_path, expected_triples_csv_path, + f"Comparing SVG TRIPLES: {test_name}") class TestStatVarHierarchyGenerator(unittest.TestCase): diff --git a/simple/tests/stats/variable_per_row_importer_test.py b/simple/tests/stats/variable_per_row_importer_test.py index fb5183fc..877ab61e 100644 --- a/simple/tests/stats/variable_per_row_importer_test.py +++ b/simple/tests/stats/variable_per_row_importer_test.py @@ -27,6 +27,7 @@ from stats.reporter import FileImportReporter from stats.reporter import ImportReporter from stats.variable_per_row_importer import VariablePerRowImporter +from tests.stats.test_util import compare_files from tests.stats.test_util import is_write_mode from util.filehandler import LocalFileHandler @@ -36,14 +37,6 @@ _EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected") -def _compare_files(test: unittest.TestCase, output_path, expected_path): - with open(output_path) as gotf: - got = gotf.read() - with open(expected_path) as wantf: - want = wantf.read() - test.assertEqual(got, want) - - def _write_observations(db_path: str, output_path: str): with sqlite3.connect(db_path) as db: rows = db.execute("select * from observations").fetchall() @@ -89,7 +82,7 @@ def _test_import(test: unittest.TestCase, shutil.copy(output_path, expected_path) return - _compare_files(test, output_path, expected_path) + compare_files(test, output_path, expected_path) class TestVariablePerRowImporter(unittest.TestCase):