diff --git a/pyproject.toml b/pyproject.toml index 3852d1c..28f218e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,14 @@ Source = "https://github.com/aiidateam/aiida-pythonjob" [project.entry-points."aiida.data"] "pythonjob.pickled_data" = "aiida_pythonjob.data.pickled_data:PickledData" "pythonjob.pickled_function" = "aiida_pythonjob.data.pickled_function:PickledFunction" +"pythonjob.ase.atoms.Atoms" = "aiida_pythonjob.data.atoms:AtomsData" +"pythonjob.builtins.int" = "aiida.orm.nodes.data.int:Int" +"pythonjob.builtins.float" = "aiida.orm.nodes.data.float:Float" +"pythonjob.builtins.str" = "aiida.orm.nodes.data.str:Str" +"pythonjob.builtins.bool" = "aiida.orm.nodes.data.bool:Bool" +"pythonjob.builtins.list"="aiida_pythonjob.data.data_with_value:List" +"pythonjob.builtins.dict"="aiida_pythonjob.data.data_with_value:Dict" + [project.entry-points."aiida.calculations"] "pythonjob.pythonjob" = "aiida_pythonjob.calculations.pythonjob:PythonJob" diff --git a/src/aiida_pythonjob/data/__init__.py b/src/aiida_pythonjob/data/__init__.py index 2a00bfc..8d543fc 100644 --- a/src/aiida_pythonjob/data/__init__.py +++ b/src/aiida_pythonjob/data/__init__.py @@ -1,4 +1,5 @@ from .pickled_data import PickledData from .pickled_function import PickledFunction +from .serializer import general_serializer, serialize_to_aiida_nodes -__all__ = ("PickledData", "PickledFunction") +__all__ = ("PickledData", "PickledFunction", "serialize_to_aiida_nodes", "general_serializer") diff --git a/src/aiida_pythonjob/data/atoms.py b/src/aiida_pythonjob/data/atoms.py new file mode 100644 index 0000000..262794d --- /dev/null +++ b/src/aiida_pythonjob/data/atoms.py @@ -0,0 +1,53 @@ +import numpy as np +from aiida.orm import Data +from ase import Atoms +from ase.db.row import atoms2dict + +__all__ = ("AtomsData",) + + +class AtomsData(Data): + """Data to represent a ASE Atoms.""" + + _cached_atoms = None + + def __init__(self, value=None, **kwargs): + """Initialise a `AtomsData` node instance. + + :param value: ASE Atoms instance to initialise the `AtomsData` node from + """ + atoms = value or Atoms() + super().__init__(**kwargs) + data, keys = self.atoms2dict(atoms) + self.base.attributes.set_many(data) + self.base.attributes.set("keys", keys) + + @classmethod + def atoms2dict(cls, atoms): + data = atoms2dict(atoms) + data.pop("unique_id") + keys = list(data.keys()) + formula = atoms.get_chemical_formula() + data = cls._convert_numpy_to_native(data) + data["formula"] = formula + data["symbols"] = atoms.get_chemical_symbols() + return data, keys + + @classmethod + def _convert_numpy_to_native(cls, data): + """Convert numpy types to Python native types for JSON compatibility.""" + for key, value in data.items(): + if isinstance(value, np.bool_): + data[key] = bool(value) + elif isinstance(value, np.ndarray): + data[key] = value.tolist() + elif isinstance(value, np.generic): + data[key] = value.item() + return data + + @property + def value(self): + keys = self.base.attributes.get("keys") + data = self.base.attributes.get_many(keys) + data = dict(zip(keys, data)) + return Atoms(**data) diff --git a/src/aiida_pythonjob/data/data_with_value.py b/src/aiida_pythonjob/data/data_with_value.py new file mode 100644 index 0000000..469b810 --- /dev/null +++ b/src/aiida_pythonjob/data/data_with_value.py @@ -0,0 +1,13 @@ +from aiida import orm + + +class Dict(orm.Dict): + @property + def value(self): + return self.get_dict() + + +class List(orm.List): + @property + def value(self): + return self.get_list() diff --git a/src/aiida_pythonjob/data/pickled_data.py b/src/aiida_pythonjob/data/pickled_data.py index a0afdaa..e8bbed0 100644 --- a/src/aiida_pythonjob/data/pickled_data.py +++ b/src/aiida_pythonjob/data/pickled_data.py @@ -7,18 +7,6 @@ from aiida import orm -class Dict(orm.Dict): - @property - def value(self): - return self.get_dict() - - -class List(orm.List): - @property - def value(self): - return self.get_list() - - class PickledData(orm.Data): """Data to represent a pickled value using cloudpickle.""" diff --git a/tests/test_data.py b/tests/test_data.py index 185d404..cf3d981 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -1,3 +1,4 @@ +import aiida from aiida_pythonjob import PickledFunction @@ -21,3 +22,15 @@ def generate_structures( "builtins": {"list", "float"}, "numpy": {"array"}, } + + +def test_python_job(): + """Test a simple python node.""" + from aiida_pythonjob.data.pickled_data import PickledData + from aiida_pythonjob.data.serializer import serialize_to_aiida_nodes + + inputs = {"a": 1, "b": 2.0, "c": set()} + new_inputs = serialize_to_aiida_nodes(inputs) + assert isinstance(new_inputs["a"], aiida.orm.Int) + assert isinstance(new_inputs["b"], aiida.orm.Float) + assert isinstance(new_inputs["c"], PickledData)