From 25d6db92bf1f22cfedf956de9ba7643606b87409 Mon Sep 17 00:00:00 2001 From: Ale Sanchez Date: Mon, 16 Aug 2021 10:59:39 +0200 Subject: [PATCH] fix: Fixes across package to allow CLI usage --- covid_data/__init__.py | 2 ++ covid_data/__main__.py | 4 +++ covid_data/commands/scrap_country.py | 40 ++++++++++++++---------- covid_data/{covid_data => covid_data.py} | 4 +-- covid_data/db/queries.py | 40 ++++++++++++------------ covid_data/scrappers/__init__.py | 0 covid_data/scrappers/france.py | 5 +-- covid_data/scrappers/spain.py | 1 - setup.py | 2 +- 9 files changed, 53 insertions(+), 45 deletions(-) create mode 100644 covid_data/__main__.py rename covid_data/{covid_data => covid_data.py} (86%) create mode 100644 covid_data/scrappers/__init__.py diff --git a/covid_data/__init__.py b/covid_data/__init__.py index 469d8da..d744f48 100644 --- a/covid_data/__init__.py +++ b/covid_data/__init__.py @@ -1,4 +1,6 @@ +import covid_data # noqa: F401 from covid_data import db as db # noqa: F401 from covid_data import types as types # noqa: F401 +from covid_data.covid_data import entrypoint as script_entrypoint # noqa: F401 from covid_data.db import queries as queries # noqa: F401 from covid_data.utils import places as places_utils # noqa: F401 diff --git a/covid_data/__main__.py b/covid_data/__main__.py new file mode 100644 index 0000000..fc72779 --- /dev/null +++ b/covid_data/__main__.py @@ -0,0 +1,4 @@ +from covid_data import script_entrypoint + +if __name__ == "__main__": + script_entrypoint() diff --git a/covid_data/commands/scrap_country.py b/covid_data/commands/scrap_country.py index 35bd8c5..d0b8247 100644 --- a/covid_data/commands/scrap_country.py +++ b/covid_data/commands/scrap_country.py @@ -2,9 +2,10 @@ from contextlib import ExitStack from datetime import datetime from importlib import import_module -from typing import Any +from typing import Any, List import click + from covid_data.db import close_db, get_db @@ -19,10 +20,6 @@ def main(country: str, check: bool = False, start_date: str = ""): """Scrap cases of chosen COUNTRY. To check available countries to scrap use --check""" with ExitStack() as stack: - db = get_db() - - stack.push(close_db(db)) - base_path = os.path.join(os.path.dirname(__file__), "../scrappers") files = os.listdir(base_path) @@ -30,22 +27,31 @@ def main(country: str, check: bool = False, start_date: str = ""): click.echo("Available countries are:") for file_name in files: - if file_name.startswith("test"): + if ( + file_name.startswith("test") + or file_name.startswith("__") + or os.path.isdir(os.path.join(base_path, file_name)) + ): continue - handler_module, _ = os.path.splitext(file_name) + if check: + click.echo(f"\t{file_name.replace('.py', '').capitalize()}") + else: + handler_module, _ = os.path.splitext(file_name) + + module = import_module(f".{handler_module}", "covid_data.scrappers") - module = import_module(f".{handler_module}", "covid_data.scrappers") + if not hasattr(module, "scrap"): + continue - if not hasattr(module, "scrap"): - continue + if handler_module == country.lower(): + db = get_db() - if check: - click.echo(f"\t{file_name.replace('.py', '').capitalize()}") - elif handler_module == country.lower(): - args: list[Any] = [db] + stack.push(close_db(db)) + + args: List[Any] = [db] - if start_date: - args.append(datetime.strptime(start_date, "%d/%m/%Y")) + if start_date: + args.append(datetime.strptime(start_date, "%d/%m/%Y")) - module.scrap(*args) # type: ignore + module.scrap(*args) # type: ignore diff --git a/covid_data/covid_data b/covid_data/covid_data.py similarity index 86% rename from covid_data/covid_data rename to covid_data/covid_data.py index a79edd5..06c418d 100644 --- a/covid_data/covid_data +++ b/covid_data/covid_data.py @@ -17,7 +17,7 @@ def cli(): pass -if __name__ == "__main__": +def entrypoint(): base_path = os.path.join(os.path.dirname(__file__), "commands") files = os.listdir(base_path) @@ -27,7 +27,7 @@ def cli(): handler_module, _ = os.path.splitext(file_name) - module = import_module(f".{handler_module}", "commands") + module = import_module(f".{handler_module}", "covid_data.commands") if not hasattr(module, "main"): continue diff --git a/covid_data/db/queries.py b/covid_data/db/queries.py index df914ff..778243c 100644 --- a/covid_data/db/queries.py +++ b/covid_data/db/queries.py @@ -64,8 +64,8 @@ def get_place_by_property( def get_countries_id_by_alpha2( - alpha2_codes: list[str], engine: connection -) -> list[int]: + alpha2_codes: List[str], engine: connection +) -> List[int]: with engine.cursor() as cur: cur: cursor @@ -206,21 +206,21 @@ def get_cases_by_country( def get_cases_by_province( - provinces_id: list[int], engine: connection, case_type: CaseType = None + provinces_id: List[int], engine: connection, case_type: CaseType = None ) -> List[Dict]: return get_cases_by_filters(engine, provinces_id=provinces_id, case_type=case_type) def get_cases_by_filters_query( - countries_id: list[int] = None, - provinces_id: list[int] = None, + countries_id: List[int] = None, + provinces_id: List[int] = None, date: datetime = None, date_lte: datetime = None, date_gte: datetime = None, case_type: CaseType = None, - aggregation: list[Aggregations] = [], + aggregation: List[Aggregations] = [], limit: int = None, - sort: list[str] = [], + sort: List[str] = [], ) -> Dict[str, Any]: params = [] constraints = [] @@ -315,13 +315,13 @@ def get_cases_by_filters_query( def get_cases_by_filters( engine: connection, - countries_id: list[int] = None, - provinces_id: list[int] = None, + countries_id: List[int] = None, + provinces_id: List[int] = None, date: datetime = None, date_lte: datetime = None, date_gte: datetime = None, case_type: CaseType = None, - aggregation: list[Aggregations] = [], + aggregation: List[Aggregations] = [], limit: int = None, sort: list = [], ) -> List[Dict]: @@ -354,7 +354,7 @@ def get_cum_cases_by_date( date_lte: datetime = None, date_gte: datetime = None, case_type: CaseType = None, - countries: list[int] = None, + countries: List[int] = None, ) -> List[Dict]: params = [] @@ -418,13 +418,13 @@ def get_cum_cases_by_date( def get_cum_cases_by_date_country( engine: connection, country_id: int, - provinces_id: list[int] = [], + provinces_id: List[int] = [], date: datetime = None, date_lte: datetime = None, date_gte: datetime = None, case_type: CaseType = None, ) -> List[Dict]: - params: list[Any] = [country_id] + params: List[Any] = [country_id] inner_query = sql.SQL( ( @@ -490,7 +490,7 @@ def get_cum_cases_by_country( date_lte: datetime = None, date_gte: datetime = None, case_type: CaseType = None, - countries_id: list[int] = [], + countries_id: List[int] = [], ) -> List[Dict]: params = [] @@ -557,9 +557,9 @@ def get_cum_cases_by_province( date_gte: datetime = None, case_type: CaseType = None, country_id: int = None, - provinces_id: list[str] = None, + provinces_id: List[str] = None, ) -> List[Dict]: - params: list[Any] = [country_id] + params: List[Any] = [country_id] inner_query = sql.SQL( ( @@ -650,8 +650,8 @@ def create_case( def get_all_countries( - engine: connection, name: str = None, near: list[float] = [] -) -> list[dict]: + engine: connection, name: str = None, near: List[float] = [] +) -> List[dict]: with engine.cursor() as cur: cur: cursor @@ -675,7 +675,7 @@ def get_all_countries( return cur.fetchall() # type: ignore -def get_all_provinces(engine: connection) -> list[dict]: +def get_all_provinces(engine: connection) -> List[dict]: with engine.cursor() as cur: cur: cursor @@ -695,7 +695,7 @@ def get_all_provinces(engine: connection) -> list[dict]: return cur.fetchall() # type: ignore -def get_provinces_by_country(engine: connection, country_id: int) -> list[dict]: +def get_provinces_by_country(engine: connection, country_id: int) -> List[dict]: with engine.cursor() as cur: cur: cursor diff --git a/covid_data/scrappers/__init__.py b/covid_data/scrappers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/covid_data/scrappers/france.py b/covid_data/scrappers/france.py index 663bd71..cdff393 100644 --- a/covid_data/scrappers/france.py +++ b/covid_data/scrappers/france.py @@ -1,7 +1,6 @@ import datetime import json import logging -import os import click import requests @@ -48,11 +47,9 @@ def scrap(engine: connection, start_date: datetime.datetime = START_DATE) -> Non if response.status_code > 399: message = f"Error fetching data for date {curr_date}" logger.error(message) - click.echo(message) logger.error(response) - click.echo(response) - raise ClickException(message) + continue data = json.loads(response.text) diff --git a/covid_data/scrappers/spain.py b/covid_data/scrappers/spain.py index 5406186..e1206b1 100644 --- a/covid_data/scrappers/spain.py +++ b/covid_data/scrappers/spain.py @@ -1,6 +1,5 @@ import json import logging -import os from datetime import datetime import click diff --git a/setup.py b/setup.py index 86405fc..ccf3578 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ name="covid_data", entry_points={ "console_scripts": [ - "covid_data = covid_data:covid_data", + "covid-data = covid_data:script_entrypoint", ], }, packages=find_packages(),