Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve errors #478

Merged
merged 14 commits into from
Aug 18, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def _is_none(x: Any) -> bool:
return x is None


def _term_compatible(
def _assert_term_compatible(
y: PyTree[ArrayLike],
args: PyTree[Any],
terms: PyTree[AbstractTerm],
term_structure: PyTree,
contr_kwargs: PyTree[dict],
) -> bool:
) -> None:
error_msg = "term_structure"

def _check(term_cls, term, term_contr_kwargs, yi):
Expand All @@ -136,18 +136,19 @@ def _check(term_cls, term, term_contr_kwargs, yi):
for term, arg, term_contr_kwarg in zip(
term.terms, get_args(_tmp), term_contr_kwargs
):
if not _term_compatible(yi, args, term, arg, term_contr_kwarg):
raise ValueError
_assert_term_compatible(yi, args, term, arg, term_contr_kwarg)
else:
raise ValueError
raise ValueError(
f"Term {term} is not a MultiTerm but is expected to be."
)
else:
# Check that `term` is an instance of `term_cls` (ignoring any generic
# parameterization).
origin_cls = get_origin_no_specials(term_cls, error_msg)
if origin_cls is None:
origin_cls = term_cls
if not isinstance(term, origin_cls):
raise ValueError
raise ValueError(f"Term {term} is not an instance of {origin_cls}.")

# Now check the generic parametrization of `term_cls`; can be one of:
# -----------------------------------------
Expand All @@ -162,32 +163,37 @@ def _check(term_cls, term, term_contr_kwargs, yi):
pass
elif n_term_args == 2:
vf_type_expected, control_type_expected = term_args
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
try:
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
except Exception as e:
raise ValueError(f"Error while tracing {term}.vf: " + str(e))
vf_type_compatible = eqx.filter_eval_shape(
better_isinstance, vf_type, vf_type_expected
)
if not vf_type_compatible:
raise ValueError
raise ValueError(f"Vector field term {term} is incompatible.")

contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
try:
control_type = eqx.filter_eval_shape(contr, 0.0, 0.0)
except Exception as e:
raise ValueError(f"Error while tracing {term}.contr: " + str(e))
control_type_compatible = eqx.filter_eval_shape(
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError
raise ValueError(f"Control term {term} is incompatible.")
else:
assert False, "Malformed term structure"
# If we've got to this point then the term is compatible

try:
with jax.numpy_dtype_promotion("standard"):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except ValueError:
except Exception as e:
# ValueError may also arise from mismatched tree structures
return False
return True
raise ValueError("Terms are not compatible with solver! " + str(e))


def _is_subsaveat(x: Any) -> bool:
Expand Down Expand Up @@ -1006,29 +1012,35 @@ def _promote(yi):
del timelikes

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
) and _term_compatible(
y0, args, terms, (ODETerm, AbstractTerm), solver.term_compatible_contr_kwargs
):
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general and SDE-specific "
"solvers!",
stacklevel=2,
)
terms = MultiTerm(*terms)
if isinstance(solver, (EulerHeun, ItoMilstein, StratonovichMilstein)):
try:
_assert_term_compatible(
y0,
args,
terms,
(ODETerm, AbstractTerm),
solver.term_compatible_contr_kwargs,
)
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general "
"and SDE-specific solvers!",
stacklevel=2,
)
terms = MultiTerm(*terms)
except Exception as _:
pass

# Error checking
if not _term_compatible(
y0, args, terms, solver.term_structure, solver.term_compatible_contr_kwargs
):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)
_assert_term_compatible(
y0,
args,
terms,
solver.term_structure,
solver.term_compatible_contr_kwargs,
)

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
Expand Down
6 changes: 5 additions & 1 deletion diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,11 @@ def make_zeros_aux(leaf):
# Now the diffusion related stuff
# Brownian increment (and space-time Lévy area)
bm_inc = diffusion.contr(t0, t1, use_levy=True)
assert isinstance(bm_inc, self.minimal_levy_area)
if not isinstance(bm_inc, self.minimal_levy_area):
raise ValueError(
f"The brownian increment {bm_inc} does not have the "
f"minimal Levy Area {self.minimal_levy_area}."
)
w = bm_inc.W

# b looks similar regardless of whether we have additive noise or not
Expand Down
6 changes: 4 additions & 2 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ class TestSolver(diffrax.Euler):
compatible_vf, compatible_control
)
for term in incompatible_terms:
with pytest.raises(ValueError, match=r"`terms` must be a PyTree of"):
with pytest.raises(ValueError, match=r"Terms are not compatible with solver! "):
diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, (jnp.zeros((2, 1)),))

diffrax.diffeqsolve(
Expand Down Expand Up @@ -786,5 +786,7 @@ def func(self, terms, t0, y0, args):
):
if term is compatible_term and y0 is compatible_y0:
continue
with pytest.raises(ValueError, match=r"`terms` must be a PyTree of"):
with pytest.raises(
ValueError, match=r"Terms are not compatible with solver! "
):
diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, y0)
Loading