Skip to content

Commit

Permalink
Work on documentation (15).
Browse files Browse the repository at this point in the history
- Write "top-down view of Feedbax" example.
- Move `feedbax.model` to `feedbax._model`.
- Rename `"statics_step"` stage in `Mechanics` to `"kinematics_update"`
  • Loading branch information
mlprt committed Feb 28, 2024
1 parent bd252fa commit 147fb42
Show file tree
Hide file tree
Showing 17 changed files with 29 additions and 28 deletions.
6 changes: 3 additions & 3 deletions docs/api/model.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Model base classes

::: feedbax.model.AbstractModel
::: feedbax.AbstractModel

::: feedbax.model.ModelInput
::: feedbax.ModelInput

::: feedbax.model.wrap_stateless_callable
::: feedbax.wrap_stateless_callable
6 changes: 3 additions & 3 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ I've developed Feedbax over the last few months, as I've learned JAX. My short-t

By making the library open source now, I hope to receive some feedback about those decisions. To make that easier I've created GitHub [issues](https://github.com/mlprt/feedbax/issues) documenting my choices and uncertainties. The issues largely fall into a few categories:

1. Structure: Some of the abstractions I've chosen are probably clumsy. It would be good to know about that, at this point. Maybe we can make some changes for the better! In approximate order of significance: #19, #12, #1, #21.
2. Features: There are many small additions that could be made, especially to the pre-built models and tasks. There are also a few major improvements which I am anticipating in the near future, such as *online learning* (#21).
3. Typing: Typing in Feedbax is a mess, at the moment. I have been learning to use the typing system recently. However, I haven't been constraining myself with type checker errors. I know I've done some things that probably won't work. See issues. (#7, #8, #9, #11)
1. Structure: Some of the abstractions I've chosen are probably clumsy. It would be good to know about that, at this point. Maybe we can make some changes for the better! In approximate order of significance: #19, #12, #1, #5, #21.
2. Features: There are many small additions that could be made, especially to the pre-built models and tasks. There are also a few major improvements which I am anticipating in the near future, such as *online learning* (#21). #10,
3. Typing: Typing in Feedbax is a mess, at the moment. I have been learning to use the typing system recently. However, I haven't been constraining myself with type checker errors. I know I've done some things that probably won't work, and that there may not be clever solutions to some of the issues. See issues: (#7, #8, #9, #11)

If you are an experienced Python or JAX user:

Expand Down
2 changes: 1 addition & 1 deletion feedbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import warnings

from feedbax._io import save, load
from feedbax.model import wrap_stateless_callable
from feedbax._model import AbstractModel, ModelInput, wrap_stateless_callable
from feedbax._staged import (
AbstractStagedModel,
ModelStage,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion feedbax/_staged.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from jaxtyping import Array, PRNGKeyArray, PyTree
import numpy as np

from feedbax.model import AbstractModel, ModelInput
from feedbax._model import AbstractModel, ModelInput
from feedbax.intervene import AbstractIntervenor
from feedbax.misc import indent_str
from feedbax.state import AbstractState, StateT
Expand Down
2 changes: 1 addition & 1 deletion feedbax/bodies.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from feedbax.channel import Channel, ChannelSpec, ChannelState
from feedbax.intervene import AbstractIntervenor
from feedbax.model import MultiModel
from feedbax._model import MultiModel
from feedbax.mechanics import Mechanics, MechanicsState
from feedbax.nn import NetworkState
from feedbax._staged import AbstractStagedModel, ModelStage
Expand Down
2 changes: 1 addition & 1 deletion feedbax/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import jax.numpy as jnp
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.state import CartesianState, StateBounds, StateT


Expand Down
2 changes: 1 addition & 1 deletion feedbax/intervene.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from jaxtyping import Array, ArrayLike, Float, PRNGKeyArray, PyTree

from feedbax.misc import get_unique_label
from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.state import AbstractState, StateT
from feedbax._tree import tree_call

Expand Down
2 changes: 1 addition & 1 deletion feedbax/iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from jaxtyping import Array, PRNGKeyArray, PyTree, Shaped
from tqdm.auto import tqdm

from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.state import StateT
from feedbax._tree import tree_take, tree_set

Expand Down
4 changes: 2 additions & 2 deletions feedbax/mechanics/mechanics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from feedbax.intervene import AbstractIntervenor
from feedbax.mechanics.plant import AbstractPlant, PlantState

from feedbax.model import wrap_stateless_callable
from feedbax._model import wrap_stateless_callable
from feedbax._staged import AbstractStagedModel, ModelStage
from feedbax.state import AbstractState, CartesianState

Expand Down Expand Up @@ -90,7 +90,7 @@ def model_spec(self) -> OrderedDict[str, ModelStage]:
where_input=lambda input, state: state.effector.force,
where_state=lambda state: state.plant.skeleton,
),
"statics_step": ModelStage(
"kinematics_update": ModelStage(
# the `plant` module directly implements non-ODE operations
callable=lambda self: self.plant,
where_input=lambda input, state: input,
Expand Down
3 changes: 2 additions & 1 deletion feedbax/mechanics/muscle.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np

from feedbax.dynamics import AbstractDynamicalSystem
from feedbax.model import AbstractModel
from feedbax.state import AbstractState, StateBounds


Expand Down Expand Up @@ -122,7 +123,7 @@ class AbstractActivationFunction(eqx.Module):
def __call__(self, input: Array, state: AbstractMuscleState) -> Array: ...


class AbstractMuscle(eqx.Module):
class AbstractMuscle(AbstractModel[AbstractMuscleState]):
"""Base class for muscle models.
Attributes:
Expand Down
14 changes: 7 additions & 7 deletions feedbax/mechanics/plant.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ class AbstractPlant(
a muscle, which is not described by a differential equation but is
directly dependent in each instant on the skeletal geometry.
Static updates are specified in `model_spec`, and dynamic updates in
`dynamics_spec`. Calling an `AbstractPlant` instance will only perform
the static updates defined in `model_spec`. Normally we wrap the
instance in a `Mechanics` instance to discretize the dynamics—then,
calling the `Mechanics` instance will perform both sets of updates.
Kinematic/geometric updates are specified in `model_spec`, and dynamic updates
in `dynamics_spec`. Calling an `AbstractPlant` instance will only perform the
kinematic updates defined in `model_spec`. Normally we wrap the instance in a
`Mechanics` instance to discretize the dynamics—then, calling the `Mechanics`
instance will perform both sets of updates.
Attributes:
skeleton: The model of skeletal dynamics.
Expand Down Expand Up @@ -113,7 +113,7 @@ def vector_field(

@abstractproperty
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies static/instantaneous updates to the musculoskeletal state."""
"""Specifies kinematic updates to the musculoskeletal state."""
...

@abstractproperty
Expand Down Expand Up @@ -349,7 +349,7 @@ def __init__(

@property
def model_spec(self) -> OrderedDict[str, ModelStage]:
"""Specifies static updates to the musculoskeletal state."""
"""Specifies kinematic updates to the musculoskeletal state."""
return OrderedDict(
{
"clip_skeleton_state": ModelStage(
Expand Down
2 changes: 1 addition & 1 deletion feedbax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from feedbax.intervene import AbstractIntervenor
from feedbax.model import wrap_stateless_callable
from feedbax._model import wrap_stateless_callable
from feedbax.misc import (
identity_func,
interleave_unequal,
Expand Down
4 changes: 2 additions & 2 deletions feedbax/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@
from feedbax.intervene import AbstractIntervenorInput, TimeSeriesParam
from feedbax.loss import AbstractLoss, LossDict
from feedbax._mapping import AbstractTransformedOrderedDict
from feedbax.model import ModelInput
from feedbax._model import ModelInput
from feedbax.state import AbstractState, CartesianState, StateT
from feedbax._tree import tree_call

if TYPE_CHECKING:
from feedbax.model import AbstractModel
from feedbax._model import AbstractModel

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion feedbax/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from feedbax import loss
from feedbax.loss import AbstractLoss, LossDict
from feedbax.misc import TqdmLoggingHandler, delete_contents
from feedbax.model import AbstractModel, ModelInput
from feedbax._model import AbstractModel, ModelInput
import feedbax.plot as plot
from feedbax.state import StateT
from feedbax.task import AbstractTask, AbstractTaskTrialSpec
Expand Down
2 changes: 1 addition & 1 deletion feedbax/xabdeef/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import optax

from feedbax import get_ensemble
from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.task import AbstractTask, SimpleReaches
from feedbax.train import TaskTrainer, TaskTrainerHistory
from feedbax.xabdeef.losses import simple_reach_loss
Expand Down
2 changes: 1 addition & 1 deletion feedbax/xabdeef/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from feedbax.bodies import SimpleFeedback
from feedbax.mechanics.plant import DirectForceInput
from feedbax.misc import identity_func
from feedbax.model import AbstractModel
from feedbax._model import AbstractModel
from feedbax.iterate import Iterator
from feedbax.mechanics import Mechanics
from feedbax.mechanics.skeleton.pointmass import PointMass
Expand Down

0 comments on commit 147fb42

Please sign in to comment.