Skip to content

Commit

Permalink
Add method to get names from both sql and remote dc. (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
keyurva authored May 28, 2024
1 parent d98bfde commit b2dc47d
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 2 deletions.
9 changes: 7 additions & 2 deletions simple/stats/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ def commit_and_close(self):
def select_triples_by_subject_type(self, subject_type: str) -> list[Triple]:
pass

# Returns names of the corresponding dcids.
def select_entity_names(self, dcids: list[str]) -> dict[str, str]:
pass


class MainDcDb(Db):
"""Generates output for main DC.
Expand Down Expand Up @@ -217,8 +221,9 @@ def __init__(self, config: dict) -> None:

def insert_triples(self, triples: list[Triple]):
logging.info("Writing %s triples to [%s]", len(triples), self.engine)
self.engine.executemany(_INSERT_TRIPLES_STATEMENT,
[to_triple_tuple(triple) for triple in triples])
if triples:
self.engine.executemany(_INSERT_TRIPLES_STATEMENT,
[to_triple_tuple(triple) for triple in triples])

def insert_observations(self, observations: list[Observation],
input_file_name: str):
Expand Down
32 changes: 32 additions & 0 deletions simple/stats/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2024 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Functions to fetch schema info from custom dc db and remote dc.

from stats import schema_constants as sc
from stats.db import Db

from util import dc_client


# Gets names of the specified dcids first from db and any remaining
# ones from remote dc.
def get_schema_names(dcids: list[str], db: Db) -> dict[str, str]:
db_dcid2name = db.select_entity_names(dcids)
dcids = list(filter(lambda x: x not in db_dcid2name, dcids))
remote_dcid2name = {}
if dcids:
remote_dcid2name = dc_client.get_property_of_entities(
dcids, sc.PREDICATE_NAME)
return remote_dcid2name | db_dcid2name
114 changes: 114 additions & 0 deletions simple/tests/stats/schema_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2024 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import shutil
import sqlite3
import tempfile
import unittest
from unittest import mock

from freezegun import freeze_time
import pandas as pd
from parameterized import parameterized
from stats import schema
from stats import schema_constants as sc
from stats.data import Observation
from stats.data import Triple
from stats.db import create_db
from stats.db import create_main_dc_config
from stats.db import create_sqlite_config
from stats.db import get_cloud_sql_config_from_env
from stats.db import get_sqlite_config_from_env
from stats.db import ImportStatus
from stats.db import to_observation_tuple
from stats.db import to_triple_tuple
from tests.stats.test_util import is_write_mode


def _to_triples(dcid2name: dict[str, str]) -> list[Triple]:
triples: list[Triple] = []
for dcid, name in dcid2name.items():
triples.append(Triple(dcid, sc.PREDICATE_NAME, object_value=name))
return triples


class TestSchema(unittest.TestCase):

@parameterized.expand([
(
"both",
{
"var1": "Variable 1"
},
{
"prop1": "Property 1"
},
["var1", "prop1"],
{
"var1": "Variable 1",
"prop1": "Property 1"
},
),
(
"db only",
{
"var1": "Variable 1"
},
{},
["var1", "prop1"],
{
"var1": "Variable 1"
},
),
(
"remote only",
{},
{
"prop1": "Property 1"
},
["var1", "prop1"],
{
"prop1": "Property 1"
},
),
(
"prefer db value",
{
"var1": "DB Var 1"
},
{
"var1": "Remote Var 1"
},
["var1", "prop1"],
{
"var1": "DB Var 1"
},
),
])
@mock.patch("util.dc_client.get_property_of_entities")
def test_get_schema_names(self, desc: str, db_names: dict[str, str],
remote_names: dict[str,
str], input_dcids: list[str],
output_names: dict[str, str], mock_dc_client):
with tempfile.TemporaryDirectory() as temp_dir:
db_file_path = os.path.join(temp_dir, "datacommons.db")
db = create_db(create_sqlite_config(db_file_path))
db.insert_triples(_to_triples(db_names))
mock_dc_client.return_value = remote_names

result = schema.get_schema_names(input_dcids, db)

self.assertDictEqual(result, output_names, desc)

0 comments on commit b2dc47d

Please sign in to comment.