Skip to content

Commit 042f659

Browse files
committed
WIP Convert get_shapes to use SQLAlchemy ORM, query builder
1 parent 9e68de8 commit 042f659

File tree

4 files changed

+52
-16
lines changed

4 files changed

+52
-16
lines changed

_shared_utils/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ pytest-recording (>=0.13.4,<0.14.0)
1313
pytest-unordered (>=0.7.0,<0.8.0)
1414
quarto==0.1.0
1515
quarto-cli==1.6.40
16+
sqlalchemy==1.4.46
17+
sqlalchemy-bigquery==1.11.0
1618
vegafusion==2.0.2
1719
vl-convert-python>=1.6.0
1820
movingpandas==0.22.4

_shared_utils/shared_utils/gtfs_utils_v2.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from shared_utils import DBSession, schedule_rt_utils
1818
from shared_utils.models.dim_gtfs_dataset import DimGtfsDataset
1919
from shared_utils.models.fct_daily_schedule_feeds import FctDailyScheduleFeeds
20+
from shared_utils.models.fct_daily_scheduled_shapes import FctDailyScheduledShapes
2021
from shared_utils.models.fct_scheduled_trips import FctScheduledTrips
2122
from siuba import *
2223
from sqlalchemy import and_, create_engine, func, or_, select
@@ -407,7 +408,7 @@ def get_shapes(
407408
get_df: bool = True,
408409
crs: str = geography_utils.WGS84,
409410
custom_filtering: dict = None,
410-
) -> gpd.GeoDataFrame:
411+
) -> Union[gpd.GeoDataFrame | sqlalchemy.sql.selectable.Select]:
411412
"""
412413
Query fct_daily_scheduled_shapes.
413414
@@ -416,21 +417,15 @@ def get_shapes(
416417
"""
417418
check_operator_feeds(operator_feeds)
418419

419-
# If pt_array is not kept in the final, we still need it
420-
# to turn this into a gdf
421-
if "pt_array" not in shape_cols:
422-
shape_cols_with_geom = shape_cols + ["pt_array"]
423-
elif shape_cols:
424-
shape_cols_with_geom = shape_cols[:]
420+
search_conditions = [
421+
FctDailyScheduledShapes.service_date == selected_date,
422+
FctDailyScheduledShapes.feed_key.in_(operator_feeds),
423+
]
425424

426-
tables = _get_tables()
425+
for k, v in (custom_filtering or {}).items():
426+
search_conditions.append(getattr(FctDailyScheduledShapes, k).in_(v))
427427

428-
shapes = (
429-
getattr(tables, "test_shared_utils").fct_daily_scheduled_shapes()
430-
>> filter_date(selected_date, date_col="service_date")
431-
>> filter_operator(operator_feeds, include_name=False)
432-
>> filter_custom_col(custom_filtering)
433-
)
428+
shapes = select(FctDailyScheduledShapes).where(and_(*search_conditions))
434429

435430
if get_df:
436431
shapes = shapes >> collect()
@@ -442,7 +437,12 @@ def get_shapes(
442437
return shapes_gdf
443438

444439
else:
445-
return shapes >> subset_cols(shape_cols_with_geom)
440+
columns = {FctDailyScheduledShapes.pt_array}
441+
442+
for column in shape_cols:
443+
columns.add(getattr(FctDailyScheduledShapes, column))
444+
445+
return shapes.with_only_columns(*list(columns))
446446

447447

448448
def get_stops(
Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,27 @@
1-
# mart_gtfs.fct_daily_scheduled_shapes
1+
from shared_utils.models.base import get_table_name
2+
from sqlalchemy import Boolean, Column, Date, DateTime, Integer, String
3+
from sqlalchemy.orm import declarative_base, declared_attr
4+
from sqlalchemy_bigquery import GEOGRAPHY
5+
6+
Base = declarative_base()
7+
8+
9+
class FctDailyScheduledShapes(Base):
10+
dataset = "mart_gtfs"
11+
table = "fct_daily_scheduled_shapes"
12+
13+
@declared_attr
14+
def __tablename__(cls):
15+
return get_table_name(cls.dataset, cls.table)
16+
17+
key = Column(String, primary_key=True)
18+
feed_key = Column(String)
19+
service_date = Column(Date)
20+
shape_id = Column(String)
21+
shape_array_key = Column(String)
22+
feed_timezone = Column(String)
23+
n_trips = Column(Integer)
24+
shape_first_departure_datetime_pacific = Column(DateTime)
25+
shape_last_arrival_datetime_pacific = Column(DateTime)
26+
contains_warning_duplicate_trip_primary_key = Column(Boolean)
27+
pt_array = Column(GEOGRAPHY)

_shared_utils/tests/shared_utils/test_gtfs_utils_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,3 +731,11 @@ def test_get_shapes_custom_filtering(self):
731731
def test_get_shapes_no_operator_feeds(self):
732732
with pytest.raises(ValueError, match="Supply list of feed keys or operator names!"):
733733
get_shapes(selected_date="2025-09-19")
734+
735+
def test_get_shapes_get_df_false(self):
736+
result = get_shapes(
737+
selected_date="2025-09-01", operator_feeds=["3ea60aa240ddc543da5415ccc759fd6d"], get_df=False
738+
)
739+
740+
assert isinstance(result, sqlalchemy.sql.selectable.Select)
741+
# assert result select includes pt_array

0 commit comments

Comments
 (0)