Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 61 additions & 7 deletions src/autochem/rate/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
import pydantic
from numpy.polynomial import chebyshev
from numpy.typing import ArrayLike, NDArray
from pydantic import BeforeValidator
from pydantic import BeforeValidator, model_validator
from pydantic_core import core_schema

from .. import unit_
from ..unit_ import UNITS, C, D, Dimension, UnitManager, Units, UnitsData, const
from ..util import arrh, chemkin, func, mess, plot
from ..util.type_ import Frozen, NDArray_, Scalable, Scalers, SubclassTyped
from ..util.type_ import NDArray_, Scalable, Scalers, SubclassTyped
from . import blend
from .blend import BlendingFunction_

Expand All @@ -34,7 +34,7 @@ class Key:
k = "k"


class BaseRate(UnitManager, Frozen, Scalable, SubclassTyped, abc.ABC):
class BaseRate(UnitManager, Scalable, SubclassTyped, abc.ABC):
"""Abstract base class for rate constants."""

order: int = 1
Expand Down Expand Up @@ -189,6 +189,16 @@ class Rate(BaseRate):
"k_high": D.rate_constant,
}

@model_validator(mode="after")
def sort_temperatures(self) -> Self:
idxs = np.argsort(self.T)
self.T = np.take(self.T, idxs).tolist()
self.k_data = self.k_data[idxs]
self.k_high = (
None if self.k_high is None else np.take(self.k_high, idxs).tolist()
)
return self

def __truediv__(self, other: "Rate" | ArrayLike) -> Self:
"""Scalar division.

Expand Down Expand Up @@ -467,7 +477,53 @@ def merge_equivalent(self, other: "Rate", *, tol: float = 0.1) -> "Rate":
return self.model_copy(update={"k_data": k_data, "k_high": k_high})


class RateFit(BaseRate):
class BoundedMixin(pydantic.BaseModel):
"""Mixin to define bounded calculator."""

T_min: float | None = None
T_max: float | None = None

def in_bounds(
self,
T: ArrayLike, # noqa: N803
) -> NDArray[np.bool_]:
"""Determine whether temperature(s) are in bounds.

:param T: Temperature(s)
:return: Boolean value(s)
"""
T = np.array(T, dtype=np.float64) # noqa: N806
greater_than_min = (
np.ones_like(T, dtype=bool) if self.T_min is None else self.T_min <= T
)
less_than_max = (
np.ones_like(T, dtype=bool) if self.T_max is None else self.T_max >= T
)
return greater_than_min & less_than_max

def all_in_bounds(
self,
T: ArrayLike, # noqa: N803
) -> bool:
"""Determine whether all temperature(s) are in bounds.

:param T: Temperature(s)
:return: `True` if they are
"""
return np.all(self.in_bounds(T)).item()

def assert_all_in_bounds(
self,
T: ArrayLike, # noqa: N803
) -> None:
"""Assert that all temperature(s) are in bounds.

:param T: Temperature(s)
"""
assert self.all_in_bounds(T), f"{self.T_min} !<= {T} !<= {self.T_max}"


class RateFit(BaseRate, BoundedMixin):
"""Rate fit abstract base classs."""

efficiencies: dict[str, float] = pydantic.Field(default_factory=dict)
Expand Down Expand Up @@ -534,6 +590,7 @@ def __call__(
) -> NDArray[np.float128]:
"""Evaluate rate constant."""
T_, _ = func.normalize_arguments((T, P)) # noqa: N806
T_ = np.where(self.in_bounds(T_), T_, np.nan)
R = const.value(C.gas, UNITS) # noqa: N806
kTP = self.A * (T_**self.b) * np.exp(-self.E / (R * T_)) # noqa: N806
return func.normalize_values(kTP, (T, P))
Expand Down Expand Up @@ -1276,9 +1333,6 @@ def display_p( # noqa: PLR0913
if y_unit:
y_label = f"{y_label} ({y_unit})"

nr = len(rates)
labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None)

def make_chart(
ixs: Sequence[int],
rates: Sequence[BaseRate],
Expand Down
3 changes: 1 addition & 2 deletions src/autochem/unit_/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import pydantic
from numpy.typing import NDArray

from ..util.type_ import Frozen
from . import dim
from .dim import Dimension
from .system import UNITS, Units, UnitsData


class UnitManager(Frozen, abc.ABC):
class UnitManager(abc.ABC):
_dimensions: ClassVar[dict[str, Dimension]]

def __init__(self, units: UnitsData | None = None, **kwargs: object) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/autochem/unit_/_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ def string(unit: pint.Unit) -> str:


def pretty_string(unit: pint.Unit) -> str:
return format(unit, "~P")
return format(unit, "~P").replace("particle", "molecule")
89 changes: 57 additions & 32 deletions src/autochem/util/plot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Plotting helpers."""

import itertools
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Any

import altair as alt
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from scipy.interpolate import CubicSpline

from .. import unit_
from ..unit_ import UNITS, Units, UnitsData
Expand Down Expand Up @@ -158,10 +160,19 @@ class Mark:
line = "line"


def regular_scale(val_range: tuple[float, float]) -> alt.Scale:
"""Generate a regular scale specification.

:param val_range: Range
:return: Scale
"""
return alt.Scale(domain=val_range)


def log_scale(val_range: tuple[float, float]) -> alt.Scale:
"""Generate a log scale specification.

:param val_range: Rante
:param val_range: Range
:return: Scale
"""
return alt.Scale(type="log", domain=log_scale_domain(val_range))
Expand Down Expand Up @@ -267,11 +278,32 @@ def recompose_base10(mant: float, exp: int) -> float:
MARKS = (Mark.point, Mark.line)


def transformed_spline_interpolator(
x_data: ArrayLike,
y_data: ArrayLike,
x_trans: Callable[[ArrayLike], ArrayLike] = lambda x: x,
y_trans: Callable[[ArrayLike], ArrayLike] = lambda y: y,
y_trans_inv: Callable[[ArrayLike], ArrayLike] = lambda y: y,
) -> Callable[[Any], np.ndarray]:
"""Generate an inerpolator from data.

:param y_data: Y data
:param x_data: X data
:return: Y interpolator
"""
interp_trans_ = CubicSpline(x_trans(x_data), y_trans(y_data))

def interp_(x: Any) -> np.ndarray:
return np.asarray(y_trans_inv(interp_trans_(x_trans(x))))

return interp_


def general(
y_data: Sequence[Sequence[float]],
x_data: Sequence[float], # noqa: N803
labels: Sequence[str],
*,
labels: Sequence[str] | None = None,
colors: Sequence[str] | None = None,
x_label: str | None = None, # noqa: RUF001
y_label: str | None = None, # noqa: RUF001
Expand All @@ -280,18 +312,11 @@ def general(
x_axis: alt.Axis | None = None,
y_axis: alt.Axis | None = None,
mark: str = Mark.line,
mark_kwargs: dict | None = None,
legend: bool = True,
) -> alt.Chart:
"""Display as simple plot.

We should eventually be able to handle everything through this.

:param others: Other rate constants
:param others_labels: Labels for other rate constants
:param T_range: Temperature range
:param P: Pressure
:param x_label: X-axis label
:param y_label: Y-axis label
:param point: Whether to mark with points instead of a line
:return: Chart
"""
x_label = "" if x_label is None else x_label
Expand All @@ -308,12 +333,10 @@ def general(
else [*POINT_COLOR_CYCLE, *LINE_COLOR_CYCLE]
)

nk, nT = np.shape(y_data) # noqa: N806
colors = colors or list(itertools.islice(itertools.cycle(color_cycle), nk))
keep_legend = labels is not None
labels = labels or [f"k{i + 1}" for i in range(nk)]
assert len(x_data) == nT, f"{x_data} !~ {y_data}"
assert len(labels) == nk, f"{labels} !~ {y_data}"
ny, nx = np.shape(y_data) # noqa: N806
colors = colors or list(itertools.islice(itertools.cycle(color_cycle), ny))
assert len(x_data) == nx, f"{x_data} !~ {y_data}"
assert len(labels) == ny, f"{labels} !~ {y_data}"

# Gather data from functons
data_dct = dict(zip(labels, y_data, strict=True))
Expand All @@ -322,18 +345,18 @@ def general(
# Prepare encoding parameters
x = alt.X("x", title=x_label, scale=x_scale_, axis=x_axis_)
y = alt.Y("value:Q", title=y_label, scale=y_scale_, axis=y_axis_)
color = (
alt.Color("key:N", scale=alt.Scale(domain=labels, range=colors))
if keep_legend
else alt.value(colors[0])
color = alt.Color(
"key:N",
scale=alt.Scale(domain=labels, range=colors),
legend=alt.Undefined if legend else None,
)

chart = alt.Chart(data)
chart = (
chart.mark_point(filled=True, opacity=1)
if mark == Mark.point
else chart.mark_line()
)
kwargs = {} if mark_kwargs is None else mark_kwargs
if mark == Mark.point:
chart = chart.mark_point(**kwargs)
else:
chart = chart.mark_line(**kwargs)

# Create chart
return chart.transform_fold(fold=list(data_dct.keys())).encode(
Expand Down Expand Up @@ -439,6 +462,7 @@ def arrhenius( # noqa: PLR0913
x_unit: str | None = None,
y_unit: str | None = None,
mark: str = Mark.line,
mark_kwargs: dict | None = None,
) -> alt.Chart:
"""Display as Arrhenius plot.

Expand Down Expand Up @@ -503,11 +527,12 @@ def arrhenius( # noqa: PLR0913
)

chart = alt.Chart(data)
chart = (
chart.mark_point(filled=True, opacity=1)
if mark == Mark.point
else chart.mark_line()
)
if mark == Mark.point:
kwargs = {"filled": True, "opacity": 1} if mark_kwargs is None else mark_kwargs
chart = chart.mark_point(**kwargs)
else:
kwargs = {} if mark_kwargs is None else mark_kwargs
chart = chart.mark_line(**kwargs)

# Create chart
return chart.transform_fold(fold=list(data_dct.keys())).encode(
Expand Down