Skip to content

Commit

Permalink
✨ feat(vecs): initialize space from non-vectors (#357)
Browse files Browse the repository at this point in the history
nstarman authored Jan 25, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent f89da37 commit 2df61b1
Showing 1 changed file with 52 additions and 16 deletions.
68 changes: 52 additions & 16 deletions src/coordinax/_src/vectors/space/core.py
Original file line number Diff line number Diff line change
@@ -104,36 +104,49 @@ class Space(AbstractVector, ImmutableMap[Dimension, AbstractVector]): # type: i
>>> w.mT.shapes
mappingproxy({'length': (2, 1), 'speed': (2, 1)})
There are convenience ways to initialize the vectors in the space:
>>> space = cx.Space.from_({"length": u.Quantity([1, 2, 3], "km"),
... "speed": u.Quantity([4, 5, 6], "km/s")})
>>> print(space)
Space({
'length': <CartesianPos3D (x[km], y[km], z[km])
[1 2 3]>,
'speed': <CartesianVel3D (d_x[km / s], d_y[km / s], d_z[km / s])
[4 5 6]>
})
"""

_data: dict[str, AbstractVector] = eqx.field(init=False)

def __init__( # pylint: disable=super-init-not-called # TODO: resolve this
self,
/,
*args: Mapping[DimensionLike, AbstractVector]
| tuple[DimensionLike, AbstractVector]
| Iterable[tuple[DimensionLike, AbstractVector]],
**kwargs: AbstractVector,
*args: Mapping[DimensionLike, Any]
| tuple[DimensionLike, Any]
| Iterable[tuple[DimensionLike, Any]],
**kwargs: Any,
) -> None:
# Process the input data
# Consolidate the inputs into a single dict, then process keys & values.
raw = dict(*args, **kwargs) # process the input data
keys = [_get_dimension_name(k) for k in raw]
keys = eqx.error_if(
keys,
len(keys) < len(raw),
f"Space(**input) contained duplicate keys {set(raw) - set(keys)}.",
)
# TODO: check the key dimension makes sense for the value

# Process the keys
dims = tuple(u.dimension(k) for k in raw)
keys = tuple(_get_dimension_name(dim) for dim in dims)
# Convert the values to vectors
values = tuple(vector(v) for v in raw.values())

# TODO: check the dimension makes sense for the value

# Check that the shapes are broadcastable
keys = eqx.error_if(
keys,
not _can_broadcast_shapes(*(v.shape for v in raw.values())),
values = eqx.error_if(
values,
not _can_broadcast_shapes(*map(jnp.shape, values)),
"vector shapes are not broadcastable.",
)

ImmutableMap.__init__(self, dict(zip(keys, raw.values(), strict=True)))
ImmutableMap.__init__(self, dict(zip(keys, values, strict=True)))

@classmethod
def _dimensionality(cls) -> int:
@@ -581,6 +594,29 @@ def vector(
return cls(length=q, speed=p, acceleration=a)


@dispatch
def vector(
cls: type[Space],
obj: Mapping[str, Any],
) -> Space:
"""Construct a Space from a Mapping.
Examples
--------
>>> import unxt as u
>>> import coordinax as cx
>>> space = cx.Space.from_({ 'length': u.Quantity([1, 2, 3], "m") })
>>> print(space)
Space({
'length': <CartesianPos3D (x[m], y[m], z[m])
[1 2 3]>
})
"""
return cls({k: vector(v) for k, v in obj.items()})


# ===============================================================
# Vector API dispatches

0 comments on commit 2df61b1

Please sign in to comment.