Skip to content

Commit

Permalink
fix: Fixes across package to allow CLI usage
Browse files Browse the repository at this point in the history
  • Loading branch information
alesanmed committed Aug 16, 2021
1 parent 849e19e commit 25d6db9
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 45 deletions.
2 changes: 2 additions & 0 deletions covid_data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import covid_data # noqa: F401
from covid_data import db as db # noqa: F401
from covid_data import types as types # noqa: F401
from covid_data.covid_data import entrypoint as script_entrypoint # noqa: F401
from covid_data.db import queries as queries # noqa: F401
from covid_data.utils import places as places_utils # noqa: F401
4 changes: 4 additions & 0 deletions covid_data/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from covid_data import script_entrypoint

if __name__ == "__main__":
script_entrypoint()
40 changes: 23 additions & 17 deletions covid_data/commands/scrap_country.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from contextlib import ExitStack
from datetime import datetime
from importlib import import_module
from typing import Any
from typing import Any, List

import click

from covid_data.db import close_db, get_db


Expand All @@ -19,33 +20,38 @@
def main(country: str, check: bool = False, start_date: str = ""):
"""Scrap cases of chosen COUNTRY. To check available countries to scrap use --check"""
with ExitStack() as stack:
db = get_db()

stack.push(close_db(db))

base_path = os.path.join(os.path.dirname(__file__), "../scrappers")
files = os.listdir(base_path)

if check:
click.echo("Available countries are:")

for file_name in files:
if file_name.startswith("test"):
if (
file_name.startswith("test")
or file_name.startswith("__")
or os.path.isdir(os.path.join(base_path, file_name))
):
continue

handler_module, _ = os.path.splitext(file_name)
if check:
click.echo(f"\t{file_name.replace('.py', '').capitalize()}")
else:
handler_module, _ = os.path.splitext(file_name)

module = import_module(f".{handler_module}", "covid_data.scrappers")

module = import_module(f".{handler_module}", "covid_data.scrappers")
if not hasattr(module, "scrap"):
continue

if not hasattr(module, "scrap"):
continue
if handler_module == country.lower():
db = get_db()

if check:
click.echo(f"\t{file_name.replace('.py', '').capitalize()}")
elif handler_module == country.lower():
args: list[Any] = [db]
stack.push(close_db(db))

args: List[Any] = [db]

if start_date:
args.append(datetime.strptime(start_date, "%d/%m/%Y"))
if start_date:
args.append(datetime.strptime(start_date, "%d/%m/%Y"))

module.scrap(*args) # type: ignore
module.scrap(*args) # type: ignore
4 changes: 2 additions & 2 deletions covid_data/covid_data → covid_data/covid_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def cli():
pass


if __name__ == "__main__":
def entrypoint():
base_path = os.path.join(os.path.dirname(__file__), "commands")
files = os.listdir(base_path)

Expand All @@ -27,7 +27,7 @@ def cli():

handler_module, _ = os.path.splitext(file_name)

module = import_module(f".{handler_module}", "commands")
module = import_module(f".{handler_module}", "covid_data.commands")

if not hasattr(module, "main"):
continue
Expand Down
40 changes: 20 additions & 20 deletions covid_data/db/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_place_by_property(


def get_countries_id_by_alpha2(
alpha2_codes: list[str], engine: connection
) -> list[int]:
alpha2_codes: List[str], engine: connection
) -> List[int]:
with engine.cursor() as cur:
cur: cursor

Expand Down Expand Up @@ -206,21 +206,21 @@ def get_cases_by_country(


def get_cases_by_province(
provinces_id: list[int], engine: connection, case_type: CaseType = None
provinces_id: List[int], engine: connection, case_type: CaseType = None
) -> List[Dict]:
return get_cases_by_filters(engine, provinces_id=provinces_id, case_type=case_type)


def get_cases_by_filters_query(
countries_id: list[int] = None,
provinces_id: list[int] = None,
countries_id: List[int] = None,
provinces_id: List[int] = None,
date: datetime = None,
date_lte: datetime = None,
date_gte: datetime = None,
case_type: CaseType = None,
aggregation: list[Aggregations] = [],
aggregation: List[Aggregations] = [],
limit: int = None,
sort: list[str] = [],
sort: List[str] = [],
) -> Dict[str, Any]:
params = []
constraints = []
Expand Down Expand Up @@ -315,13 +315,13 @@ def get_cases_by_filters_query(

def get_cases_by_filters(
engine: connection,
countries_id: list[int] = None,
provinces_id: list[int] = None,
countries_id: List[int] = None,
provinces_id: List[int] = None,
date: datetime = None,
date_lte: datetime = None,
date_gte: datetime = None,
case_type: CaseType = None,
aggregation: list[Aggregations] = [],
aggregation: List[Aggregations] = [],
limit: int = None,
sort: list = [],
) -> List[Dict]:
Expand Down Expand Up @@ -354,7 +354,7 @@ def get_cum_cases_by_date(
date_lte: datetime = None,
date_gte: datetime = None,
case_type: CaseType = None,
countries: list[int] = None,
countries: List[int] = None,
) -> List[Dict]:
params = []

Expand Down Expand Up @@ -418,13 +418,13 @@ def get_cum_cases_by_date(
def get_cum_cases_by_date_country(
engine: connection,
country_id: int,
provinces_id: list[int] = [],
provinces_id: List[int] = [],
date: datetime = None,
date_lte: datetime = None,
date_gte: datetime = None,
case_type: CaseType = None,
) -> List[Dict]:
params: list[Any] = [country_id]
params: List[Any] = [country_id]

inner_query = sql.SQL(
(
Expand Down Expand Up @@ -490,7 +490,7 @@ def get_cum_cases_by_country(
date_lte: datetime = None,
date_gte: datetime = None,
case_type: CaseType = None,
countries_id: list[int] = [],
countries_id: List[int] = [],
) -> List[Dict]:
params = []

Expand Down Expand Up @@ -557,9 +557,9 @@ def get_cum_cases_by_province(
date_gte: datetime = None,
case_type: CaseType = None,
country_id: int = None,
provinces_id: list[str] = None,
provinces_id: List[str] = None,
) -> List[Dict]:
params: list[Any] = [country_id]
params: List[Any] = [country_id]

inner_query = sql.SQL(
(
Expand Down Expand Up @@ -650,8 +650,8 @@ def create_case(


def get_all_countries(
engine: connection, name: str = None, near: list[float] = []
) -> list[dict]:
engine: connection, name: str = None, near: List[float] = []
) -> List[dict]:
with engine.cursor() as cur:
cur: cursor

Expand All @@ -675,7 +675,7 @@ def get_all_countries(
return cur.fetchall() # type: ignore


def get_all_provinces(engine: connection) -> list[dict]:
def get_all_provinces(engine: connection) -> List[dict]:
with engine.cursor() as cur:
cur: cursor

Expand All @@ -695,7 +695,7 @@ def get_all_provinces(engine: connection) -> list[dict]:
return cur.fetchall() # type: ignore


def get_provinces_by_country(engine: connection, country_id: int) -> list[dict]:
def get_provinces_by_country(engine: connection, country_id: int) -> List[dict]:
with engine.cursor() as cur:
cur: cursor

Expand Down
Empty file.
5 changes: 1 addition & 4 deletions covid_data/scrappers/france.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import datetime
import json
import logging
import os

import click
import requests
Expand Down Expand Up @@ -48,11 +47,9 @@ def scrap(engine: connection, start_date: datetime.datetime = START_DATE) -> Non
if response.status_code > 399:
message = f"Error fetching data for date {curr_date}"
logger.error(message)
click.echo(message)
logger.error(response)
click.echo(response)

raise ClickException(message)
continue

data = json.loads(response.text)

Expand Down
1 change: 0 additions & 1 deletion covid_data/scrappers/spain.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import os
from datetime import datetime

import click
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
name="covid_data",
entry_points={
"console_scripts": [
"covid_data = covid_data:covid_data",
"covid-data = covid_data:script_entrypoint",
],
},
packages=find_packages(),
Expand Down

0 comments on commit 25d6db9

Please sign in to comment.