Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion howso/engine/tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,4 +415,10 @@ def test_react_aggregate(self, data: DataFrame, trainee: Trainee):
assert isinstance(value, DataFrame)

for feature in data.columns:
assert feature in total_df.columns
assert feature in total_df.columns

def test_to_memory(self, trainee):
"""
Test the passthrough to `to_memory` in the Trainee class.
"""
assert isinstance(trainee.to_memory(), bytes)
46 changes: 38 additions & 8 deletions howso/engine/trainee.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import typing as t
import uuid
import warnings
from collections.abc import (
Callable,
Collection,
Generator,
Iterable,
Mapping,
MutableMapping,
)
from copy import deepcopy
from pathlib import Path
import typing as t
import uuid
import warnings

from pandas import (
DataFrame,
Expand All @@ -28,15 +29,16 @@
LocalSaveableProtocol,
ProjectClient,
)
from howso.client.schemas import AggregateReaction, GroupReaction
from howso.client.schemas import Project as BaseProject
from howso.client.schemas import Reaction
from howso.client.schemas import Session as BaseSession
from howso.client.schemas import Trainee as BaseTrainee
from howso.client.schemas import (
AggregateReaction,
GroupReaction,
Reaction,
TraineeRuntime,
TraineeRuntimeOptions,
)
from howso.client.schemas import Project as BaseProject
from howso.client.schemas import Session as BaseSession
from howso.client.schemas import Trainee as BaseTrainee
from howso.client.typing import (
AblationThresholdMap,
CaseIndices,
Expand All @@ -57,6 +59,7 @@
TargetedModel,
ValueMasses,
)
from howso.direct.client import HowsoDirectClient
from howso.engine.client import get_client
from howso.engine.project import Project
from howso.engine.session import Session
Expand Down Expand Up @@ -4428,6 +4431,33 @@ def from_schema(
return schema
return cls.from_dict(dict(schema.to_dict(), client=client))

def to_memory(
self,
*,
file_type: t.Literal["amlg", "caml"] = "amlg",
trainee_path: Iterable[str] | None = None,
) -> bytes | None:
"""
Get the Trainee file data as bytes.

Parameters
----------
file_type : {"amlg", "caml"}, default "amlg"
The type of byte data to return.
trainee_path : Iterable of str, optional
The hierarchy path to a sub-Trainee from the root Trainee specified by `trainee_id`.

Returns
-------
bytes or None
The Trainee file data as bytes. Or None if the `trainee_id` and/or `trainee_path` does not refer to
a valid Trainee.
"""
if not isinstance(self.client, HowsoDirectClient):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might prefer to condition this on the existence of the trainee_to_memory method, or use a runtime-checkable protocol.

@runtime_checkable
class CanLoadSaveMemory(Protocol):
  def trainee_to_memory(self, id: str, file_type: Literal["amlg", "caml"], trainee_path: something) -> bytes | None: ...

if isinstance(self.client, CanLoadSaveMemory):
  return self.client.trainee_to_memory(...)
raise NotImplementedError

raise NotImplementedError("The `to_memory` method is only supported when using `HowsoDirectClient`.")
return self.client.trainee_to_memory(self.id, file_type=file_type, trainee_path=trainee_path)


@classmethod
def from_dict(cls, schema: Mapping) -> Trainee:
"""
Expand Down
Loading