diff --git a/howso/engine/tests/test_engine.py b/howso/engine/tests/test_engine.py index d5060628..96f4ce45 100644 --- a/howso/engine/tests/test_engine.py +++ b/howso/engine/tests/test_engine.py @@ -415,4 +415,10 @@ def test_react_aggregate(self, data: DataFrame, trainee: Trainee): assert isinstance(value, DataFrame) for feature in data.columns: - assert feature in total_df.columns \ No newline at end of file + assert feature in total_df.columns + + def test_to_memory(self, trainee): + """ + Test the passthrough to `to_memory` in the Trainee class. + """ + assert isinstance(trainee.to_memory(), bytes) \ No newline at end of file diff --git a/howso/engine/trainee.py b/howso/engine/trainee.py index c7efdca6..2d12dcad 100644 --- a/howso/engine/trainee.py +++ b/howso/engine/trainee.py @@ -1,17 +1,18 @@ from __future__ import annotations +import typing as t +import uuid +import warnings from collections.abc import ( Callable, Collection, Generator, + Iterable, Mapping, MutableMapping, ) from copy import deepcopy from pathlib import Path -import typing as t -import uuid -import warnings from pandas import ( DataFrame, @@ -28,15 +29,16 @@ LocalSaveableProtocol, ProjectClient, ) -from howso.client.schemas import AggregateReaction, GroupReaction -from howso.client.schemas import Project as BaseProject -from howso.client.schemas import Reaction -from howso.client.schemas import Session as BaseSession -from howso.client.schemas import Trainee as BaseTrainee from howso.client.schemas import ( + AggregateReaction, + GroupReaction, + Reaction, TraineeRuntime, TraineeRuntimeOptions, ) +from howso.client.schemas import Project as BaseProject +from howso.client.schemas import Session as BaseSession +from howso.client.schemas import Trainee as BaseTrainee from howso.client.typing import ( AblationThresholdMap, CaseIndices, @@ -57,6 +59,7 @@ TargetedModel, ValueMasses, ) +from howso.direct.client import HowsoDirectClient from howso.engine.client import get_client from howso.engine.project import Project from howso.engine.session import Session @@ -4428,6 +4431,33 @@ def from_schema( return schema return cls.from_dict(dict(schema.to_dict(), client=client)) + def to_memory( + self, + *, + file_type: t.Literal["amlg", "caml"] = "amlg", + trainee_path: Iterable[str] | None = None, + ) -> bytes | None: + """ + Get the Trainee file data as bytes. + + Parameters + ---------- + file_type : {"amlg", "caml"}, default "amlg" + The type of byte data to return. + trainee_path : Iterable of str, optional + The hierarchy path to a sub-Trainee from the root Trainee specified by `trainee_id`. + + Returns + ------- + bytes or None + The Trainee file data as bytes. Or None if the `trainee_id` and/or `trainee_path` does not refer to + a valid Trainee. + """ + if not isinstance(self.client, HowsoDirectClient): + raise NotImplementedError("The `to_memory` method is only supported when using `HowsoDirectClient`.") + return self.client.trainee_to_memory(self.id, file_type=file_type, trainee_path=trainee_path) + + @classmethod def from_dict(cls, schema: Mapping) -> Trainee: """