From b2dc47d21687ddd68d175dff8c27a1e731177363 Mon Sep 17 00:00:00 2001 From: Keyur Shah Date: Tue, 28 May 2024 09:22:18 -0700 Subject: [PATCH] Add method to get names from both sql and remote dc. (#315) --- simple/stats/db.py | 9 ++- simple/stats/schema.py | 32 +++++++++ simple/tests/stats/schema_test.py | 114 ++++++++++++++++++++++++++++++ 3 files changed, 153 insertions(+), 2 deletions(-) create mode 100644 simple/stats/schema.py create mode 100644 simple/tests/stats/schema_test.py diff --git a/simple/stats/db.py b/simple/stats/db.py index dcdb0666..af9e7a7f 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -148,6 +148,10 @@ def commit_and_close(self): def select_triples_by_subject_type(self, subject_type: str) -> list[Triple]: pass + # Returns names of the corresponding dcids. + def select_entity_names(self, dcids: list[str]) -> dict[str, str]: + pass + class MainDcDb(Db): """Generates output for main DC. @@ -217,8 +221,9 @@ def __init__(self, config: dict) -> None: def insert_triples(self, triples: list[Triple]): logging.info("Writing %s triples to [%s]", len(triples), self.engine) - self.engine.executemany(_INSERT_TRIPLES_STATEMENT, - [to_triple_tuple(triple) for triple in triples]) + if triples: + self.engine.executemany(_INSERT_TRIPLES_STATEMENT, + [to_triple_tuple(triple) for triple in triples]) def insert_observations(self, observations: list[Observation], input_file_name: str): diff --git a/simple/stats/schema.py b/simple/stats/schema.py new file mode 100644 index 00000000..7b73520c --- /dev/null +++ b/simple/stats/schema.py @@ -0,0 +1,32 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Functions to fetch schema info from custom dc db and remote dc. + +from stats import schema_constants as sc +from stats.db import Db + +from util import dc_client + + +# Gets names of the specified dcids first from db and any remaining +# ones from remote dc. +def get_schema_names(dcids: list[str], db: Db) -> dict[str, str]: + db_dcid2name = db.select_entity_names(dcids) + dcids = list(filter(lambda x: x not in db_dcid2name, dcids)) + remote_dcid2name = {} + if dcids: + remote_dcid2name = dc_client.get_property_of_entities( + dcids, sc.PREDICATE_NAME) + return remote_dcid2name | db_dcid2name diff --git a/simple/tests/stats/schema_test.py b/simple/tests/stats/schema_test.py new file mode 100644 index 00000000..432e6e71 --- /dev/null +++ b/simple/tests/stats/schema_test.py @@ -0,0 +1,114 @@ +# Copyright 2024 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import sqlite3 +import tempfile +import unittest +from unittest import mock + +from freezegun import freeze_time +import pandas as pd +from parameterized import parameterized +from stats import schema +from stats import schema_constants as sc +from stats.data import Observation +from stats.data import Triple +from stats.db import create_db +from stats.db import create_main_dc_config +from stats.db import create_sqlite_config +from stats.db import get_cloud_sql_config_from_env +from stats.db import get_sqlite_config_from_env +from stats.db import ImportStatus +from stats.db import to_observation_tuple +from stats.db import to_triple_tuple +from tests.stats.test_util import is_write_mode + + +def _to_triples(dcid2name: dict[str, str]) -> list[Triple]: + triples: list[Triple] = [] + for dcid, name in dcid2name.items(): + triples.append(Triple(dcid, sc.PREDICATE_NAME, object_value=name)) + return triples + + +class TestSchema(unittest.TestCase): + + @parameterized.expand([ + ( + "both", + { + "var1": "Variable 1" + }, + { + "prop1": "Property 1" + }, + ["var1", "prop1"], + { + "var1": "Variable 1", + "prop1": "Property 1" + }, + ), + ( + "db only", + { + "var1": "Variable 1" + }, + {}, + ["var1", "prop1"], + { + "var1": "Variable 1" + }, + ), + ( + "remote only", + {}, + { + "prop1": "Property 1" + }, + ["var1", "prop1"], + { + "prop1": "Property 1" + }, + ), + ( + "prefer db value", + { + "var1": "DB Var 1" + }, + { + "var1": "Remote Var 1" + }, + ["var1", "prop1"], + { + "var1": "DB Var 1" + }, + ), + ]) + @mock.patch("util.dc_client.get_property_of_entities") + def test_get_schema_names(self, desc: str, db_names: dict[str, str], + remote_names: dict[str, + str], input_dcids: list[str], + output_names: dict[str, str], mock_dc_client): + with tempfile.TemporaryDirectory() as temp_dir: + db_file_path = os.path.join(temp_dir, "datacommons.db") + db = create_db(create_sqlite_config(db_file_path)) + db.insert_triples(_to_triples(db_names)) + mock_dc_client.return_value = remote_names + + result = schema.get_schema_names(input_dcids, db) + + self.assertDictEqual(result, output_names, desc)