Skip to content

Commit

Permalink
GTFS enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
Benyamin Ginzburg committed Jun 26, 2021
1 parent c1ea700 commit 86bc61f
Show file tree
Hide file tree
Showing 16 changed files with 217 additions and 177 deletions.
74 changes: 0 additions & 74 deletions bus_api/gtfs/sql.py

This file was deleted.

6 changes: 0 additions & 6 deletions bus_api/siri/client.py

This file was deleted.

83 changes: 0 additions & 83 deletions bus_api/siri/models.py

This file was deleted.

File renamed without changes.
2 changes: 1 addition & 1 deletion bus_api/config.py → israel_transport_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

ROOT_PATH = os.getenv("ROOT_PATH")

DB_DSN = DSN = f'dbname={os.getenv("DB_NAME", "bus_api")} ' \
DB_DSN = DSN = f'dbname={os.getenv("DB_NAME", "israel_transport_api")} ' \
f'user={os.getenv("DB_USER"), "bus_api_admin"} ' \
f'password={os.getenv("DB_PASS"), ""} ' \
f'host={os.getenv("DB_HOST"), "localhost"} ' \
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions israel_transport_api/gtfs/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class GtfsFileNotFound(Exception):
pass
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import asyncio
import csv
import io
import os
import ftplib
import zipfile
import logging
import tempfile
import pathlib
from typing import Dict, Tuple

import aioftp

from bus_api.config import GTFS_URL
from bus_api.misс import scheduler, daily_trigger
from israel_transport_api.config import GTFS_URL
from israel_transport_api.gtfs.exceptions import GtfsFileNotFound
from israel_transport_api.gtfs.models import Route

logging.basicConfig(level=logging.DEBUG)
GTFS_FP = '../gtfs_data'
GTFS_FP = pathlib.Path(__file__).parent.parent.parent / 'gtfs_data'
GTFS_FILES = [
'agency.txt',
'calendar.txt',
Expand All @@ -27,6 +27,7 @@
]

logger = logging.getLogger(__name__)
ROUTES: Dict[Tuple[int, int], Route] = {}


async def download_gtfs_data() -> io.BytesIO:
Expand All @@ -49,10 +50,31 @@ async def save_gtfs_data():

logger.debug(f'Saving files to {GTFS_FP}...')
with zipfile.ZipFile(gtfs_data_io) as zip_file:
if not os.path.exists(GTFS_FP):
os.mkdir(GTFS_FP)
if not GTFS_FP.exists():
GTFS_FP.mkdir()

zip_file.extractall(GTFS_FP)
logger.debug('Done.')


scheduler.add_job(save_gtfs_data, trigger=daily_trigger)
async def _store_db_data():
...


def _store_memory_data():
fp = GTFS_FP / 'routes.txt'
if not fp.exists():
raise GtfsFileNotFound('File routes.txt not found!')

with open(fp, 'r', encoding='utf-8') as file:
reader = csv.reader(file)
...


async def load_gtfs_data():
if not (GTFS_FP / 'routes.txt').exists():
await save_gtfs_data()
# scheduler.add_job(save_gtfs_data, trigger=daily_trigger)

_store_memory_data()

13 changes: 13 additions & 0 deletions israel_transport_api/gtfs/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel


class Route:
id: int
agency_id: int
short_name: str
long_name: str
description: str
type: str
color: str


15 changes: 15 additions & 0 deletions israel_transport_api/gtfs/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
init_query = '''
CREATE TABLE IF NOT EXISTS stops
(
id INT PRIMARY KEY,
code INT UNIQUE,
name TEXT NOT NULL,
street TEXT NULL,
city TEXT NOT NULL,
platform TEXT NULL,
floor TEXT NULL,
location POINT NOT NULL,
zone_id TEXT NULL
);
'''
25 changes: 25 additions & 0 deletions israel_transport_api/gtfs/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Tuple, List, Optional


def parse_stop_description(s: str) -> List[Optional[str]]:
parts = s.split(':')
values = []
for value in parts[1:-1]:
value = value.rsplit(' ', maxsplit=1)[0].strip()
if value:
values.append(value)
else:
values.append(None)
values.append(parts[-1].strip() or None)

return values


def parse_route_long_description(s: str) -> Tuple[str, str, str, str]:
from_, to = s.split('<->')
*from_stop_name, from_city = from_.split('-')
from_stop_name = ' - '.join(from_stop_name)

to_stop_name, to_city = to.split('-')[:2]

return from_stop_name, from_city, to_stop_name, to_city
6 changes: 3 additions & 3 deletions bus_api/main.py → israel_transport_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import betterlogging as logging
from fastapi import FastAPI

from bus_api.config import ROOT_PATH
from bus_api.misс import scheduler
from bus_api.gtfs.gtfs_retriever import save_gtfs_data
from israel_transport_api.config import ROOT_PATH
from israel_transport_api.misс import scheduler
from israel_transport_api.gtfs.gtfs_retriever import save_gtfs_data


logging.basic_colorized_config(level=logging.INFO)
Expand Down
File renamed without changes.
File renamed without changes.
51 changes: 51 additions & 0 deletions israel_transport_api/siri/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging
from typing import List

from httpx import AsyncClient
from pydantic import parse_obj_as

from israel_transport_api.config import SIRI_URL, API_KEY
from israel_transport_api.siri.models import MonitoredStopVisit

logging.basicConfig(level=logging.DEBUG)


class SiriClient:
http_client: AsyncClient

def __init__(self):
self.http_client = AsyncClient()

async def _make_request(self, station_id: int) -> List[MonitoredStopVisit]:
params = {
'Key': API_KEY,
'MonitoringRef': station_id
}

resp = await self.http_client.get(SIRI_URL, params=params)
raw_data: dict = resp.json()
raw_stop_data: List[dict] = raw_data.get('Siri', {}).get('ServiceDelivery', {}).get('StopMonitoringDelivery', [])

if len(raw_stop_data) == 0:
print('no data')
raise ValueError()

if raw_stop_data[0]['Status'] != 'true':
print('error', raw_stop_data)
raise ValueError()

parsed_data = parse_obj_as(List[MonitoredStopVisit], raw_stop_data[0]['MonitoredStopVisit'])
return parsed_data




import asyncio


async def m():
c = SiriClient()
resp = await c._make_request('32372')
print(resp)

asyncio.run(m())
Loading

0 comments on commit 86bc61f

Please sign in to comment.