diff --git a/changes/120.feature.rst b/changes/120.feature.rst new file mode 100644 index 00000000..c3247829 --- /dev/null +++ b/changes/120.feature.rst @@ -0,0 +1,2 @@ +Refactor ``AbstractDataModel`` into a Python protocol as that is how it was +effectively used. diff --git a/docs/source/model_library.rst b/docs/source/model_library.rst index c065b74c..9c7f8073 100644 --- a/docs/source/model_library.rst +++ b/docs/source/model_library.rst @@ -4,7 +4,7 @@ Model Library ============= `~stpipe.library.AbstractModelLibrary` is a container designed to allow efficient processing of -collections of `~stpipe.datamodel.AbstractDataModel` instances created from an association. +collections of `~stpipe.protocols.DataModel` instances created from an association. `~stpipe.library.AbstractModelLibrary` is an ordered collection (like a `list`) but provides: @@ -234,8 +234,8 @@ allowing the `~stpipe.step.Step` to generate an :ref:`library_on_disk` ``Step.process`` can extend the above pattern to support additional inputs (for example a single -`~stpipe.datamodel.AbstractDataModel` or filename containing -a `~stpipe.datamodel.AbstractDataModel`) to allow more +`~stpipe.protocols.DataModel` or filename containing +a `~stpipe.protocols.DataModel`) to allow more flexible data processings, although some consideration should be given to how to handle input that does not contain association metadata. Does it make sense @@ -252,9 +252,9 @@ Isolated Processing Let's say we have a `~stpipe.step.Step`, ``flux_calibration`` that performs an operation that is only concerned with the data -for a single `~stpipe.datamodel.AbstractDataModel` at a time. +for a single `~stpipe.protocols.DataModel` at a time. This step applies a function ``calibrate_model_flux`` that -accepts a single `~stpipe.datamodel.AbstractDataModel` and index as an input. +accepts a single `~stpipe.protocols.DataModel` and index as an input. Its ``Step.process`` function can make good use of `~stpipe.library.AbstractModelLibrary.map_function` to apply this method to each model in the library. diff --git a/src/stpipe/config_parser.py b/src/stpipe/config_parser.py index cb1af2a1..9f376cdf 100644 --- a/src/stpipe/config_parser.py +++ b/src/stpipe/config_parser.py @@ -21,7 +21,7 @@ from . import utilities from .config import StepConfig -from .datamodel import AbstractDataModel +from .protocols import DataModel from .utilities import _not_set # Configure logger @@ -82,7 +82,7 @@ def _output_file_check(path): def _is_datamodel(value): """Verify that value is a DataModel.""" - if isinstance(value, AbstractDataModel): + if isinstance(value, DataModel): return value raise VdtTypeError(value) @@ -92,7 +92,7 @@ def _is_string_or_datamodel(value): """Verify that value is either a string (nominally a reference file path) or a DataModel (possibly one with no corresponding file.) """ - if isinstance(value, AbstractDataModel): + if isinstance(value, DataModel): return value if isinstance(value, str): diff --git a/src/stpipe/library.py b/src/stpipe/library.py index 79f295af..d8aa4a9e 100644 --- a/src/stpipe/library.py +++ b/src/stpipe/library.py @@ -10,7 +10,7 @@ import asdf -from .datamodel import AbstractDataModel +from .protocols import DataModel __all__ = [ "LibraryError", @@ -771,7 +771,7 @@ def _model_to_filename(self, model): return model_filename def _to_group_id(self, model_or_filename, index): - if isinstance(model_or_filename, AbstractDataModel): + if isinstance(model_or_filename, DataModel): getter = self._model_to_group_id else: getter = self._filename_to_group_id diff --git a/src/stpipe/pipeline.py b/src/stpipe/pipeline.py index 70eaf00c..4599071d 100644 --- a/src/stpipe/pipeline.py +++ b/src/stpipe/pipeline.py @@ -151,7 +151,7 @@ def get_config_from_reference(cls, dataset, disable=None, crds_observatory=None) Either a class or instance of a class derived from `Step`. - dataset : `stpipe.datamodel.AbstractDataModel` + dataset : `stpipe.protocols.DataModel` A model of the input file. Metadata on this input file will be used by the CRDS "bestref" algorithm to obtain a reference file. diff --git a/src/stpipe/datamodel.py b/src/stpipe/protocols.py similarity index 50% rename from src/stpipe/datamodel.py rename to src/stpipe/protocols.py index 008e2aaa..c58739dd 100644 --- a/src/stpipe/datamodel.py +++ b/src/stpipe/protocols.py @@ -1,50 +1,54 @@ -import abc +from __future__ import annotations +from abc import abstractmethod +from typing import TYPE_CHECKING, Protocol, runtime_checkable -class AbstractDataModel(abc.ABC): +if TYPE_CHECKING: + from collections.abc import Callable + from os import PathLike + + +@runtime_checkable +class DataModel(Protocol): """ - This Abstract Base Class is intended to cover multiple implementations of - data models so that each will be considered an appropriate subclass of this - class without requiring that they inherit this class. + This is a protocol to describe the methods and properties that define a + DataModel for the purposes of stpipe. This is a runtime checkable protocol + meaning that any object can be `isinstance` checked against this protocol + and will succeed even if the object does not inherit from this class. + Moreover, this object will act as an `abc.ABC` class if it is inherited from. Any datamodel class instance that desires to be considered an instance of - AbstractDataModel must implement the following methods. + must fully implement the protocol in order to pass the `isinstance` check. In addition, although it isn't yet checked (the best approach for supporting this is still being considered), such instances must have a meta.filename attribute. """ - @classmethod - def __subclasshook__(cls, c_): - """ - Pseudo subclass check based on these attributes and methods - """ - if cls is AbstractDataModel: - mro = c_.__mro__ - if ( - any(hasattr(CC, "crds_observatory") for CC in mro) - and any(hasattr(CC, "get_crds_parameters") for CC in mro) - and any(hasattr(CC, "save") for CC in mro) - ): - return True - return False - @property - @abc.abstractmethod - def crds_observatory(self): + @abstractmethod + def crds_observatory(self) -> str: """This should return a string identifying the observatory as CRDS expects it""" + ... - @abc.abstractmethod - def get_crds_parameters(self): + @property + @abstractmethod + def get_crds_parameters(self) -> dict[str, any]: """ This should return a dictionary of key/value pairs corresponding to the parkey values CRDS is using to match reference files. Typically it returns all metadata simple values. """ + ... - @abc.abstractmethod - def save(self, path, dir_path=None, *args, **kwargs): + @abstractmethod + def save( + self, + path: PathLike | Callable[..., PathLike], + dir_path: PathLike | None = None, + *args, + **kwargs, + ) -> PathLike: """ Save to a file. @@ -64,3 +68,4 @@ def save(self, path, dir_path=None, *args, **kwargs): output_path: str The file path the model was saved in. """ + ... diff --git a/src/stpipe/step.py b/src/stpipe/step.py index 8776eb5f..c6e5498d 100644 --- a/src/stpipe/step.py +++ b/src/stpipe/step.py @@ -33,9 +33,9 @@ DISCOURAGED_TYPES = None from . import config, config_parser, crds_client, log, utilities -from .datamodel import AbstractDataModel from .format_template import FormatTemplate from .library import AbstractModelLibrary +from .protocols import DataModel from .utilities import _not_set @@ -398,7 +398,7 @@ def _check_args(self, args, discouraged_types, msg): for i, arg in enumerate(args): if isinstance(arg, discouraged_types): self.log.error( - "%s %s object. Use an instance of AbstractDataModel instead.", + "%s %s object. Use an instance of DataModel instead.", msg, i, ) @@ -492,7 +492,7 @@ def run(self, *args): e, ) library.shelve(model, i) - elif isinstance(args[0], AbstractDataModel): + elif isinstance(args[0], DataModel): if self.class_alias is not None: if isinstance(args[0], Sequence): for model in args[0]: @@ -506,7 +506,7 @@ def run(self, *args): "header: %s", e, ) - elif isinstance(args[0], AbstractDataModel): + elif isinstance(args[0], DataModel): try: args[0][ f"meta.cal_step.{self.class_alias}" @@ -567,9 +567,7 @@ def run(self, *args): for idx, result in enumerate(results_to_save): if len(results_to_save) <= 1: idx = None - if isinstance( - result, (AbstractDataModel | AbstractModelLibrary) - ): + if isinstance(result, (DataModel | AbstractModelLibrary)): self.save_model(result, idx=idx) elif hasattr(result, "save"): try: @@ -609,7 +607,7 @@ def finalize_result(self, result, reference_files_used): Parameters ---------- - result : a datamodel that is an instance of AbstractDataModel or + result : a datamodel that is an instance of DataModel or collections.abc.Sequence One step result (potentially of many). @@ -780,7 +778,7 @@ def get_ref_override(self, reference_file_type): """ override_name = crds_client.get_override_name(reference_file_type) path = getattr(self, override_name, None) - if isinstance(path, AbstractDataModel): + if isinstance(path, DataModel): return path return abspath(path) if path and path != "N/A" else path @@ -795,7 +793,7 @@ def get_reference_file(self, input_file, reference_file_type): Parameters ---------- - input_file : a datamodel that is an instance of AbstractDataModel + input_file : a datamodel that is an instance of DataModel A model of the input file. Metadata on this input file will be used by the CRDS "bestref" algorithm to obtain a reference file. @@ -810,7 +808,7 @@ def get_reference_file(self, input_file, reference_file_type): """ override = self.get_ref_override(reference_file_type) if override is not None: - if isinstance(override, AbstractDataModel): + if isinstance(override, DataModel): self._reference_files_used.append( (reference_file_type, override.override_handle) ) @@ -846,7 +844,7 @@ def get_config_from_reference(cls, dataset, disable=None, crds_observatory=None) cls : stpipe.Step Either a class or instance of a class derived from `Step`. - dataset : A datamodel that is an instance of AbstractDataModel + dataset : A datamodel that is an instance of DataModel A model of the input file. Metadata on this input file will be used by the CRDS "bestref" algorithm to obtain a reference file. @@ -874,7 +872,7 @@ def get_config_from_reference(cls, dataset, disable=None, crds_observatory=None) if crds_observatory is None: raise ValueError("Need a valid name for crds_observatory.") else: - # If the dataset is not an operable instance of AbstractDataModel, + # If the dataset is not an operable instance of DataModel, # log as such and return an empty config object try: with cls._datamodels_open(dataset, asn_n_members=1) as model: @@ -884,7 +882,7 @@ def get_config_from_reference(cls, dataset, disable=None, crds_observatory=None) crds_parameters = model.get_crds_parameters() crds_observatory = model.crds_observatory except (OSError, TypeError, ValueError): - logger.warning("Input dataset is not an instance of AbstractDataModel.") + logger.warning("Input dataset is not an instance of DataModel.") disable = True # Check if retrieval should be attempted. @@ -930,7 +928,7 @@ def set_primary_input(self, obj, exclusive=True): Parameters ---------- - obj : str, pathlib.Path, or instance of AbstractDataModel + obj : str, pathlib.Path, or instance of DataModel The object to base the name on. If a datamodel, use Datamodel.meta.filename. @@ -945,7 +943,7 @@ def set_primary_input(self, obj, exclusive=True): if not exclusive or parent_input_filename is None: if isinstance(obj, str | Path): self._input_filename = str(obj) - elif isinstance(obj, AbstractDataModel): + elif isinstance(obj, DataModel): try: self._input_filename = obj.meta.filename except AttributeError: @@ -967,7 +965,7 @@ def save_model( Parameters ---------- - model : a instance of AbstractDataModel + model : a instance of DataModel The model to save. suffix : str @@ -1166,7 +1164,7 @@ def open_model(self, init, **kwargs): Returns ------- - datamodel : instance of AbstractDataModel + datamodel : instance of DataModel Object opened as a datamodel """ # Use the parent method if available, since this step diff --git a/src/stpipe/subproc.py b/src/stpipe/subproc.py index b9c0ad04..31c7771c 100644 --- a/src/stpipe/subproc.py +++ b/src/stpipe/subproc.py @@ -1,7 +1,7 @@ import os import subprocess -from .datamodel import AbstractDataModel +from .protocols import DataModel from .step import Step @@ -28,7 +28,7 @@ class SystemCall(Step): def process(self, *args): newargs = [] for i, arg in enumerate(args): - if isinstance(arg, AbstractDataModel): + if isinstance(arg, DataModel): filename = f"{self.qualified_name}.{i:04d}.{self.output_ext}" arg.save(filename) newargs.append(filename) diff --git a/tests/test_abstract_datamodel.py b/tests/test_abstract_datamodel.py deleted file mode 100644 index 3dcaa740..00000000 --- a/tests/test_abstract_datamodel.py +++ /dev/null @@ -1,57 +0,0 @@ -""" -Test that the AbstractDataModel interface works properly -""" - -import pytest - -from stpipe.datamodel import AbstractDataModel - - -def test_roman_datamodel(): - roman_datamodels = pytest.importorskip("roman_datamodels.datamodels") - from roman_datamodels.maker_utils import mk_level2_image - - roman_image_tree = mk_level2_image() - image_model = roman_datamodels.ImageModel(roman_image_tree) - assert isinstance(image_model, AbstractDataModel) - - -def test_jwst_datamodel(): - jwst_datamodel = pytest.importorskip("stdatamodels.jwst.datamodels") - image_model = jwst_datamodel.ImageModel() - assert isinstance(image_model, AbstractDataModel) - - -class GoodDataModel: - def __init__(self): - pass - - def crds_observatory(self): - pass - - def get_crds_parameters(self): - pass - - def save(self): - pass - - -class BadDataModel: - def __init__(self): - pass - - def crds_observatory(self): - pass - - def get_crds_parameters(self): - pass - - -def test_good_datamodel(): - gdm = GoodDataModel() - assert isinstance(gdm, AbstractDataModel) - - -def test_bad_datamodel(): - gdm = BadDataModel() - assert not isinstance(gdm, AbstractDataModel) diff --git a/tests/test_jwst.py b/tests/test_jwst.py new file mode 100644 index 00000000..4bb300b9 --- /dev/null +++ b/tests/test_jwst.py @@ -0,0 +1,31 @@ +""" +Integration tests with JWST pipeline +""" + +from inspect import getmembers, isclass + +import pytest + +from stpipe.protocols import DataModel + +datamodels = pytest.importorskip("stdatamodels.jwst.datamodels") + + +def test_jwst_datamodel(): + """Smoke test to ensure the JWST datamodels work with the DataModel protocol.""" + jwst_datamodel = pytest.importorskip("stdatamodels.jwst.datamodels") + image_model = jwst_datamodel.ImageModel() + assert isinstance(image_model, DataModel) + + +@pytest.mark.parametrize( + "model", + [ + model[1] + for model in getmembers(datamodels, isclass) + if issubclass(model[1], datamodels.JwstDataModel) + ], +) +def test_datamodel(model): + """Test that all JWST datamodels work with the DataModel protocol.""" + assert isinstance(model(), DataModel) diff --git a/tests/test_library.py b/tests/test_library.py index 5a7e0853..3d142e1c 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -8,7 +8,6 @@ import asdf import pytest -from stpipe.datamodel import AbstractDataModel from stpipe.library import ( AbstractModelLibrary, BorrowError, @@ -16,6 +15,7 @@ NoGroupID, _Ledger, ) +from stpipe.protocols import DataModel _GROUP_IDS = ["1", "1", "2"] _N_MODELS = len(_GROUP_IDS) @@ -38,7 +38,7 @@ class _Meta: pass -class DataModel: +class ExampleDataModel: def __init__(self, **kwargs): self.meta = _Meta() self.meta.__dict__.update(kwargs) @@ -57,7 +57,7 @@ def save(self, path, **kwargs): def _load_model(filename): with asdf.open(filename) as af: - return DataModel(**af.tree) + return ExampleDataModel(**af.tree) class ModelLibrary(AbstractModelLibrary): @@ -113,7 +113,7 @@ def example_models(): """ models = [] for i in range(_N_MODELS): - m = DataModel(group_id=_GROUP_IDS[i], index=i) + m = ExampleDataModel(group_id=_GROUP_IDS[i], index=i) m.meta.filename = f"{i}.asdf" models.append(m) return models @@ -219,7 +219,7 @@ def test_init_from_models_no_ondisk(example_models): ModelLibrary(example_models, on_disk=True) -@pytest.mark.parametrize("invalid", (None, ModelLibrary([]), DataModel())) +@pytest.mark.parametrize("invalid", (None, ModelLibrary([]), ExampleDataModel())) def test_invalid_init(invalid): """ Test that some unsupported init values produce errors. @@ -408,7 +408,7 @@ def test_closed_library_model_shelve(example_library): an error. """ with pytest.raises(ClosedLibraryError, match="ModelLibrary is not open"): - example_library.shelve(DataModel(), 0) + example_library.shelve(ExampleDataModel(), 0) def test_closed_library_model_iter(example_library): @@ -577,7 +577,7 @@ def test_shelve_unknown_model(example_library, use_index): with lib_ctx: # to catch the error for the un-returned model with example_library: example_library.borrow(0) - new_model = DataModel() + new_model = ExampleDataModel() if use_index: ctx = contextlib.nullcontext() @@ -758,7 +758,7 @@ def test_ledger(): based on index and index based on models. """ ledger = _Ledger() - model = DataModel() + model = ExampleDataModel() ledger[0] = model assert ledger[0] == model assert ledger[model] == 0 @@ -773,10 +773,21 @@ def test_ledger(): def test_library_datamodel_relationship(): """ Smoke test to make sure the relationship between - AbstractModelLibrary and AbstractDataModel doesn't + AbstractModelLibrary and DataModel doesn't change. """ - assert not issubclass(AbstractModelLibrary, AbstractDataModel) + dm = ExampleDataModel() + lib = ModelLibrary([]) + + assert isinstance(dm, DataModel) + assert isinstance(lib, AbstractModelLibrary) + + assert not isinstance(dm, AbstractModelLibrary) + assert not isinstance(lib, DataModel) + + # Protocols don't support issubclass for complex reasons + # so we check the relationship with isinstance on instances instead + # assert not issubclass(AbstractModelLibrary, DataModel) def test_library_is_not_sequence(): diff --git a/tests/test_protocols.py b/tests/test_protocols.py new file mode 100644 index 00000000..eb4239a7 --- /dev/null +++ b/tests/test_protocols.py @@ -0,0 +1,40 @@ +""" +Test that the DataModel interface of the protocol works properly +""" + +from stpipe.protocols import DataModel + + +class GoodDataModel: + def __init__(self): + pass + + def crds_observatory(self): + pass + + def get_crds_parameters(self): + pass + + def save(self): + pass + + +class BadDataModel: + def __init__(self): + pass + + def crds_observatory(self): + pass + + def get_crds_parameters(self): + pass + + +def test_good_datamodel(): + gdm = GoodDataModel() + assert isinstance(gdm, DataModel) + + +def test_bad_datamodel(): + gdm = BadDataModel() + assert not isinstance(gdm, DataModel) diff --git a/tests/test_roman.py b/tests/test_roman.py new file mode 100644 index 00000000..7e7c31ff --- /dev/null +++ b/tests/test_roman.py @@ -0,0 +1,35 @@ +""" +Integration tests with Roman pipeline +""" + +from inspect import getmembers, isclass + +import pytest + +from stpipe.protocols import DataModel + +datamodels = pytest.importorskip("roman_datamodels.datamodels") + + +def test_roman_datamodel(): + """Smoke test to ensure the Roman datamodels work with the DataModel protocol.""" + roman_datamodels = pytest.importorskip("roman_datamodels.datamodels") + from roman_datamodels.maker_utils import mk_level2_image + + roman_image_tree = mk_level2_image() + image_model = roman_datamodels.ImageModel(roman_image_tree) + assert isinstance(image_model, DataModel) + + +@pytest.mark.parametrize( + "model", + [ + model[1] + for model in getmembers(datamodels, isclass) + if model[1] != datamodels.DataModel + and issubclass(model[1], datamodels.DataModel) + ], +) +def test_datamodel(model): + """Test that all Roman datamodels work with the DataModel protocol.""" + assert isinstance(model(), DataModel)