Skip to content

Commit da41251

Browse files
committed
Add type hints to altair.utils.schemapi
1 parent 0217b2e commit da41251

File tree

1 file changed

+60
-52
lines changed

1 file changed

+60
-52
lines changed

altair/utils/schemapi.py

+60-52
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@
44
import contextlib
55
import inspect
66
import json
7+
from __future__ import annotations
8+
from typing import Any, Callable, DefaultDict, Dict, FrozenSet, Iterable, Iterator, List, Literal, Mapping, Optional, Sequence, Set, Tuple, Type, TypeAlias, TypeVar, Union
79

810
import jsonschema
11+
from jsonschema.validators import RefResolver
912
import numpy as np
13+
import numpy.typing as npt
1014
import pandas as pd
1115

1216

@@ -16,20 +20,22 @@
1620
# Individual schema classes can override this by setting the
1721
# class-level _class_is_valid_at_instantiation attribute to False
1822
DEBUG_MODE = True
23+
GenericT = TypeVar("GenericT")
24+
T = TypeVar("T", bound="SchemaBase")
25+
AltairObj: TypeAlias = Union["SchemaBase", List[Any], Tuple[Any], npt.NDArray[Any], Dict[Any, Any], np.number[Any], pd.Timestamp, np.datetime64]
1926

20-
21-
def enable_debug_mode():
27+
def enable_debug_mode() -> None:
2228
global DEBUG_MODE
2329
DEBUG_MODE = True
2430

2531

26-
def disable_debug_mode():
32+
def disable_debug_mode() -> None:
2733
global DEBUG_MODE
2834
DEBUG_MODE = True
2935

3036

3137
@contextlib.contextmanager
32-
def debug_mode(arg):
38+
def debug_mode(arg: bool) -> Iterator[Optional[bool]]:
3339
global DEBUG_MODE
3440
original = DEBUG_MODE
3541
DEBUG_MODE = arg
@@ -39,9 +45,9 @@ def debug_mode(arg):
3945
DEBUG_MODE = original
4046

4147

42-
def _subclasses(cls):
48+
def _subclasses(cls: Type[GenericT]) -> Iterable[Type[GenericT]]:
4349
"""Breadth-first sequence of all classes which inherit from cls."""
44-
seen = set()
50+
seen: Set[Type[GenericT]] = set()
4551
current_set = {cls}
4652
while current_set:
4753
seen |= current_set
@@ -50,7 +56,7 @@ def _subclasses(cls):
5056
yield cls
5157

5258

53-
def _todict(obj, validate, context):
59+
def _todict(obj: AltairObj, validate: Union[bool, str], context: Optional[Dict[Any, Any]]) -> Union[AltairObj, Dict[Any, Any], str, float]:
5460
"""Convert an object to a dict representation."""
5561
if isinstance(obj, SchemaBase):
5662
return obj.to_dict(validate=validate, context=context)
@@ -72,9 +78,9 @@ def _todict(obj, validate, context):
7278
return obj
7379

7480

75-
def _resolve_references(schema, root=None):
81+
def _resolve_references(schema: Mapping[str, Any], root: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
7682
"""Resolve schema references."""
77-
resolver = jsonschema.RefResolver.from_schema(root or schema)
83+
resolver: RefResolver = jsonschema.RefResolver.from_schema(root or schema)
7884
while "$ref" in schema:
7985
with resolver.resolving(schema["$ref"]) as resolved:
8086
schema = resolved
@@ -84,12 +90,12 @@ def _resolve_references(schema, root=None):
8490
class SchemaValidationError(jsonschema.ValidationError):
8591
"""A wrapper for jsonschema.ValidationError with friendlier traceback"""
8692

87-
def __init__(self, obj, err):
93+
def __init__(self, obj: Any, err: jsonschema.ValidationError) -> None:
8894
super(SchemaValidationError, self).__init__(**self._get_contents(err))
8995
self.obj = obj
9096

9197
@staticmethod
92-
def _get_contents(err):
98+
def _get_contents(err: jsonschema.ValidationError) -> Dict[str, Any]:
9399
"""Get a dictionary with the contents of a ValidationError"""
94100
try:
95101
# works in jsonschema 2.3 or later
@@ -104,7 +110,7 @@ def _get_contents(err):
104110
contents = {key: getattr(err, key) for key in spec.args[1:]}
105111
return contents
106112

107-
def __str__(self):
113+
def __str__(self) -> str:
108114
cls = self.obj.__class__
109115
schema_path = ["{}.{}".format(cls.__module__, cls.__name__)]
110116
schema_path.extend(self.schema_path)
@@ -128,12 +134,12 @@ class UndefinedType(object):
128134

129135
__instance = None
130136

131-
def __new__(cls, *args, **kwargs):
137+
def __new__(cls, *args: Any, **kwargs: Any) -> UndefinedType:
132138
if not isinstance(cls.__instance, cls):
133139
cls.__instance = object.__new__(cls, *args, **kwargs)
134140
return cls.__instance
135141

136-
def __repr__(self):
142+
def __repr__(self) -> Literal["Undefined"]:
137143
return "Undefined"
138144

139145

@@ -147,12 +153,12 @@ class SchemaBase(object):
147153
the _rootschema class attribute) which is used for validation.
148154
"""
149155

150-
_schema = None
151-
_rootschema = None
156+
_schema: Optional[Mapping[str, Any]] = None
157+
_rootschema: Optional[Mapping[str, Any]] = None
152158
_class_is_valid_at_instantiation = True
153159
_validator = jsonschema.Draft7Validator
154160

155-
def __init__(self, *args, **kwds):
161+
def __init__(self, *args: Any, **kwds: Any) -> None:
156162
# Two valid options for initialization, which should be handled by
157163
# derived classes:
158164
# - a single arg with no kwds, for, e.g. {'type': 'string'}
@@ -176,7 +182,7 @@ def __init__(self, *args, **kwds):
176182
if DEBUG_MODE and self._class_is_valid_at_instantiation:
177183
self.to_dict(validate=True)
178184

179-
def copy(self, deep=True, ignore=()):
185+
def copy(self: T, deep: Union[bool, Sequence[Any]] = True, ignore: Sequence[Any] = ()) -> T:
180186
"""Return a copy of the object
181187
182188
Parameters
@@ -191,7 +197,7 @@ def copy(self, deep=True, ignore=()):
191197
only stored by reference.
192198
"""
193199

194-
def _shallow_copy(obj):
200+
def _shallow_copy(obj: T) -> T:
195201
if isinstance(obj, SchemaBase):
196202
return obj.copy(deep=False)
197203
elif isinstance(obj, list):
@@ -201,7 +207,7 @@ def _shallow_copy(obj):
201207
else:
202208
return obj
203209

204-
def _deep_copy(obj, ignore=()):
210+
def _deep_copy(obj: T, ignore: Sequence[Any] = ()) -> T:
205211
if isinstance(obj, SchemaBase):
206212
args = tuple(_deep_copy(arg) for arg in obj._args)
207213
kwds = {
@@ -221,7 +227,7 @@ def _deep_copy(obj, ignore=()):
221227
return obj
222228

223229
try:
224-
deep = list(deep)
230+
deep: List[Any] = list(deep)
225231
except TypeError:
226232
deep_is_list = False
227233
else:
@@ -237,36 +243,36 @@ def _deep_copy(obj, ignore=()):
237243
copy[attr] = _shallow_copy(copy._get(attr))
238244
return copy
239245

240-
def _get(self, attr, default=Undefined):
246+
def _get(self, attr: str, default: Any = Undefined) -> Any:
241247
"""Get an attribute, returning default if not present."""
242-
attr = self._kwds.get(attr, Undefined)
248+
attr: Any = self._kwds.get(attr, Undefined)
243249
if attr is Undefined:
244250
attr = default
245251
return attr
246252

247-
def __getattr__(self, attr):
253+
def __getattr__(self, attr: str) -> Any:
248254
# reminder: getattr is called after the normal lookups
249255
if attr == "_kwds":
250256
raise AttributeError()
251257
if attr in self._kwds:
252258
return self._kwds[attr]
253259
else:
254260
try:
255-
_getattr = super(SchemaBase, self).__getattr__
261+
_getattr: Callable[[str], Any] = super(SchemaBase, self).__getattr__
256262
except AttributeError:
257263
_getattr = super(SchemaBase, self).__getattribute__
258264
return _getattr(attr)
259265

260-
def __setattr__(self, item, val):
266+
def __setattr__(self, item: str, val: Any) -> None:
261267
self._kwds[item] = val
262268

263-
def __getitem__(self, item):
269+
def __getitem__(self, item: str) -> Any:
264270
return self._kwds[item]
265271

266-
def __setitem__(self, item, val):
272+
def __setitem__(self, item: str, val: Any) -> None:
267273
self._kwds[item] = val
268274

269-
def __repr__(self):
275+
def __repr__(self) -> str:
270276
if self._kwds:
271277
args = (
272278
"{}: {!r}".format(key, val)
@@ -280,14 +286,14 @@ def __repr__(self):
280286
else:
281287
return "{}({!r})".format(self.__class__.__name__, self._args[0])
282288

283-
def __eq__(self, other):
289+
def __eq__(self, other: Any) -> bool:
284290
return (
285291
type(self) is type(other)
286292
and self._args == other._args
287293
and self._kwds == other._kwds
288294
)
289295

290-
def to_dict(self, validate=True, ignore=None, context=None):
296+
def to_dict(self, validate: Union[bool, str] = True, ignore: Optional[Sequence[str]] = None, context: Optional[Dict[Any, Any]] = None) -> Union[AltairObj, Dict[Any, Any], str, float]:
291297
"""Return a dictionary representation of the object
292298
293299
Parameters
@@ -341,8 +347,8 @@ def to_dict(self, validate=True, ignore=None, context=None):
341347
return result
342348

343349
def to_json(
344-
self, validate=True, ignore=[], context={}, indent=2, sort_keys=True, **kwargs
345-
):
350+
self, validate: Union[bool, str] = True, ignore: Sequence[str] = (), context: Optional[Dict[Any, Any]] = None, indent: int = 2, sort_keys: bool = True, **kwargs: Any
351+
) -> str:
346352
"""Emit the JSON representation for this object as a string.
347353
348354
Parameters
@@ -370,16 +376,18 @@ def to_json(
370376
spec : string
371377
The JSON specification of the chart object.
372378
"""
379+
if not context:
380+
context = {}
373381
dct = self.to_dict(validate=validate, ignore=ignore, context=context)
374382
return json.dumps(dct, indent=indent, sort_keys=sort_keys, **kwargs)
375383

376384
@classmethod
377-
def _default_wrapper_classes(cls):
385+
def _default_wrapper_classes(cls) -> Iterable[Type[SchemaBase]]:
378386
"""Return the set of classes used within cls.from_dict()"""
379387
return _subclasses(SchemaBase)
380388

381389
@classmethod
382-
def from_dict(cls, dct, validate=True, _wrapper_classes=None):
390+
def from_dict(cls: Type[T], dct: Mapping[str, Any], validate: bool = True, _wrapper_classes: Optional[Union[Iterable[Type[SchemaBase]], Iterable[Type[T]]]] = None) -> T:
383391
"""Construct class from a dictionary representation
384392
385393
Parameters
@@ -411,7 +419,7 @@ def from_dict(cls, dct, validate=True, _wrapper_classes=None):
411419
return converter.from_dict(dct, cls)
412420

413421
@classmethod
414-
def from_json(cls, json_string, validate=True, **kwargs):
422+
def from_json(cls: Type[T], json_string: str, validate: bool = True, **kwargs: Any) -> T:
415423
"""Instantiate the object from a valid JSON string
416424
417425
Parameters
@@ -432,42 +440,42 @@ def from_json(cls, json_string, validate=True, **kwargs):
432440
return cls.from_dict(dct, validate=validate)
433441

434442
@classmethod
435-
def validate(cls, instance, schema=None):
443+
def validate(cls: Type[T], instance: Any, schema: Optional[Mapping[str, Any]] = None) -> None:
436444
"""
437445
Validate the instance against the class schema in the context of the
438446
rootschema.
439447
"""
440448
if schema is None:
441449
schema = cls._schema
442-
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
450+
resolver: RefResolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
443451
return jsonschema.validate(
444452
instance, schema, cls=cls._validator, resolver=resolver
445453
)
446454

447455
@classmethod
448-
def resolve_references(cls, schema=None):
456+
def resolve_references(cls, schema: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
449457
"""Resolve references in the context of this object's schema or root schema."""
450458
return _resolve_references(
451459
schema=(schema or cls._schema),
452460
root=(cls._rootschema or cls._schema or schema),
453461
)
454462

455463
@classmethod
456-
def validate_property(cls, name, value, schema=None):
464+
def validate_property(cls, name: str, value: Any, schema: Optional[Mapping[str, Any]] = None) -> None:
457465
"""
458466
Validate a property against property schema in the context of the
459467
rootschema
460468
"""
461469
value = _todict(value, validate=False, context={})
462470
props = cls.resolve_references(schema or cls._schema).get("properties", {})
463-
resolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
471+
resolver: RefResolver = jsonschema.RefResolver.from_schema(cls._rootschema or cls._schema)
464472
return jsonschema.validate(value, props.get(name, {}), resolver=resolver)
465473

466-
def __dir__(self):
474+
def __dir__(self) -> List[str]:
467475
return list(self._kwds.keys())
468476

469477

470-
def _passthrough(*args, **kwds):
478+
def _passthrough(*args: Any, **kwds: Any) -> Union[Any, Dict[str, Any]]:
471479
return args[0] if args else kwds
472480

473481

@@ -481,16 +489,16 @@ class _FromDict(object):
481489

482490
_hash_exclude_keys = ("definitions", "title", "description", "$schema", "id")
483491

484-
def __init__(self, class_list):
492+
def __init__(self, class_list: Iterable[Type[Any]]) -> None:
485493
# Create a mapping of a schema hash to a list of matching classes
486494
# This lets us quickly determine the correct class to construct
487-
self.class_dict = collections.defaultdict(list)
495+
self.class_dict: DefaultDict[Any, List[Any] ]= collections.defaultdict(list)
488496
for cls in class_list:
489497
if cls._schema is not None:
490498
self.class_dict[self.hash_schema(cls._schema)].append(cls)
491499

492500
@classmethod
493-
def hash_schema(cls, schema, use_json=True):
501+
def hash_schema(cls, schema: Mapping[str, Any], use_json: bool = True) -> int:
494502
"""
495503
Compute a python hash for a nested dictionary which
496504
properly handles dicts, lists, sets, and tuples.
@@ -513,7 +521,7 @@ def hash_schema(cls, schema, use_json=True):
513521
return hash(s)
514522
else:
515523

516-
def _freeze(val):
524+
def _freeze(val: Union[Dict[Any, Any], Set[Any], Sequence[Any], GenericT]) -> Union[FrozenSet[Any], Tuple[Any], GenericT]:
517525
if isinstance(val, dict):
518526
return frozenset((k, _freeze(v)) for k, v in val.items())
519527
elif isinstance(val, set):
@@ -526,8 +534,8 @@ def _freeze(val):
526534
return hash(_freeze(schema))
527535

528536
def from_dict(
529-
self, dct, cls=None, schema=None, rootschema=None, default_class=_passthrough
530-
):
537+
self, dct: Union[Mapping[str, Any], SchemaBase], cls: Optional[Type[T]] = None, schema: Optional[Mapping[str, Any]] = None, rootschema: Optional[Mapping[str, Any]] = None, default_class: Any = _passthrough
538+
) -> Union[T, SchemaBase]:
531539
"""Construct an object from a dict representation"""
532540
if (schema is None) == (cls is None):
533541
raise ValueError("Must provide either cls or schema, but not both.")
@@ -553,7 +561,7 @@ def from_dict(
553561
if "anyOf" in schema or "oneOf" in schema:
554562
schemas = schema.get("anyOf", []) + schema.get("oneOf", [])
555563
for possible_schema in schemas:
556-
resolver = jsonschema.RefResolver.from_schema(rootschema)
564+
resolver: RefResolver = jsonschema.RefResolver.from_schema(rootschema)
557565
try:
558566
jsonschema.validate(dct, possible_schema, resolver=resolver)
559567
except jsonschema.ValidationError:
@@ -569,7 +577,7 @@ def from_dict(
569577
if isinstance(dct, dict):
570578
# TODO: handle schemas for additionalProperties/patternProperties
571579
props = schema.get("properties", {})
572-
kwds = {}
580+
kwds: Mapping[str, Any] = {}
573581
for key, val in dct.items():
574582
if key in props:
575583
val = self.from_dict(val, schema=props[key], rootschema=rootschema)
@@ -578,7 +586,7 @@ def from_dict(
578586

579587
elif isinstance(dct, list):
580588
item_schema = schema.get("items", {})
581-
dct = [
589+
dct: List[Union[T, SchemaBase]] = [
582590
self.from_dict(val, schema=item_schema, rootschema=rootschema)
583591
for val in dct
584592
]

0 commit comments

Comments
 (0)