Skip to content

Commit

Permalink
Use common test util functions. (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored May 31, 2024
1 parent d7516a6 commit 0793b20
Show file tree
Hide file tree
Showing 10 changed files with 33 additions and 114 deletions.
16 changes: 4 additions & 12 deletions simple/tests/stats/db_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,13 @@
from stats.db import ImportStatus
from stats.db import to_observation_tuple
from stats.db import to_triple_tuple
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode

_TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"test_data", "db")
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


_TRIPLES = [
Triple("sub1", "typeOf", object_id="StatisticalVariable"),
Triple("sub1", "pred1", object_value="objval1"),
Expand Down Expand Up @@ -120,9 +112,9 @@ def test_main_dc_db(self):
shutil.copy(mcf_file, expected_mcf_file)
return

_compare_files(self, observations_file, expected_observations_file)
_compare_files(self, tmcf_file, expected_tmcf_file)
_compare_files(self, mcf_file, expected_mcf_file)
compare_files(self, observations_file, expected_observations_file)
compare_files(self, tmcf_file, expected_tmcf_file)
compare_files(self, mcf_file, expected_mcf_file)

@mock.patch.dict(os.environ, {})
def test_get_cloud_sql_config_from_env_empty(self):
Expand Down
11 changes: 2 additions & 9 deletions simple/tests/stats/entities_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from stats.nodes import Nodes
from stats.reporter import FileImportReporter
from stats.reporter import ImportReporter
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

Expand All @@ -37,14 +38,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_triples(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from triples").fetchall()
Expand Down Expand Up @@ -86,7 +79,7 @@ def _test_import(test: unittest.TestCase, test_name: str):
shutil.copy(output_triples_path, expected_triples_path)
return

_compare_files(test, output_triples_path, expected_triples_path)
compare_files(test, output_triples_path, expected_triples_path)


class TestEntitiesImporter(unittest.TestCase):
Expand Down
13 changes: 3 additions & 10 deletions simple/tests/stats/events_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from stats.nodes import Nodes
from stats.reporter import FileImportReporter
from stats.reporter import ImportReporter
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

Expand All @@ -38,14 +39,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_observations(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from observations").fetchall()
Expand Down Expand Up @@ -104,8 +97,8 @@ def _test_import(test: unittest.TestCase, test_name: str):
shutil.copy(output_observations_path, expected_observations_path)
return

_compare_files(test, output_triples_path, expected_triples_path)
_compare_files(test, output_observations_path, expected_observations_path)
compare_files(test, output_triples_path, expected_triples_path)
compare_files(test, output_observations_path, expected_observations_path)


class TestEventsImporter(unittest.TestCase):
Expand Down
13 changes: 3 additions & 10 deletions simple/tests/stats/mcf_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from stats.nodes import Nodes
from stats.reporter import FileImportReporter
from stats.reporter import ImportReporter
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

Expand All @@ -37,14 +38,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_triples(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from triples").fetchall()
Expand Down Expand Up @@ -97,13 +90,13 @@ def _test_import(test: unittest.TestCase,
shutil.copy(output_triples_path, expected_triples_path)
return

_compare_files(test, output_triples_path, expected_triples_path)
compare_files(test, output_triples_path, expected_triples_path)
else:
if is_write_mode():
shutil.copy(output_mcf_path, expected_mcf_path)
return

_compare_files(test, output_mcf_path, expected_mcf_path)
compare_files(test, output_mcf_path, expected_mcf_path)


class TestMcfImporter(unittest.TestCase):
Expand Down
19 changes: 4 additions & 15 deletions simple/tests/stats/nl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import pandas as pd
from stats.data import Triple
import stats.nl as nl
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from tests.stats.test_util import read_triples_csv
from util.filehandler import LocalFileHandler

_TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
Expand All @@ -29,25 +31,12 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _read_triples_csv(path: str) -> list[Triple]:
df = pd.read_csv(path)
return [Triple(**kwargs) for kwargs in df.to_dict(orient='records')]


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _test_generate_nl_sentences(test: unittest.TestCase, test_name: str):
test.maxDiff = None

with tempfile.TemporaryDirectory() as temp_dir:
input_triples_path = os.path.join(_INPUT_DIR, f"{test_name}.csv")
input_triples = _read_triples_csv(input_triples_path)
input_triples = read_triples_csv(input_triples_path)

output_sentences_csv_path = os.path.join(temp_dir,
f"{test_name}_sentences.csv")
Expand All @@ -62,7 +51,7 @@ def _test_generate_nl_sentences(test: unittest.TestCase, test_name: str):
shutil.copy(output_sentences_csv_path, expected_sentences_csv_path)
return

_compare_files(test, output_sentences_csv_path, expected_sentences_csv_path)
compare_files(test, output_sentences_csv_path, expected_sentences_csv_path)


class TestData(unittest.TestCase):
Expand Down
12 changes: 2 additions & 10 deletions simple/tests/stats/nodes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,14 @@
from stats.data import StatVar
from stats.data import StatVarGroup
from stats.nodes import Nodes
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

_TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"test_data", "nodes")
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


CONFIG_DATA = {
"inputFiles": {
"a.csv": {
Expand Down Expand Up @@ -147,7 +139,7 @@ def test_triples(self):
shutil.copy(output_path, expected_path)
return

_compare_files(self, output_path, expected_path)
compare_files(self, output_path, expected_path)

def test_variable_with_no_config(self):
nodes = Nodes(CONFIG)
Expand Down
11 changes: 2 additions & 9 deletions simple/tests/stats/observations_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from stats.observations_importer import ObservationsImporter
from stats.reporter import FileImportReporter
from stats.reporter import ImportReporter
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

Expand All @@ -36,14 +37,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_observations(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from observations").fetchall()
Expand Down Expand Up @@ -94,7 +87,7 @@ def _test_import(test: unittest.TestCase,
shutil.copy(output_path, expected_path)
return

_compare_files(test, output_path, expected_path)
compare_files(test, output_path, expected_path)


class TestObservationsImporter(unittest.TestCase):
Expand Down
15 changes: 4 additions & 11 deletions simple/tests/stats/runner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from stats.data import Observation
from stats.data import Triple
from stats.runner import Runner
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode

from util import dc_client
Expand All @@ -37,14 +38,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_observations(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from observations").fetchall()
Expand Down Expand Up @@ -109,9 +102,9 @@ def _test_runner(test: unittest.TestCase,
shutil.copy(output_nl_sentences_path, expected_nl_sentences_path)
return

_compare_files(test, output_triples_path, expected_triples_path)
_compare_files(test, output_observations_path, expected_observations_path)
_compare_files(test, output_nl_sentences_path, expected_nl_sentences_path)
compare_files(test, output_triples_path, expected_triples_path)
compare_files(test, output_observations_path, expected_observations_path)
compare_files(test, output_nl_sentences_path, expected_nl_sentences_path)


class TestRunner(unittest.TestCase):
Expand Down
26 changes: 7 additions & 19 deletions simple/tests/stats/stat_var_hierarchy_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,20 @@
from stats.stat_var_hierarchy_generator import *
from stats.stat_var_hierarchy_generator import _extract_svs
from stats.stat_var_hierarchy_generator import _generate_internal
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from tests.stats.test_util import read_triples_csv

_TEST_DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
"test_data", "stat_var_hierarchy_generator")
_INPUT_DIR = os.path.join(_TEST_DATA_DIR, "input")
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path,
message):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want, message)


def _write_triples_csv(triples: list[Triple], path: str):
pd.DataFrame(triples).to_csv(path, index=False)


def _read_triples_csv(path: str) -> list[Triple]:
df = pd.read_csv(path)
return [Triple(**kwargs) for kwargs in df.to_dict(orient='records')]


def _strip_ns(v):
return v[v.find(":") + 1:]

Expand Down Expand Up @@ -83,7 +71,7 @@ def _test_generate_internal(test: unittest.TestCase,
input_triples = _mcf_to_triples(input_mcf_path)
else:
input_triples_path = os.path.join(_INPUT_DIR, f"{test_name}.csv")
input_triples = _read_triples_csv(input_triples_path)
input_triples = read_triples_csv(input_triples_path)

vertical_specs: list[VerticalSpec] = []
if has_vertical_specs:
Expand Down Expand Up @@ -120,10 +108,10 @@ def _test_generate_internal(test: unittest.TestCase,
shutil.copy(output_triples_csv_path, expected_triples_csv_path)
return

_compare_files(test, output_svgs_json_path, expected_svgs_json_path,
f"Comparing SVGS JSON: {test_name}")
_compare_files(test, output_triples_csv_path, expected_triples_csv_path,
f"Comparing SVG TRIPLES: {test_name}")
compare_files(test, output_svgs_json_path, expected_svgs_json_path,
f"Comparing SVGS JSON: {test_name}")
compare_files(test, output_triples_csv_path, expected_triples_csv_path,
f"Comparing SVG TRIPLES: {test_name}")


class TestStatVarHierarchyGenerator(unittest.TestCase):
Expand Down
11 changes: 2 additions & 9 deletions simple/tests/stats/variable_per_row_importer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from stats.reporter import FileImportReporter
from stats.reporter import ImportReporter
from stats.variable_per_row_importer import VariablePerRowImporter
from tests.stats.test_util import compare_files
from tests.stats.test_util import is_write_mode
from util.filehandler import LocalFileHandler

Expand All @@ -36,14 +37,6 @@
_EXPECTED_DIR = os.path.join(_TEST_DATA_DIR, "expected")


def _compare_files(test: unittest.TestCase, output_path, expected_path):
with open(output_path) as gotf:
got = gotf.read()
with open(expected_path) as wantf:
want = wantf.read()
test.assertEqual(got, want)


def _write_observations(db_path: str, output_path: str):
with sqlite3.connect(db_path) as db:
rows = db.execute("select * from observations").fetchall()
Expand Down Expand Up @@ -89,7 +82,7 @@ def _test_import(test: unittest.TestCase,
shutil.copy(output_path, expected_path)
return

_compare_files(test, output_path, expected_path)
compare_files(test, output_path, expected_path)


class TestVariablePerRowImporter(unittest.TestCase):
Expand Down

0 comments on commit 0793b20

Please sign in to comment.