Skip to content

Commit

Permalink
Use postgresql to store GTFS data;
Browse files Browse the repository at this point in the history
Python 3.12;
Refactor and little enhancements.
  • Loading branch information
Benyamin Ginzburg committed Aug 1, 2024
1 parent 870581f commit 44c63f3
Show file tree
Hide file tree
Showing 24 changed files with 1,390 additions and 348 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,4 @@ venv.bak/
Pipfile.lock
/gtfs_data/*
/gtfs_data/
/.pdm-python
4 changes: 2 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
FROM python:3.9-slim

RUN pip install pipenv
RUN pip install pdm
WORKDIR /home/app
COPY . .
WORKDIR /home/app/israel_transport_api
Expand All @@ -9,4 +9,4 @@ ENV TZ=Asia/Jerusalem
ENV PYTHONPATH=/home/app
ENV DOCKER_MODE=true
EXPOSE 8000
CMD ["pipenv", "run", "python", "main.py"]
CMD ["pdm", "run", "python", "main.py"]
21 changes: 0 additions & 21 deletions Pipfile

This file was deleted.

1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# bus-api
2 changes: 1 addition & 1 deletion israel_transport_api/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = '0.0.1'
version = '0.1.1'
21 changes: 9 additions & 12 deletions israel_transport_api/config.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from pydantic import BaseSettings, Field
from pydantic_settings import BaseSettings, SettingsConfigDict


class Env(BaseSettings):

class Config:
env_file = '../.env'
env_file_encoding = 'utf-8'
model_config = SettingsConfigDict(env_file='../.env', env_file_encoding='utf-8')

SIRI_URL: str = Field(..., env='SIRI_URL')
GTFS_URL: str = Field(..., env='GTFS_FTP_URL')
API_KEY: str = Field(..., env='API_KEY')
ROOT_PATH: str = Field('', env='ROOT_PATH')
DB_URL: str = Field('localhost', env='DB_URL')
DB_NAME: str = Field(..., env='DB_NAME')
SCHED_HOURS: int = Field(..., env='SCHED_HOURS')
SCHED_MINS: int = Field(..., env='SCHED_MINS')
SIRI_URL: str
GTFS_FTP_URL: str
API_KEY: str
ROOT_PATH: str = '/'
DB_DSN: str
SCHED_HOURS: int
SCHED_MINS: int


env = Env()
1 change: 0 additions & 1 deletion israel_transport_api/gtfs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from israel_transport_api.gtfs.repository.stops_repository import init_db
from .gtfs_retriever import init_gtfs_data
from .router import stops_router, routes_router

124 changes: 15 additions & 109 deletions israel_transport_api/gtfs/gtfs_retriever.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
import asyncio
import csv
import ftplib
import io
import logging
import pathlib
import zipfile
from typing import TextIO

from shapely.geometry import Point
from psycopg.connection_async import AsyncConnection

from israel_transport_api.config import env
from israel_transport_api.gtfs.exceptions import GtfsFileNotFound
from israel_transport_api.gtfs.models import Route
from israel_transport_api.gtfs.repository import stops_repository, routes_repository
from israel_transport_api.gtfs.utils import parse_route_long_name, parse_stop_description
from israel_transport_api.gtfs.repository import db_updater
from israel_transport_api.gtfs.repository import db_loader
from israel_transport_api.gtfs.repository.init_sql import init_db

GTFS_FP = pathlib.Path(__file__).parent.parent.parent / 'gtfs_data'
GTFS_FILES = [
Expand All @@ -34,9 +29,9 @@


async def _download_gtfs_data_from_ftp() -> io.BytesIO:
logger.debug(f'Trying to establish ftp connection with {env.GTFS_URL}...')
logger.debug(f'Trying to establish ftp connection with {env.GTFS_FTP_URL}...')

ftp = ftplib.FTP(env.GTFS_URL)
ftp = ftplib.FTP(env.GTFS_FTP_URL)
ftp.login()

bio = io.BytesIO()
Expand All @@ -60,119 +55,30 @@ async def _download_gtfs_data():
logger.debug('Done.')


def _process_stops_file(fio: TextIO) -> list[list]:
reader = csv.reader(fio)
next(reader, None) # skip headers

stops_to_save = []
for row in reader:
try:
street, city, platform, floor = parse_stop_description(row[3])
except (ValueError, IndexError):
msg = f'Failed to parse stop description. Row: {row}'
# logger.exception(msg)
continue

stops_to_save.append([
int(row[0]),
int(row[1]),
row[2],
street,
city,
platform,
floor,
Point(float(row[4]), float(row[5])),
row[6],
int(row[7]) if row[7] else None,
row[8]
])
return stops_to_save


def _process_routes_file(fio: TextIO) -> list[Route]:
reader = csv.reader(fio)
next(reader, None) # skip headers
routes = []

for row in reader:
try:
from_stop_name, from_city, to_stop_name, to_city = parse_route_long_name(row[3])
except (ValueError, IndexError):
msg = f'Failed to parse route long name. Row: {row}'
logger.exception(msg)
continue

routes.append(Route(
id=(row[0]),
agency_id=row[1],
short_name=row[2],
from_stop_name=from_stop_name,
to_stop_name=to_stop_name,
from_city=from_city,
to_city=to_city,
description=row[4],
type=row[5],
color=row[6]
))
return routes


async def _store_db_data(session):
logger.debug('Loading stops to database...')
fp = GTFS_FP / 'stops.txt'
if not fp.exists():
raise GtfsFileNotFound('File stops.txt not found!')

with open(fp, 'r', encoding='utf-8') as file:
stops = _process_stops_file(file)
await db_updater.update_stops(stops, session)
# await stops_repository.save_stops(stops_to_save)
logger.debug('Done.')


def _load_memory_data():
logger.debug('Loading routes to memory storage...')
fp = GTFS_FP / 'routes.txt'
if not fp.exists():
raise GtfsFileNotFound('File routes.txt not found!')

with open(fp, 'r', encoding='utf-8') as file:

routes_repository.save_route(route)
logger.debug('Done.')
async def _store_db_data(session, force_load: bool = False):
await init_db(session)
await db_loader.load_agencies(session, force_load)
await db_loader.load_stops(session, force_load)
await db_loader.load_routes(session, force_load)
await db_loader.load_trips(session, force_load)
await db_loader.load_stop_times(session, force_load)


async def init_gtfs_data(force_download: bool = False):
async def init_gtfs_data(conn: AsyncConnection, force_download: bool = False):
logger.info(f'Data initialization started {"with" if force_download else "without"} downloading files...')
if force_download or (not (GTFS_FP / 'routes.txt').exists() or not (GTFS_FP / 'stops.txt').exists()):
n_retries = 5
for i in range(n_retries, 0, -1):
try:
logger.debug(f'Trying to download data, {i} tries remain...')
await _download_gtfs_data()
await _store_db_data()
await _store_db_data(conn)
except Exception as e:
logger.exception(e)
await asyncio.sleep(10)
else:
break

_load_memory_data()

count = await stops_repository.get_stops_count()
logger.info(f'There are {count} documents in stops collection.')
if count == 0:
await _store_db_data()
await _store_db_data(conn, force_download)

logger.info('Data initialization completed!')


async def test():
import psycopg
conn = await psycopg.AsyncConnection.connect('host=localhost port=5432 dbname=gtfs user=postgres password=')
await _store_db_data(conn)

import sys
if sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
asyncio.run(test())
63 changes: 45 additions & 18 deletions israel_transport_api/gtfs/models.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,67 @@
from abc import ABC
from typing import Optional, Tuple

from odmantic import Field, Model, EmbeddedModel
from pydantic import BaseModel


class Route(BaseModel):
id: str
agency_id: str
id: int
agency_id: int
short_name: str
from_stop_name: str
to_stop_name: str
from_city: str
to_city: str
description: str
type: str
type: int
color: str

@classmethod
def from_row(cls, row: list) -> 'Route':
return cls(
id=row[0],
agency_id=row[1],
short_name=row[2],
from_stop_name=row[3],
to_stop_name=row[4],
from_city=row[5],
to_city=row[6],
description=row[7],
type=row[8],
color=row[9]
)


class StopLocation(EmbeddedModel, ABC):
class StopLocation(BaseModel):
type: str = 'Point'
coordinates: Tuple[float, float]
coordinates: tuple[float, float]


class Stop(Model, ABC):
id: int = Field(..., primary_field=True)
class Stop(BaseModel):
id: int
code: int
name: str
city: str
street: Optional[str] = None
floor: Optional[str] = None
platform: Optional[str] = None
city: str | None = None
street: str | None = None
floor: str | None = None
platform: int | None = None
location: StopLocation
location_type: str
parent_station_id: Optional[str] = None
zone_id: Optional[str] = None
location_type: int
parent_station_id: int | None = None
zone_id: int | None = None

@classmethod
def from_row(cls, row: list) -> 'Stop':
return cls(
id=row[0],
code=row[2],
name=row[6],
city=row[1],
street=row[9],
floor=row[3],
platform=row[8],
location=StopLocation(coordinates=(row[11], row[12])),
location_type=row[5],
parent_station_id=row[7],
zone_id=row[10]
)

class Config:
collection = 'stops'
Expand Down
Loading

0 comments on commit 44c63f3

Please sign in to comment.