From 4115397c6ca142b94a0a7a875fdfdfe84917075e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Robert?= Date: Tue, 10 Oct 2023 19:02:59 +0200 Subject: [PATCH] TST: migrate out of answer test framework and to pytest --- .../light_cone/tests/test_light_cone.py | 139 ++++++++---------- .../halo_catalog/analysis_operators.py | 6 + .../halo_analysis/tests/test_halo_catalog.py | 100 ++++++------- .../halo_analysis/tests/test_halo_finders.py | 42 +++--- .../tests/test_halo_finders_ts.py | 66 ++++----- .../tests/test_radmc3d_exporter.py | 99 ++++--------- yt_astro_analysis/utilities/testing.py | 44 +++--- 7 files changed, 216 insertions(+), 280 deletions(-) diff --git a/yt_astro_analysis/cosmological_observation/light_cone/tests/test_light_cone.py b/yt_astro_analysis/cosmological_observation/light_cone/tests/test_light_cone.py index cada2c37..1f28e935 100644 --- a/yt_astro_analysis/cosmological_observation/light_cone/tests/test_light_cone.py +++ b/yt_astro_analysis/cosmological_observation/light_cone/tests/test_light_cone.py @@ -14,96 +14,79 @@ # ----------------------------------------------------------------------------- import os -import shutil -import tempfile +import h5py import numpy as np +import numpy.testing as npt +import pytest +import unyt as un -from yt.testing import assert_equal -from yt.units.yt_array import YTQuantity -from yt.utilities.answer_testing.framework import AnswerTestingTest -from yt.utilities.on_demand_imports import _h5py as h5py +import yt # noqa +from yt.testing import requires_file from yt_astro_analysis.cosmological_observation.api import LightCone -from yt_astro_analysis.utilities.testing import requires_sim ETC = "enzo_tiny_cosmology/32Mpc_32.enzo" _funits = { - "density": YTQuantity(1, "g/cm**3"), - "temperature": YTQuantity(1, "K"), - "length": YTQuantity(1, "cm"), + "density": un.unyt_quantity(1, "g/cm**3"), + "temperature": un.unyt_quantity(1, "K"), + "length": un.unyt_quantity(1, "cm"), } -class LightConeProjectionTest(AnswerTestingTest): - _type_name = "LightConeProjection" - _attrs = () - - def __init__(self, parameter_file, simulation_type, field, weight_field=None): - self.parameter_file = parameter_file - self.simulation_type = simulation_type - self.ds = os.path.basename(self.parameter_file) - self.field = field - self.weight_field = weight_field - - @property - def storage_name(self): - return "_".join( - (os.path.basename(self.parameter_file), self.field, str(self.weight_field)) - ) - - def run(self): - # Set up in a temp dir - tmpdir = tempfile.mkdtemp() - curdir = os.getcwd() - os.chdir(tmpdir) - - lc = LightCone( - self.parameter_file, - self.simulation_type, - 0.0, - 0.1, - observer_redshift=0.0, - time_data=False, - ) - lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt") - lc.project_light_cone( - (600.0, "arcmin"), - (60.0, "arcsec"), - self.field, - weight_field=self.weight_field, - save_stack=True, - ) - - dname = f"{self.field}_{self.weight_field}" - fh = h5py.File("LC/LightCone.h5", mode="r") +@requires_file(ETC) +@pytest.mark.parametrize( + "field, weight_field, expected", + [ + ( + "density", + None, + [6.0000463633868075e-05, 1.1336502301470154e-05, 0.08970763360935877], + ), + ( + "temperature", + "density", + [37.79481498628398, 0.018410545597485613, 543702.4613479003], + ), + ], +) +def test_light_cone_projection(tmp_path, field, weight_field, expected): + parameter_file = ETC + simulation_type = "Enzo" + field = field + weight_field = weight_field + + os.chdir(tmp_path) + lc = LightCone( + parameter_file, + simulation_type, + near_redshift=0.0, + far_redshift=0.1, + observer_redshift=0.0, + time_data=False, + ) + lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt") + lc.project_light_cone( + (600.0, "arcmin"), + (60.0, "arcsec"), + field, + weight_field=weight_field, + save_stack=True, + ) + + dname = f"{field}_{weight_field}" + with h5py.File("LC/LightCone.h5", mode="r") as fh: data = fh[dname][()] units = fh[dname].attrs["units"] - if self.weight_field is None: - punits = _funits[self.field] * _funits["length"] + if weight_field is None: + punits = _funits[field] * _funits["length"] else: - punits = ( - _funits[self.field] * _funits[self.weight_field] * _funits["length"] - ) - wunits = fh["weight_field_%s" % self.weight_field].attrs["units"] - pwunits = _funits[self.weight_field] * _funits["length"] + punits = _funits[field] * _funits[weight_field] * _funits["length"] + wunits = fh[f"weight_field_{weight_field}"].attrs["units"] + pwunits = _funits[weight_field] * _funits["length"] assert wunits == str(pwunits.units) - assert units == str(punits.units) - fh.close() - - # clean up - os.chdir(curdir) - shutil.rmtree(tmpdir) - - mean = data.mean() - mi = data[data.nonzero()].min() - ma = data.max() - return np.array([mean, mi, ma]) - - def compare(self, new_result, old_result): - assert_equal(new_result, old_result, verbose=True) - + assert units == str(punits.units) -@requires_sim(ETC, "Enzo") -def test_light_cone_projection(): - yield LightConeProjectionTest(ETC, "Enzo", "density") - yield LightConeProjectionTest(ETC, "Enzo", "temperature", weight_field="density") + mean = np.nanmean(data) + mi = np.nanmin(data[data.nonzero()]) + ma = np.nanmax(data) + npt.assert_equal([mean, mi, ma], expected, verbose=True) diff --git a/yt_astro_analysis/halo_analysis/halo_catalog/analysis_operators.py b/yt_astro_analysis/halo_analysis/halo_catalog/analysis_operators.py index 857a4223..c2888115 100644 --- a/yt_astro_analysis/halo_analysis/halo_catalog/analysis_operators.py +++ b/yt_astro_analysis/halo_analysis/halo_catalog/analysis_operators.py @@ -63,6 +63,12 @@ def add_quantity(name, function): quantity_registry[name] = AnalysisQuantity(function) +def _remove_quantity(name): + # this is useful to avoid test pollution when using add_quantity in tests + # but it's not meant as public API + quantity_registry.pop(name) + + class AnalysisQuantity(AnalysisCallback): r""" An AnalysisQuantity is a function that takes minimally a target object, diff --git a/yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py b/yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py index 79bfd549..3e04ee23 100644 --- a/yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py +++ b/yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py @@ -13,73 +13,57 @@ # The full license is in the file COPYING.txt, distributed with this software. # ----------------------------------------------------------------------------- -import os -import shutil -import tempfile - -import numpy as np +import numpy.testing as npt +import pytest +import unyt as un from yt.loaders import load -from yt.testing import assert_equal -from yt.utilities.answer_testing.framework import ( - AnswerTestingTest, - data_dir_load, - requires_ds, +from yt.testing import requires_file +from yt_astro_analysis.halo_analysis import HaloCatalog +from yt_astro_analysis.halo_analysis.halo_catalog.analysis_operators import ( + _remove_quantity, + add_quantity, ) -from yt_astro_analysis.halo_analysis import HaloCatalog, add_quantity - - -def _nstars(halo): - sp = halo.data_object - return (sp["all", "creation_time"] > 0).sum() - - -add_quantity("nstars", _nstars) - - -class HaloQuantityTest(AnswerTestingTest): - _type_name = "HaloQuantity" - _attrs = () - - def __init__(self, data_ds_fn, halos_ds_fn): - self.data_ds_fn = data_ds_fn - self.halos_ds_fn = halos_ds_fn - self.ds = data_dir_load(data_ds_fn) - - def run(self): - curdir = os.getcwd() - tmpdir = tempfile.mkdtemp() - os.chdir(tmpdir) - - dds = data_dir_load(self.data_ds_fn) - hds = data_dir_load(self.halos_ds_fn) - hc = HaloCatalog( - data_ds=dds, halos_ds=hds, output_dir=os.path.join(tmpdir, str(dds)) - ) - hc.add_callback("sphere") - hc.add_quantity("nstars") - hc.create() - - fn = os.path.join(tmpdir, str(dds), "%s.0.h5" % str(dds)) - ds = load(fn) - ad = ds.all_data() - mi, ma = ad.quantities.extrema("nstars") - mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones") +from yt_astro_analysis.utilities.testing import data_dir_load - os.chdir(curdir) - shutil.rmtree(tmpdir) - return np.array([mean, mi, ma]) +@pytest.fixture +def nstars_defined(): + def _nstars(halo): + sp = halo.data_object + return (sp["all", "creation_time"] > 0).sum() - def compare(self, new_result, old_result): - assert_equal(new_result, old_result, verbose=True) + add_quantity("nstars", _nstars) + yield + _remove_quantity("nstars") rh0 = "rockstar_halos/halos_0.0.bin" e64 = "Enzo_64/DD0043/data0043" -@requires_ds(rh0) -@requires_ds(e64) -def test_halo_quantity(): - yield HaloQuantityTest(e64, rh0) +@requires_file(rh0) +@requires_file(e64) +@pytest.mark.usefixtures("nstars_defined") +def test_halo_quantity(tmp_path): + data_ds_fn = e64 + halos_ds_fn = rh0 + ds = data_dir_load(data_ds_fn) + + dds = data_dir_load(data_ds_fn) + hds = data_dir_load(halos_ds_fn) + hc = HaloCatalog(data_ds=dds, halos_ds=hds, output_dir=str(tmp_path)) + hc.add_callback("sphere") + hc.add_quantity("nstars") + hc.create() + + fn = tmp_path / str(dds) / f"{dds}.0.h5" + ds = load(fn) + ad = ds.all_data() + mi, ma = ad.quantities.extrema("nstars") + mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones") + + npt.assert_equal( + un.unyt_array([mean, mi, ma]), + [28.533783783783782, 0.0, 628.0] * un.dimensionless, + ) diff --git a/yt_astro_analysis/halo_analysis/tests/test_halo_finders.py b/yt_astro_analysis/halo_analysis/tests/test_halo_finders.py index 448644dd..58506b3a 100644 --- a/yt_astro_analysis/halo_analysis/tests/test_halo_finders.py +++ b/yt_astro_analysis/halo_analysis/tests/test_halo_finders.py @@ -1,12 +1,13 @@ import os -import shutil import sys -import tempfile + +import pytest +from unyt.testing import assert_allclose_units from yt.frontends.halo_catalog.data_structures import YTHaloCatalogDataset from yt.frontends.rockstar.data_structures import RockstarDataset from yt.loaders import load -from yt.utilities.answer_testing.framework import FieldValuesTest, requires_ds +from yt.testing import requires_file _fields = ( ("halos", "particle_position_x"), @@ -21,18 +22,19 @@ etiny = "enzo_tiny_cosmology/DD0046/DD0046" -@requires_ds(etiny, big_data=True) -def test_halo_finders_single(): +@requires_file(etiny) +def test_halo_finders_single(tmp_path): + pytest.importorskip("mpi4py") from mpi4py import MPI - tmpdir = tempfile.mkdtemp() - curdir = os.getcwd() - os.chdir(tmpdir) + os.chdir(tmp_path) filename = os.path.join(os.path.dirname(__file__), "run_halo_finder.py") for method in methods: comm = MPI.COMM_SELF.Spawn( - sys.executable, args=[filename, method, tmpdir], maxprocs=methods[method] + sys.executable, + args=[filename, method, str(tmp_path)], + maxprocs=methods[method], ) comm.Disconnect() @@ -40,9 +42,8 @@ def test_halo_finders_single(): hcfn = "halos_0.0.bin" else: hcfn = os.path.join("DD0046", "DD0046.0.h5") - fn = os.path.join(tmpdir, "halo_catalogs", method, hcfn) - ds = load(fn) + ds = load(tmp_path / "halo_catalogs" / method / hcfn) if method == "rockstar": ds.parameters["format_revision"] = 2 ds_type = RockstarDataset @@ -51,11 +52,16 @@ def test_halo_finders_single(): assert isinstance(ds, ds_type) for field in _fields: - my_test = FieldValuesTest( - ds, field, particle_type=True, decimals=decimals[method] + obj = ds.all_data() + field = obj._determine_fields(field)[0] + # fd = ds.field_info[field] + weight_field = (field[0], "particle_ones") + avg = obj.quantities.weighted_average_quantity(field, weight=weight_field) + mi, ma = obj.quantities.extrema(field) + assert_allclose_units( + [avg, mi, ma], + [1, 2, 3], + 10.0 ** (-decimals[method]), + err_msg=f"Field values for {field} not equal.", + verbose=True, ) - my_test.suffix = method - yield my_test - - os.chdir(curdir) - shutil.rmtree(tmpdir) diff --git a/yt_astro_analysis/halo_analysis/tests/test_halo_finders_ts.py b/yt_astro_analysis/halo_analysis/tests/test_halo_finders_ts.py index f1cf8f6a..f972726c 100644 --- a/yt_astro_analysis/halo_analysis/tests/test_halo_finders_ts.py +++ b/yt_astro_analysis/halo_analysis/tests/test_halo_finders_ts.py @@ -3,17 +3,10 @@ import pytest +import yt from yt.frontends.halo_catalog.data_structures import YTHaloCatalogDataset from yt.frontends.rockstar.data_structures import RockstarDataset -from yt.loaders import load -from yt_astro_analysis.utilities.testing import TempDirTest - -_fields = ( - ("halos", "particle_position_x"), - ("halos", "particle_position_y"), - ("halos", "particle_position_z"), - ("halos", "particle_mass"), -) +from yt.testing import requires_file methods = {"fof": 2, "hop": 2, "rockstar": 3} decimals = {"fof": 10, "hop": 10, "rockstar": 1} @@ -21,37 +14,32 @@ etiny = "enzo_tiny_cosmology/32Mpc_32.enzo" -@pytest.mark.skipif( - os.environ.get("YT_ASTRO_GHA"), - reason="can't port this test as-is on github actions", -) -class HaloFinderTimeSeriesTest(TempDirTest): - def test_halo_finders(self): - from mpi4py import MPI +@requires_file(etiny) +def test_halo_finders(tmp_path): + pytest.importorskip("mpi4py") + from mpi4py import MPI + + os.chdir(tmp_path) + + filename = os.path.join(os.path.dirname(__file__), "run_halo_finder_ts.py") + for method in methods: + comm = MPI.COMM_SELF.Spawn( + sys.executable, + args=[filename, method, str(tmp_path)], + maxprocs=methods[method], + ) + comm.Disconnect() - filename = os.path.join(os.path.dirname(__file__), "run_halo_finder_ts.py") - for method in methods: - comm = MPI.COMM_SELF.Spawn( - sys.executable, - args=[filename, method, self.tmpdir], - maxprocs=methods[method], - ) - comm.Disconnect() + if method == "rockstar": + hcfns = [f"halos_{i}.0.bin" for i in range(2)] + else: + hcfns = [os.path.join(f"DD{i:04d}", f"DD{i:04d}.0.h5") for i in [20, 46]] + for hcfn in hcfns: + ds = yt.load(tmp_path / "halo_catalogs" / method / hcfn) if method == "rockstar": - hcfns = [f"halos_{i}.0.bin" for i in range(2)] + ds.parameters["format_revision"] = 2 + ds_type = RockstarDataset else: - hcfns = [ - os.path.join(f"DD{i:04d}", f"DD{i:04d}.0.h5") for i in [20, 46] - ] - - for hcfn in hcfns: - fn = os.path.join(self.tmpdir, "halo_catalogs", method, hcfn) - - ds = load(fn) - if method == "rockstar": - ds.parameters["format_revision"] = 2 - ds_type = RockstarDataset - else: - ds_type = YTHaloCatalogDataset - assert isinstance(ds, ds_type) + ds_type = YTHaloCatalogDataset + assert isinstance(ds, ds_type) diff --git a/yt_astro_analysis/radmc3d_export/tests/test_radmc3d_exporter.py b/yt_astro_analysis/radmc3d_export/tests/test_radmc3d_exporter.py index ef3291d0..c37c9e00 100644 --- a/yt_astro_analysis/radmc3d_export/tests/test_radmc3d_exporter.py +++ b/yt_astro_analysis/radmc3d_export/tests/test_radmc3d_exporter.py @@ -11,85 +11,28 @@ # ----------------------------------------------------------------------------- import os -import shutil -import tempfile import numpy as np +import numpy.testing as npt -import yt -from yt.testing import assert_allclose -from yt.utilities.answer_testing.framework import AnswerTestingTest, requires_ds +from yt.testing import requires_file from yt_astro_analysis.radmc3d_export.api import RadMC3DWriter - - -class RadMC3DValuesTest(AnswerTestingTest): - """ - - This test writes out a "dust_density.inp" file, - reads it back in, and checks the sum of the - values for degradation. - - """ - - _type_name = "RadMC3DValuesTest" - _attrs = ("field",) - - def __init__(self, ds_fn, field, decimals=10): - super().__init__(ds_fn) - self.field = field - self.decimals = decimals - - def run(self): - # Set up in a temp dir - tmpdir = tempfile.mkdtemp() - curdir = os.getcwd() - os.chdir(tmpdir) - - # try to write the output files - writer = RadMC3DWriter(self.ds) - writer.write_amr_grid() - writer.write_dust_file(self.field, "dust_density.inp") - - # compute the sum of the values in the resulting file - total = 0.0 - with open("dust_density.inp") as f: - for i, line in enumerate(f): - # skip header - if i < 3: - continue - - line = line.rstrip() - total += np.float64(line) - - # clean up - os.chdir(curdir) - shutil.rmtree(tmpdir) - - return total - - def compare(self, new_result, old_result): - err_msg = f"Total value for {self.field} not equal." - assert_allclose( - new_result, - old_result, - 10.0 ** (-self.decimals), - err_msg=err_msg, - verbose=True, - ) - +from yt_astro_analysis.utilities.testing import data_dir_load etiny = "enzo_tiny_cosmology/DD0046/DD0046" -@requires_ds(etiny) -def test_radmc3d_exporter_continuum(): +@requires_file(etiny) +def test_radmc3d_exporter_continuum(tmp_path): """ This test is simply following the description in the docs for how to generate the necessary output files to run a continuum emission map from dust for one of our sample datasets. """ + os.chdir(tmp_path) - ds = yt.load(etiny) + field = ("gas", "dust_density") + ds = data_dir_load(etiny) # Make up a dust density field where dust density is 1% of gas density dust_to_gas = 0.01 @@ -98,10 +41,32 @@ def _DustDensity(field, data): return dust_to_gas * data["density"] ds.add_field( - ("gas", "dust_density"), + field, function=_DustDensity, sampling_type="cell", units="g/cm**3", ) - yield RadMC3DValuesTest(ds, ("gas", "dust_density")) + # try to write the output files + writer = RadMC3DWriter(ds) + writer.write_amr_grid() + writer.write_dust_file(field, "dust_density.inp") + + # compute the sum of the values in the resulting file + total = 0.0 + with open("dust_density.inp") as f: + for i, line in enumerate(f): + # skip header + if i < 3: + continue + + line = line.rstrip() + total += np.float64(line) + + npt.assert_allclose( + total, + 4.240471916352974e-27, + rtol=10.0 ** (-10), + err_msg=f"Total value for {field} not equal.", + verbose=True, + ) diff --git a/yt_astro_analysis/utilities/testing.py b/yt_astro_analysis/utilities/testing.py index a6ae87a1..517db8ab 100644 --- a/yt_astro_analysis/utilities/testing.py +++ b/yt_astro_analysis/utilities/testing.py @@ -14,33 +14,26 @@ # ----------------------------------------------------------------------------- import os -import shutil -import tempfile -from unittest import TestCase +import warnings +import yt from yt.config import ytcfg -from yt.data_objects.time_series import SimulationTimeSeries -from yt.loaders import load_simulation -from yt.utilities.answer_testing.framework import AnswerTestingTest -class TempDirTest(TestCase): - """ - A test class that runs in a temporary directory and - removes it afterward. - """ - - def setUp(self): - self.curdir = os.getcwd() - self.tmpdir = tempfile.mkdtemp() - os.chdir(self.tmpdir) - - def tearDown(self): - os.chdir(self.curdir) - shutil.rmtree(self.tmpdir) +def data_dir_load(fn, *args, **kwargs): + # wrap yt.load but only load from test_data_dir + path = os.path.join(ytcfg.get("yt", "test_data_dir"), fn) + return yt.load(path, *args, **kwargs) def requires_sim(sim_fn, sim_type, file_check=False): + warnings.warn( + "yt_astro_analysis.utilities.testing.requires_sim " + "is deprecated and will be removed in a future version. " + "Please consider implementing your own solution.", + DeprecationWarning, + stacklevel=2, + ) from functools import wraps from nose import SkipTest @@ -62,6 +55,17 @@ def ftrue(func): def can_run_sim(sim_fn, sim_type, file_check=False): + warnings.warn( + "yt_astro_analysis.utilities.testing.can_run_sim " + "is deprecated and will be removed in a future version. " + "Please consider implementing your own solution.", + DeprecationWarning, + stacklevel=2, + ) + from yt.data_objects.time_series import SimulationTimeSeries + from yt.loaders import load_simulation + from yt.utilities.answer_testing.framework import AnswerTestingTest + result_storage = AnswerTestingTest.result_storage if isinstance(sim_fn, SimulationTimeSeries): return result_storage is not None