Skip to content

Commit

Permalink
TST: migrate out of answer test framework and to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Oct 10, 2023
1 parent 3e878b5 commit 4115397
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,96 +14,79 @@
# -----------------------------------------------------------------------------

import os
import shutil
import tempfile

import h5py
import numpy as np
import numpy.testing as npt
import pytest
import unyt as un

from yt.testing import assert_equal
from yt.units.yt_array import YTQuantity
from yt.utilities.answer_testing.framework import AnswerTestingTest
from yt.utilities.on_demand_imports import _h5py as h5py
import yt # noqa
from yt.testing import requires_file
from yt_astro_analysis.cosmological_observation.api import LightCone
from yt_astro_analysis.utilities.testing import requires_sim

ETC = "enzo_tiny_cosmology/32Mpc_32.enzo"
_funits = {
"density": YTQuantity(1, "g/cm**3"),
"temperature": YTQuantity(1, "K"),
"length": YTQuantity(1, "cm"),
"density": un.unyt_quantity(1, "g/cm**3"),
"temperature": un.unyt_quantity(1, "K"),
"length": un.unyt_quantity(1, "cm"),
}


class LightConeProjectionTest(AnswerTestingTest):
_type_name = "LightConeProjection"
_attrs = ()

def __init__(self, parameter_file, simulation_type, field, weight_field=None):
self.parameter_file = parameter_file
self.simulation_type = simulation_type
self.ds = os.path.basename(self.parameter_file)
self.field = field
self.weight_field = weight_field

@property
def storage_name(self):
return "_".join(
(os.path.basename(self.parameter_file), self.field, str(self.weight_field))
)

def run(self):
# Set up in a temp dir
tmpdir = tempfile.mkdtemp()
curdir = os.getcwd()
os.chdir(tmpdir)

lc = LightCone(
self.parameter_file,
self.simulation_type,
0.0,
0.1,
observer_redshift=0.0,
time_data=False,
)
lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt")
lc.project_light_cone(
(600.0, "arcmin"),
(60.0, "arcsec"),
self.field,
weight_field=self.weight_field,
save_stack=True,
)

dname = f"{self.field}_{self.weight_field}"
fh = h5py.File("LC/LightCone.h5", mode="r")
@requires_file(ETC)
@pytest.mark.parametrize(
"field, weight_field, expected",
[
(
"density",
None,
[6.0000463633868075e-05, 1.1336502301470154e-05, 0.08970763360935877],
),
(
"temperature",
"density",
[37.79481498628398, 0.018410545597485613, 543702.4613479003],
),
],
)
def test_light_cone_projection(tmp_path, field, weight_field, expected):
parameter_file = ETC
simulation_type = "Enzo"
field = field
weight_field = weight_field

os.chdir(tmp_path)
lc = LightCone(
parameter_file,
simulation_type,
near_redshift=0.0,
far_redshift=0.1,
observer_redshift=0.0,
time_data=False,
)
lc.calculate_light_cone_solution(seed=123456789, filename="LC/solution.txt")
lc.project_light_cone(
(600.0, "arcmin"),
(60.0, "arcsec"),
field,
weight_field=weight_field,
save_stack=True,
)

dname = f"{field}_{weight_field}"
with h5py.File("LC/LightCone.h5", mode="r") as fh:
data = fh[dname][()]
units = fh[dname].attrs["units"]
if self.weight_field is None:
punits = _funits[self.field] * _funits["length"]
if weight_field is None:
punits = _funits[field] * _funits["length"]
else:
punits = (
_funits[self.field] * _funits[self.weight_field] * _funits["length"]
)
wunits = fh["weight_field_%s" % self.weight_field].attrs["units"]
pwunits = _funits[self.weight_field] * _funits["length"]
punits = _funits[field] * _funits[weight_field] * _funits["length"]
wunits = fh[f"weight_field_{weight_field}"].attrs["units"]
pwunits = _funits[weight_field] * _funits["length"]
assert wunits == str(pwunits.units)
assert units == str(punits.units)
fh.close()

# clean up
os.chdir(curdir)
shutil.rmtree(tmpdir)

mean = data.mean()
mi = data[data.nonzero()].min()
ma = data.max()
return np.array([mean, mi, ma])

def compare(self, new_result, old_result):
assert_equal(new_result, old_result, verbose=True)

assert units == str(punits.units)

@requires_sim(ETC, "Enzo")
def test_light_cone_projection():
yield LightConeProjectionTest(ETC, "Enzo", "density")
yield LightConeProjectionTest(ETC, "Enzo", "temperature", weight_field="density")
mean = np.nanmean(data)
mi = np.nanmin(data[data.nonzero()])
ma = np.nanmax(data)
npt.assert_equal([mean, mi, ma], expected, verbose=True)
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ def add_quantity(name, function):
quantity_registry[name] = AnalysisQuantity(function)


def _remove_quantity(name):
# this is useful to avoid test pollution when using add_quantity in tests
# but it's not meant as public API
quantity_registry.pop(name)


class AnalysisQuantity(AnalysisCallback):
r"""
An AnalysisQuantity is a function that takes minimally a target object,
Expand Down
100 changes: 42 additions & 58 deletions yt_astro_analysis/halo_analysis/tests/test_halo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,73 +13,57 @@
# The full license is in the file COPYING.txt, distributed with this software.
# -----------------------------------------------------------------------------

import os
import shutil
import tempfile

import numpy as np
import numpy.testing as npt
import pytest
import unyt as un

from yt.loaders import load
from yt.testing import assert_equal
from yt.utilities.answer_testing.framework import (
AnswerTestingTest,
data_dir_load,
requires_ds,
from yt.testing import requires_file
from yt_astro_analysis.halo_analysis import HaloCatalog
from yt_astro_analysis.halo_analysis.halo_catalog.analysis_operators import (
_remove_quantity,
add_quantity,
)
from yt_astro_analysis.halo_analysis import HaloCatalog, add_quantity


def _nstars(halo):
sp = halo.data_object
return (sp["all", "creation_time"] > 0).sum()


add_quantity("nstars", _nstars)


class HaloQuantityTest(AnswerTestingTest):
_type_name = "HaloQuantity"
_attrs = ()

def __init__(self, data_ds_fn, halos_ds_fn):
self.data_ds_fn = data_ds_fn
self.halos_ds_fn = halos_ds_fn
self.ds = data_dir_load(data_ds_fn)

def run(self):
curdir = os.getcwd()
tmpdir = tempfile.mkdtemp()
os.chdir(tmpdir)

dds = data_dir_load(self.data_ds_fn)
hds = data_dir_load(self.halos_ds_fn)
hc = HaloCatalog(
data_ds=dds, halos_ds=hds, output_dir=os.path.join(tmpdir, str(dds))
)
hc.add_callback("sphere")
hc.add_quantity("nstars")
hc.create()

fn = os.path.join(tmpdir, str(dds), "%s.0.h5" % str(dds))
ds = load(fn)
ad = ds.all_data()
mi, ma = ad.quantities.extrema("nstars")
mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones")
from yt_astro_analysis.utilities.testing import data_dir_load

os.chdir(curdir)
shutil.rmtree(tmpdir)

return np.array([mean, mi, ma])
@pytest.fixture
def nstars_defined():
def _nstars(halo):
sp = halo.data_object
return (sp["all", "creation_time"] > 0).sum()

def compare(self, new_result, old_result):
assert_equal(new_result, old_result, verbose=True)
add_quantity("nstars", _nstars)
yield
_remove_quantity("nstars")


rh0 = "rockstar_halos/halos_0.0.bin"
e64 = "Enzo_64/DD0043/data0043"


@requires_ds(rh0)
@requires_ds(e64)
def test_halo_quantity():
yield HaloQuantityTest(e64, rh0)
@requires_file(rh0)
@requires_file(e64)
@pytest.mark.usefixtures("nstars_defined")
def test_halo_quantity(tmp_path):
data_ds_fn = e64
halos_ds_fn = rh0
ds = data_dir_load(data_ds_fn)

dds = data_dir_load(data_ds_fn)
hds = data_dir_load(halos_ds_fn)
hc = HaloCatalog(data_ds=dds, halos_ds=hds, output_dir=str(tmp_path))
hc.add_callback("sphere")
hc.add_quantity("nstars")
hc.create()

fn = tmp_path / str(dds) / f"{dds}.0.h5"
ds = load(fn)
ad = ds.all_data()
mi, ma = ad.quantities.extrema("nstars")
mean = ad.quantities.weighted_average_quantity("nstars", "particle_ones")

npt.assert_equal(
un.unyt_array([mean, mi, ma]),
[28.533783783783782, 0.0, 628.0] * un.dimensionless,
)
42 changes: 24 additions & 18 deletions yt_astro_analysis/halo_analysis/tests/test_halo_finders.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import shutil
import sys
import tempfile

import pytest
from unyt.testing import assert_allclose_units

from yt.frontends.halo_catalog.data_structures import YTHaloCatalogDataset
from yt.frontends.rockstar.data_structures import RockstarDataset
from yt.loaders import load
from yt.utilities.answer_testing.framework import FieldValuesTest, requires_ds
from yt.testing import requires_file

_fields = (
("halos", "particle_position_x"),
Expand All @@ -21,28 +22,28 @@
etiny = "enzo_tiny_cosmology/DD0046/DD0046"


@requires_ds(etiny, big_data=True)
def test_halo_finders_single():
@requires_file(etiny)
def test_halo_finders_single(tmp_path):
pytest.importorskip("mpi4py")
from mpi4py import MPI

tmpdir = tempfile.mkdtemp()
curdir = os.getcwd()
os.chdir(tmpdir)
os.chdir(tmp_path)

filename = os.path.join(os.path.dirname(__file__), "run_halo_finder.py")
for method in methods:
comm = MPI.COMM_SELF.Spawn(
sys.executable, args=[filename, method, tmpdir], maxprocs=methods[method]
sys.executable,
args=[filename, method, str(tmp_path)],
maxprocs=methods[method],
)
comm.Disconnect()

if method == "rockstar":
hcfn = "halos_0.0.bin"
else:
hcfn = os.path.join("DD0046", "DD0046.0.h5")
fn = os.path.join(tmpdir, "halo_catalogs", method, hcfn)

ds = load(fn)
ds = load(tmp_path / "halo_catalogs" / method / hcfn)
if method == "rockstar":
ds.parameters["format_revision"] = 2
ds_type = RockstarDataset
Expand All @@ -51,11 +52,16 @@ def test_halo_finders_single():
assert isinstance(ds, ds_type)

for field in _fields:
my_test = FieldValuesTest(
ds, field, particle_type=True, decimals=decimals[method]
obj = ds.all_data()
field = obj._determine_fields(field)[0]
# fd = ds.field_info[field]
weight_field = (field[0], "particle_ones")
avg = obj.quantities.weighted_average_quantity(field, weight=weight_field)
mi, ma = obj.quantities.extrema(field)
assert_allclose_units(
[avg, mi, ma],
[1, 2, 3],
10.0 ** (-decimals[method]),
err_msg=f"Field values for {field} not equal.",
verbose=True,
)
my_test.suffix = method
yield my_test

os.chdir(curdir)
shutil.rmtree(tmpdir)
Loading

0 comments on commit 4115397

Please sign in to comment.