diff --git a/src/autochem/rate/data.py b/src/autochem/rate/data.py index 1ca42ac7..1d954024 100644 --- a/src/autochem/rate/data.py +++ b/src/autochem/rate/data.py @@ -13,13 +13,13 @@ import pydantic from numpy.polynomial import chebyshev from numpy.typing import ArrayLike, NDArray -from pydantic import BeforeValidator +from pydantic import BeforeValidator, model_validator from pydantic_core import core_schema from .. import unit_ from ..unit_ import UNITS, C, D, Dimension, UnitManager, Units, UnitsData, const from ..util import arrh, chemkin, func, mess, plot -from ..util.type_ import Frozen, NDArray_, Scalable, Scalers, SubclassTyped +from ..util.type_ import NDArray_, Scalable, Scalers, SubclassTyped from . import blend from .blend import BlendingFunction_ @@ -34,7 +34,7 @@ class Key: k = "k" -class BaseRate(UnitManager, Frozen, Scalable, SubclassTyped, abc.ABC): +class BaseRate(UnitManager, Scalable, SubclassTyped, abc.ABC): """Abstract base class for rate constants.""" order: int = 1 @@ -189,6 +189,16 @@ class Rate(BaseRate): "k_high": D.rate_constant, } + @model_validator(mode="after") + def sort_temperatures(self) -> Self: + idxs = np.argsort(self.T) + self.T = np.take(self.T, idxs).tolist() + self.k_data = self.k_data[idxs] + self.k_high = ( + None if self.k_high is None else np.take(self.k_high, idxs).tolist() + ) + return self + def __truediv__(self, other: "Rate" | ArrayLike) -> Self: """Scalar division. @@ -467,7 +477,53 @@ def merge_equivalent(self, other: "Rate", *, tol: float = 0.1) -> "Rate": return self.model_copy(update={"k_data": k_data, "k_high": k_high}) -class RateFit(BaseRate): +class BoundedMixin(pydantic.BaseModel): + """Mixin to define bounded calculator.""" + + T_min: float | None = None + T_max: float | None = None + + def in_bounds( + self, + T: ArrayLike, # noqa: N803 + ) -> NDArray[np.bool_]: + """Determine whether temperature(s) are in bounds. + + :param T: Temperature(s) + :return: Boolean value(s) + """ + T = np.array(T, dtype=np.float64) # noqa: N806 + greater_than_min = ( + np.ones_like(T, dtype=bool) if self.T_min is None else self.T_min <= T + ) + less_than_max = ( + np.ones_like(T, dtype=bool) if self.T_max is None else self.T_max >= T + ) + return greater_than_min & less_than_max + + def all_in_bounds( + self, + T: ArrayLike, # noqa: N803 + ) -> bool: + """Determine whether all temperature(s) are in bounds. + + :param T: Temperature(s) + :return: `True` if they are + """ + return np.all(self.in_bounds(T)).item() + + def assert_all_in_bounds( + self, + T: ArrayLike, # noqa: N803 + ) -> None: + """Assert that all temperature(s) are in bounds. + + :param T: Temperature(s) + """ + assert self.all_in_bounds(T), f"{self.T_min} !<= {T} !<= {self.T_max}" + + +class RateFit(BaseRate, BoundedMixin): """Rate fit abstract base classs.""" efficiencies: dict[str, float] = pydantic.Field(default_factory=dict) @@ -534,6 +590,7 @@ def __call__( ) -> NDArray[np.float128]: """Evaluate rate constant.""" T_, _ = func.normalize_arguments((T, P)) # noqa: N806 + T_ = np.where(self.in_bounds(T_), T_, np.nan) R = const.value(C.gas, UNITS) # noqa: N806 kTP = self.A * (T_**self.b) * np.exp(-self.E / (R * T_)) # noqa: N806 return func.normalize_values(kTP, (T, P)) @@ -1276,9 +1333,6 @@ def display_p( # noqa: PLR0913 if y_unit: y_label = f"{y_label} ({y_unit})" - nr = len(rates) - labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None) - def make_chart( ixs: Sequence[int], rates: Sequence[BaseRate], diff --git a/src/autochem/unit_/_manager.py b/src/autochem/unit_/_manager.py index 6688df37..89c106e8 100644 --- a/src/autochem/unit_/_manager.py +++ b/src/autochem/unit_/_manager.py @@ -9,13 +9,12 @@ import pydantic from numpy.typing import NDArray -from ..util.type_ import Frozen from . import dim from .dim import Dimension from .system import UNITS, Units, UnitsData -class UnitManager(Frozen, abc.ABC): +class UnitManager(abc.ABC): _dimensions: ClassVar[dict[str, Dimension]] def __init__(self, units: UnitsData | None = None, **kwargs: object) -> None: diff --git a/src/autochem/unit_/_unit.py b/src/autochem/unit_/_unit.py index eb009eef..6f1deece 100644 --- a/src/autochem/unit_/_unit.py +++ b/src/autochem/unit_/_unit.py @@ -8,4 +8,4 @@ def string(unit: pint.Unit) -> str: def pretty_string(unit: pint.Unit) -> str: - return format(unit, "~P") + return format(unit, "~P").replace("particle", "molecule") diff --git a/src/autochem/util/plot.py b/src/autochem/util/plot.py index d83ef878..d6822dbb 100644 --- a/src/autochem/util/plot.py +++ b/src/autochem/util/plot.py @@ -1,12 +1,14 @@ """Plotting helpers.""" import itertools -from collections.abc import Sequence +from collections.abc import Callable, Sequence +from typing import Any import altair as alt import numpy as np import pandas as pd from numpy.typing import ArrayLike +from scipy.interpolate import CubicSpline from .. import unit_ from ..unit_ import UNITS, Units, UnitsData @@ -158,10 +160,19 @@ class Mark: line = "line" +def regular_scale(val_range: tuple[float, float]) -> alt.Scale: + """Generate a regular scale specification. + + :param val_range: Range + :return: Scale + """ + return alt.Scale(domain=val_range) + + def log_scale(val_range: tuple[float, float]) -> alt.Scale: """Generate a log scale specification. - :param val_range: Rante + :param val_range: Range :return: Scale """ return alt.Scale(type="log", domain=log_scale_domain(val_range)) @@ -267,11 +278,32 @@ def recompose_base10(mant: float, exp: int) -> float: MARKS = (Mark.point, Mark.line) +def transformed_spline_interpolator( + x_data: ArrayLike, + y_data: ArrayLike, + x_trans: Callable[[ArrayLike], ArrayLike] = lambda x: x, + y_trans: Callable[[ArrayLike], ArrayLike] = lambda y: y, + y_trans_inv: Callable[[ArrayLike], ArrayLike] = lambda y: y, +) -> Callable[[Any], np.ndarray]: + """Generate an inerpolator from data. + + :param y_data: Y data + :param x_data: X data + :return: Y interpolator + """ + interp_trans_ = CubicSpline(x_trans(x_data), y_trans(y_data)) + + def interp_(x: Any) -> np.ndarray: + return np.asarray(y_trans_inv(interp_trans_(x_trans(x)))) + + return interp_ + + def general( y_data: Sequence[Sequence[float]], x_data: Sequence[float], # noqa: N803 + labels: Sequence[str], *, - labels: Sequence[str] | None = None, colors: Sequence[str] | None = None, x_label: str | None = None, # noqa: RUF001 y_label: str | None = None, # noqa: RUF001 @@ -280,18 +312,11 @@ def general( x_axis: alt.Axis | None = None, y_axis: alt.Axis | None = None, mark: str = Mark.line, + mark_kwargs: dict | None = None, + legend: bool = True, ) -> alt.Chart: """Display as simple plot. - We should eventually be able to handle everything through this. - - :param others: Other rate constants - :param others_labels: Labels for other rate constants - :param T_range: Temperature range - :param P: Pressure - :param x_label: X-axis label - :param y_label: Y-axis label - :param point: Whether to mark with points instead of a line :return: Chart """ x_label = "" if x_label is None else x_label @@ -308,12 +333,10 @@ def general( else [*POINT_COLOR_CYCLE, *LINE_COLOR_CYCLE] ) - nk, nT = np.shape(y_data) # noqa: N806 - colors = colors or list(itertools.islice(itertools.cycle(color_cycle), nk)) - keep_legend = labels is not None - labels = labels or [f"k{i + 1}" for i in range(nk)] - assert len(x_data) == nT, f"{x_data} !~ {y_data}" - assert len(labels) == nk, f"{labels} !~ {y_data}" + ny, nx = np.shape(y_data) # noqa: N806 + colors = colors or list(itertools.islice(itertools.cycle(color_cycle), ny)) + assert len(x_data) == nx, f"{x_data} !~ {y_data}" + assert len(labels) == ny, f"{labels} !~ {y_data}" # Gather data from functons data_dct = dict(zip(labels, y_data, strict=True)) @@ -322,18 +345,18 @@ def general( # Prepare encoding parameters x = alt.X("x", title=x_label, scale=x_scale_, axis=x_axis_) y = alt.Y("value:Q", title=y_label, scale=y_scale_, axis=y_axis_) - color = ( - alt.Color("key:N", scale=alt.Scale(domain=labels, range=colors)) - if keep_legend - else alt.value(colors[0]) + color = alt.Color( + "key:N", + scale=alt.Scale(domain=labels, range=colors), + legend=alt.Undefined if legend else None, ) chart = alt.Chart(data) - chart = ( - chart.mark_point(filled=True, opacity=1) - if mark == Mark.point - else chart.mark_line() - ) + kwargs = {} if mark_kwargs is None else mark_kwargs + if mark == Mark.point: + chart = chart.mark_point(**kwargs) + else: + chart = chart.mark_line(**kwargs) # Create chart return chart.transform_fold(fold=list(data_dct.keys())).encode( @@ -439,6 +462,7 @@ def arrhenius( # noqa: PLR0913 x_unit: str | None = None, y_unit: str | None = None, mark: str = Mark.line, + mark_kwargs: dict | None = None, ) -> alt.Chart: """Display as Arrhenius plot. @@ -503,11 +527,12 @@ def arrhenius( # noqa: PLR0913 ) chart = alt.Chart(data) - chart = ( - chart.mark_point(filled=True, opacity=1) - if mark == Mark.point - else chart.mark_line() - ) + if mark == Mark.point: + kwargs = {"filled": True, "opacity": 1} if mark_kwargs is None else mark_kwargs + chart = chart.mark_point(**kwargs) + else: + kwargs = {} if mark_kwargs is None else mark_kwargs + chart = chart.mark_line(**kwargs) # Create chart return chart.transform_fold(fold=list(data_dct.keys())).encode(