Skip to content
29 changes: 15 additions & 14 deletions altair/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@

from typing import TYPE_CHECKING

from altair.datasets._loader import Loader as Loader
from altair.datasets._loader import Loader

if TYPE_CHECKING:
import sys
Expand All @@ -92,6 +92,8 @@
from altair.datasets._loader import _Load
from altair.datasets._typing import Dataset, Extension

__all__ = ["Loader", "data", "load", "url"]


load: _Load[Any, Any]
"""
Expand Down Expand Up @@ -166,18 +168,17 @@ def url(
return url


def __getattr__(name):
if name == "data":
from altair.datasets._data import data
if not TYPE_CHECKING:
Comment thread
dangotbanned marked this conversation as resolved.

return data
elif name == "load":
from altair.datasets._loader import load
def __getattr__(name):
if name == "data":
from altair.datasets._data import data

return load
elif name == "__all__":
# Define __all__ dynamically to avoid ruff errors
return ["Loader", "data", "load", "url"]
else:
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
return data
elif name == "load":
from altair.datasets._loader import load

return load
else:
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)
2 changes: 1 addition & 1 deletion altair/datasets/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def schema_kwds(self, meta: Metadata, /) -> dict[str, Any]:
# For pyarrow CSV reading, use the schema as intended
# This will fail for non-ISO date formats, but that's the correct behavior
# Users can handle this by using a different backend or converting dates manually
return {"convert_options": ConvertOptions(column_types=schema)} # pyright: ignore[reportCallIssue]
return {"convert_options": ConvertOptions(column_types=schema)}
elif suffix == ".parquet":
return {"schema": schema}

Expand Down
18 changes: 7 additions & 11 deletions altair/datasets/_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import typing as t
from typing import Generic, final, overload

from narwhals.stable.v1.typing import IntoDataFrameT

from altair.datasets import _reader
from altair.datasets._reader import IntoFrameT
from altair.datasets._reader import IntoDataFrameT, IntoLazyFrameT

if t.TYPE_CHECKING:
import sys
Expand All @@ -30,7 +28,7 @@
__all__ = ["Loader", "load"]


class Loader(Generic[IntoDataFrameT, IntoFrameT]):
class Loader(Generic[IntoDataFrameT, IntoLazyFrameT]):
"""
Load example datasets *remotely* from `vega-datasets`_, with caching.

Expand All @@ -46,7 +44,7 @@ class Loader(Generic[IntoDataFrameT, IntoFrameT]):
https://github.com/vega/vega-datasets
"""

_reader: Reader[IntoDataFrameT, IntoFrameT]
_reader: Reader[IntoDataFrameT, IntoLazyFrameT]

@overload
@classmethod
Expand All @@ -58,13 +56,11 @@ def from_backend(
@classmethod
def from_backend(
cls, backend_name: Literal["pandas", "pandas[pyarrow]"], /
) -> Loader[pd.DataFrame, pd.DataFrame]: ...
) -> Loader[pd.DataFrame]: ...

@overload
@classmethod
def from_backend(
cls, backend_name: Literal["pyarrow"], /
) -> Loader[pa.Table, pa.Table]: ...
def from_backend(cls, backend_name: Literal["pyarrow"], /) -> Loader[pa.Table]: ...

@classmethod
def from_backend(
Expand Down Expand Up @@ -130,7 +126,7 @@ def from_backend(
return cls.from_reader(_reader._from_backend(backend_name))

@classmethod
def from_reader(cls, reader: Reader[IntoDataFrameT, IntoFrameT], /) -> Self:
def from_reader(cls, reader: Reader[IntoDataFrameT, IntoLazyFrameT], /) -> Self:
obj = cls.__new__(cls)
obj._reader = reader
return obj
Expand Down Expand Up @@ -294,7 +290,7 @@ def __repr__(self) -> str:


@final
class _Load(Loader[IntoDataFrameT, IntoFrameT]):
class _Load(Loader[IntoDataFrameT, IntoLazyFrameT]):
@overload
def __call__( # pyright: ignore[reportOverlappingOverload]
self,
Expand Down
82 changes: 49 additions & 33 deletions altair/datasets/_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@
from importlib.util import find_spec
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload
from urllib.request import build_opener as _build_opener

from narwhals.stable import v1 as nw
from narwhals.stable.v1.typing import IntoDataFrameT, IntoExpr
from packaging.requirements import Requirement

from altair.datasets import _readimpl
from altair.datasets._cache import CsvCache, DatasetCache, SchemaCache, _iter_metadata
from altair.datasets._constraints import is_parquet
from altair.datasets._exceptions import AltairDatasetsError, module_not_found
from altair.datasets._readimpl import IntoFrameT, is_available
from altair.datasets._readimpl import IntoDataFrameT, IntoLazyFrameT, is_available

if TYPE_CHECKING:
import sys
Expand All @@ -46,6 +45,7 @@
import pandas as pd
import polars as pl
import pyarrow as pa
from narwhals.stable.v1.typing import IntoExpr

from altair.datasets._readimpl import BaseImpl, R, Read, Scan
from altair.datasets._typing import Dataset, Extension, Metadata
Expand Down Expand Up @@ -91,6 +91,12 @@
_Ibis,
_PySpark,
)
_EagerAllowedImpl: TypeAlias = Literal[
nw.Implementation.PANDAS,
nw.Implementation.POLARS,
nw.Implementation.PYARROW,
]
_EagerAllowed: TypeAlias = Literal[_Pandas, _Polars, _PyArrow]

_SupportProfile: TypeAlias = Mapping[
Literal["supported", "unsupported"], "Sequence[Dataset]"
Expand All @@ -113,7 +119,7 @@
"""


class Reader(Generic[IntoDataFrameT, IntoFrameT]):
class Reader(Generic[IntoDataFrameT, IntoLazyFrameT]):
"""
Modular file reader, targeting remote & local tabular resources.

Expand All @@ -124,7 +130,7 @@ class Reader(Generic[IntoDataFrameT, IntoFrameT]):
_read: Sequence[Read[IntoDataFrameT]]
"""Eager file read functions."""

_scan: Sequence[Scan[IntoFrameT]]
_scan: Sequence[Scan[IntoLazyFrameT]]
"""Lazy file read functions."""

_name: str
Expand All @@ -134,7 +140,7 @@ class Reader(Generic[IntoDataFrameT, IntoFrameT]):
Otherwise, has no concrete meaning.
"""

_implementation: nw.Implementation
_implementation: _EagerAllowedImpl
"""
Corresponding `narwhals implementation`_.

Expand All @@ -150,9 +156,9 @@ class Reader(Generic[IntoDataFrameT, IntoFrameT]):
def __init__(
self,
read: Sequence[Read[IntoDataFrameT]],
scan: Sequence[Scan[IntoFrameT]],
scan: Sequence[Scan[IntoLazyFrameT]],
name: str,
implementation: nw.Implementation,
implementation: _EagerAllowedImpl,
) -> None:
self._read = read
self._scan = scan
Expand All @@ -173,7 +179,7 @@ def __repr__(self) -> str:
def read_fn(self, meta: Metadata, /) -> Callable[..., IntoDataFrameT]:
return self._solve(meta, self._read)

def scan_fn(self, meta: Metadata | Path | str, /) -> Callable[..., IntoFrameT]:
def scan_fn(self, meta: Metadata | Path | str, /) -> Callable[..., IntoLazyFrameT]:
meta = meta if isinstance(meta, Mapping) else {"suffix": _into_suffix(meta)}
return self._solve(meta, self._scan)

Expand Down Expand Up @@ -330,13 +336,13 @@ def _merge_kwds(self, meta: Metadata, kwds: dict[str, Any], /) -> Mapping[str, A
return kwds

@property
def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]:
def _metadata_frame(self) -> nw.LazyFrame[IntoLazyFrameT]:
fp = self._metadata_path
return nw.from_native(self.scan_fn(fp)(fp)).lazy()

def _scan_metadata(
self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata]
) -> nw.LazyFrame[IntoFrameT]:
) -> nw.LazyFrame[IntoLazyFrameT]:
if predicates or constraints:
return self._metadata_frame.filter(*predicates, **constraints)
return self._metadata_frame
Expand Down Expand Up @@ -373,7 +379,7 @@ def _dataset_names(
)


class _NoParquetReader(Reader[IntoDataFrameT, IntoFrameT]):
class _NoParquetReader(Reader[IntoDataFrameT]):
def __repr__(self) -> str:
return f"{super().__repr__()}\ncsv_cache\n {self.csv_cache!r}"

Expand All @@ -384,8 +390,8 @@ def csv_cache(self) -> CsvCache:
return self._csv_cache

@property
def _metadata_frame(self) -> nw.LazyFrame[IntoFrameT]:
data = cast("dict[str, Any]", self.csv_cache.rotated)
def _metadata_frame(self) -> nw.LazyFrame[Any]:
data = self.csv_cache.rotated
impl = self._implementation
return nw.maybe_convert_dtypes(nw.from_dict(data, backend=impl)).lazy()

Expand All @@ -397,31 +403,28 @@ def reader(
*,
name: str | None = ...,
implementation: nw.Implementation = ...,
) -> Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]]: ...
) -> Reader[IntoDataFrameT]: ...


@overload
def reader(
read_fns: Sequence[Read[IntoDataFrameT]],
scan_fns: Sequence[Scan[IntoFrameT]],
scan_fns: Sequence[Scan[IntoLazyFrameT]],
*,
name: str | None = ...,
implementation: nw.Implementation = ...,
) -> Reader[IntoDataFrameT, IntoFrameT]: ...
) -> Reader[IntoDataFrameT, IntoLazyFrameT]: ...


def reader(
read_fns: Sequence[Read[IntoDataFrameT]],
scan_fns: Sequence[Scan[IntoFrameT]] = (),
scan_fns: Sequence[Scan[IntoLazyFrameT]] = (),
*,
name: str | None = None,
implementation: nw.Implementation = nw.Implementation.UNKNOWN,
) -> (
Reader[IntoDataFrameT, IntoFrameT]
| Reader[IntoDataFrameT, nw.LazyFrame[IntoDataFrameT]]
):
) -> Reader[IntoDataFrameT, IntoLazyFrameT] | Reader[IntoDataFrameT]:
name = name or Counter(el._inferred_package for el in read_fns).most_common(1)[0][0]
if implementation is nw.Implementation.UNKNOWN:
if not _is_eager_allowed(implementation):
implementation = _into_implementation(Requirement(name))
if scan_fns:
return Reader(read_fns, scan_fns, name, implementation)
Expand Down Expand Up @@ -456,9 +459,9 @@ def infer_backend(
@overload
def _from_backend(name: _Polars, /) -> Reader[pl.DataFrame, pl.LazyFrame]: ...
@overload
def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame, pd.DataFrame]: ...
def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame]: ...
@overload
def _from_backend(name: _PyArrow, /) -> Reader[pa.Table, pa.Table]: ...
def _from_backend(name: _PyArrow, /) -> Reader[pa.Table]: ...


# FIXME: The order this is defined in makes splitting the module complicated
Expand Down Expand Up @@ -516,15 +519,28 @@ def _into_constraints(
return m


def _is_eager_allowed(impl: nw.Implementation, /) -> TypeIs[_EagerAllowedImpl]:
return impl in {
nw.Implementation.PANDAS,
nw.Implementation.POLARS,
nw.Implementation.PYARROW,
}


def _into_implementation(
backend: _NwSupport | _PandasAny | Requirement, /
) -> nw.Implementation:
primary = _import_guarded(backend)
backend: _NwSupport | _PandasAny | nw.Implementation | Requirement, /
) -> _EagerAllowedImpl:
req = (
Requirement(str(backend)) if isinstance(backend, nw.Implementation) else backend
)
primary = _import_guarded(req)
impl = nw.Implementation.from_backend(primary)
if impl is not nw.Implementation.UNKNOWN:
return impl
msg = f"Package {primary!r} is not supported by `narwhals`."
raise ValueError(msg)
if not _is_eager_allowed(impl):
if impl is nw.Implementation.UNKNOWN:
msg = f"Package {primary!r} is not supported by `narwhals`."
raise ValueError(msg)
raise NotImplementedError(impl)
return impl


def _into_suffix(obj: Path | str, /) -> Any:
Expand All @@ -539,7 +555,7 @@ def _into_suffix(obj: Path | str, /) -> Any:

def _steal_eager_parquet(
read_fns: Sequence[Read[IntoDataFrameT]], /
) -> Sequence[Scan[nw.LazyFrame[IntoDataFrameT]]] | None:
) -> Sequence[Scan[Any]] | None:
if convertable := next((rd for rd in read_fns if rd.include <= is_parquet), None):
return (_readimpl.into_scan(convertable),)
return None
Expand Down
Loading