From 91bbbe25eef065a3dfb7c9c1a805f2b8154106aa Mon Sep 17 00:00:00 2001 From: NikoOinonen Date: Tue, 12 Mar 2024 16:28:45 +0200 Subject: [PATCH] Added test for TarDataGenerator. --- mlspm/data_generation.py | 6 ++- tests/test_data_generation.py | 93 +++++++++++++++++++++++++++++++---- 2 files changed, 89 insertions(+), 10 deletions(-) diff --git a/mlspm/data_generation.py b/mlspm/data_generation.py index c1a910b..a254958 100644 --- a/mlspm/data_generation.py +++ b/mlspm/data_generation.py @@ -133,12 +133,16 @@ class TarDataGenerator: The npz files should contain the following entries: - - ``'array'``: An array containing the potential/density on a 3D grid. + - ``'data'``: An array containing the potential/density on a 3D grid. The potential is assumed to be in + units of eV and density in units of e/Å^3. - ``'origin'``: Lattice origin in 3D space as an array of shape ``(3,)``. - ``'lattice'``: Lattice vectors as an array of shape ``(3, 3)``, where the rows are the vectors. - ``'xyz'``: Atom xyz coordinates as an array of shape ``(n_atoms, 3)``. - ``'Z'``: Atom atomic numbers as an array of shape ``(n_atoms,)``. + Note: it is recommended to use ``multiprocessing.set_start_method('spawn')`` when using the :class:`TarDataGenerator`. + Otherwise a lot of warnings about leaked memory objects may be thrown on exit. + Arguments: samples: List of sample dicts as :class:`TarSampleList`. File paths should be relative to ``base_path``. base_path: Path to the directory with the tar files. diff --git a/tests/test_data_generation.py b/tests/test_data_generation.py index ff77e6a..2bce31e 100644 --- a/tests/test_data_generation.py +++ b/tests/test_data_generation.py @@ -1,4 +1,4 @@ - +import io from pathlib import Path from shutil import rmtree import tarfile @@ -10,8 +10,8 @@ def test_tar_writer(): from mlspm.data_generation import TarWriter - base_path = Path('./test_writer') - base_name = 'test' + base_path = Path("./test_writer") + base_name = "test" base_path.mkdir(exist_ok=True) @@ -20,12 +20,12 @@ def test_tar_writer(): X = [np.random.rand(128, 128, 10), np.random.rand(128, 128, 10)] Y = [np.random.rand(128, 128), np.random.rand(128, 128)] xyzs = np.concatenate([np.random.rand(10, 3), np.random.randint(1, 10, (10, 1))], axis=1) - tar_writer.add_sample(X, xyzs, Y, comment_str='test comment') + tar_writer.add_sample(X, xyzs, Y, comment_str="test comment") - assert (base_path / 'test_0.tar').exists() - assert (base_path / 'test_1.tar').exists() + assert (base_path / "test_0.tar").exists() + assert (base_path / "test_1.tar").exists() - with tarfile.open(base_path / 'test_0.tar') as ft: + with tarfile.open(base_path / "test_0.tar") as ft: names = [m.name for m in ft.getmembers()] assert len(names) == 10 * (2 * 10 + 1 + 2) assert "0.00.0.png" in names @@ -42,5 +42,80 @@ def test_tar_writer(): pass rmtree(base_path) - -test_tar_writer() \ No newline at end of file + + +def test_tar_data_generator(): + + from mlspm.data_generation import TarDataGenerator, get_tarinfo + from ppafm.ocl.oclUtils import init_env + import multiprocessing as mp + + # Without this will throw a lot of warnings on exit + # mp.set_start_method('spawn') + + # Loading data into HartreePotential etc. in TarDataGenerator requires for the pyopencl context to be setup + init_env(i_platform=0) + + # Make dummy data + tar_path_hartree = Path("./test_hartree.tar") + tar_path_rho = Path("./test_rho.tar") + n_sample = 5 + with tarfile.open(tar_path_hartree, "w") as f_hartree, tarfile.open(tar_path_rho, "w") as f_rho: + + hartrees = [] + rhos = [] + xyzs = [] + Zs = [] + lvecs = [] + rots = [] + names = [] + + for i_sample in range(n_sample): + + hartree = np.random.rand(10, 15, 12).astype(np.float32) + rho = np.random.rand(12, 10, 8) + xyz = np.random.rand(10, 3) + Z = np.random.rand(10) + lvec = np.random.rand(4, 3) + rot = np.random.rand(3, 3) + + hartree_bytes = io.BytesIO() + rho_bytes = io.BytesIO() + np.savez(hartree_bytes, data=hartree, origin=lvec[0], lattice=lvec[1:], xyz=xyz, Z=Z) + np.savez(rho_bytes, data=rho, origin=lvec[0], lattice=lvec[1:], xyz=xyz, Z=Z) + hartree_bytes.seek(0) + rho_bytes.seek(0) + + name = f"{i_sample}.npz" + f_hartree.addfile(get_tarinfo(name, hartree_bytes), hartree_bytes) + f_rho.addfile(get_tarinfo(name, rho_bytes), rho_bytes) + + hartrees.append(hartree) + rhos.append(rho) + xyzs.append(xyz) + Zs.append(Z) + lvecs.append(lvec) + rots.append([rot]) + names.append(name) + + sample_list = [ + { + "hartree": (tar_path_hartree, names), + "rho": (tar_path_rho, names), + "rots": rots, + } + ] + + generator = TarDataGenerator(sample_list, base_path='./', n_proc=1) + + for i_sample, sample in enumerate(generator): + assert np.allclose(sample['xyzs'], xyzs[i_sample]) + assert np.allclose(sample['Zs'], Zs[i_sample]) + assert np.allclose(sample['rot'], rots[i_sample]) + assert np.allclose(sample['qs'].array, -hartrees[i_sample]) + assert np.allclose(sample['qs'].lvec, lvecs[i_sample]) + assert np.allclose(sample['rho_sample'].array, rhos[i_sample]) + assert np.allclose(sample['rho_sample'].lvec, lvecs[i_sample]) + + tar_path_hartree.unlink() + tar_path_rho.unlink()