diff --git a/circadian/models.py b/circadian/models.py index e5430e7..19fc6de 100644 --- a/circadian/models.py +++ b/circadian/models.py @@ -85,15 +85,9 @@ def __call__(self, timepoint: float) -> np.ndarray: # state of the system # %% ../nbs/api/00_models.ipynb 11 @patch_to(DynamicalTrajectory) -def __getitem__(self, time_idx: int) -> Tuple[float, np.ndarray]: - "Return the time and state at index idx" - # idx input checking - if not isinstance(time_idx, int): - raise TypeError("idx must be int") - if time_idx < -1 or time_idx >= len(self.time): - raise ValueError(f"idx must be within 0 and {len(self.time)-1}, got {time_idx}") - - return self.time[time_idx], self.states[time_idx, ...] +def __getitem__(self, idx): + "Index directly into the states array (time_idx, state_idx, batch_idx)" + return self.states[idx] # %% ../nbs/api/00_models.ipynb 12 @patch_to(DynamicalTrajectory) diff --git a/nbs/api/00_models.ipynb b/nbs/api/00_models.ipynb index b75f9aa..6cbf80d 100644 --- a/nbs/api/00_models.ipynb +++ b/nbs/api/00_models.ipynb @@ -194,15 +194,9 @@ "#| export\n", "#| hide\n", "@patch_to(DynamicalTrajectory)\n", - "def __getitem__(self, time_idx: int) -> Tuple[float, np.ndarray]:\n", - " \"Return the time and state at index idx\"\n", - " # idx input checking\n", - " if not isinstance(time_idx, int):\n", - " raise TypeError(\"idx must be int\")\n", - " if time_idx < -1 or time_idx >= len(self.time):\n", - " raise ValueError(f\"idx must be within 0 and {len(self.time)-1}, got {time_idx}\")\n", - " \n", - " return self.time[time_idx], self.states[time_idx, ...]" + "def __getitem__(self, idx):\n", + " \"Index directly into the states array (time_idx, state_idx, batch_idx)\"\n", + " return self.states[idx]\n" ] }, { @@ -4818,4 +4812,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/nbs/test/test_models.ipynb b/nbs/test/test_models.ipynb index 7b8c66c..795ccad 100644 --- a/nbs/test/test_models.ipynb +++ b/nbs/test/test_models.ipynb @@ -124,23 +124,26 @@ "states[:, 0] = np.sin(time)\n", "states[:, 1] = np.cos(time)\n", "traj = DynamicalTrajectory(time, states)\n", - "test_eq(traj[0], (0.0, np.array([0.0, 1.0])))\n", - "states = traj[-1][1]\n", - "difference = np.abs(np.sum(states - (np.sin(np.pi), np.cos(np.pi))))\n", - "test_eq(difference < 1e-4, True)\n", - "# handle batch\n", + "# index by time point\n", + "test_eq(traj[0], np.array([0.0, 1.0]))\n", + "test_eq(np.abs(traj[-1][0] - np.sin(np.pi)) < 1e-4, True)\n", + "# index by time point and state\n", + "test_eq(traj[0, 0], 0.0)\n", + "test_eq(traj[0, 1], 1.0)\n", + "# slice\n", + "test_eq(traj[:2].shape, (2, variables))\n", + "# batch trajectory\n", "batches = 5\n", "batch_states = np.zeros((total_timepoints, variables, batches))\n", "batch_states[:, 0, :] = np.sin(time)[:, None]\n", "batch_states[:, 1, :] = np.cos(time)[:, None]\n", "batch_traj = DynamicalTrajectory(time, batch_states)\n", - "test_eq(batch_traj[0][0], 0.0)\n", - "test_eq(np.all(batch_traj[0][1][0]==np.zeros(batches)), True)\n", - "test_eq(np.all(batch_traj[0][1][1]==np.ones(batches)), True)\n", - "# test error handling\n", - "test_fail(lambda: traj[\"1\"], contains=\"idx must be int\")\n", - "test_fail(lambda: traj[-2], contains=\"idx must be within 0 and\")\n", - "test_fail(lambda: traj[1000], contains=\"idx must be within 0 and\")" + "test_eq(batch_traj[0].shape, (variables, batches))\n", + "test_eq(np.all(batch_traj[0, 0] == np.zeros(batches)), True)\n", + "test_eq(np.all(batch_traj[0, 1] == np.ones(batches)), True)\n", + "# index by time point, state, and batch\n", + "test_eq(batch_traj[0, 0, 0], 0.0)\n", + "test_eq(batch_traj[0, 1, 0], 1.0)" ] }, { @@ -2161,4 +2164,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/setup.py b/setup.py index 86fa086..afa0d8d 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -from pkg_resources import parse_version +from packaging.version import Version as parse_version from configparser import ConfigParser import setuptools from setuptools import Extension