Skip to content

Commit f21dd20

Browse files
authored
Merge pull request #716 from avcopan/dev
More plotting improvements
2 parents 4ce2c18 + 3798a77 commit f21dd20

File tree

4 files changed

+120
-42
lines changed

4 files changed

+120
-42
lines changed

src/autochem/rate/data.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
import pydantic
1414
from numpy.polynomial import chebyshev
1515
from numpy.typing import ArrayLike, NDArray
16-
from pydantic import BeforeValidator
16+
from pydantic import BeforeValidator, model_validator
1717
from pydantic_core import core_schema
1818

1919
from .. import unit_
2020
from ..unit_ import UNITS, C, D, Dimension, UnitManager, Units, UnitsData, const
2121
from ..util import arrh, chemkin, func, mess, plot
22-
from ..util.type_ import Frozen, NDArray_, Scalable, Scalers, SubclassTyped
22+
from ..util.type_ import NDArray_, Scalable, Scalers, SubclassTyped
2323
from . import blend
2424
from .blend import BlendingFunction_
2525

@@ -34,7 +34,7 @@ class Key:
3434
k = "k"
3535

3636

37-
class BaseRate(UnitManager, Frozen, Scalable, SubclassTyped, abc.ABC):
37+
class BaseRate(UnitManager, Scalable, SubclassTyped, abc.ABC):
3838
"""Abstract base class for rate constants."""
3939

4040
order: int = 1
@@ -189,6 +189,16 @@ class Rate(BaseRate):
189189
"k_high": D.rate_constant,
190190
}
191191

192+
@model_validator(mode="after")
193+
def sort_temperatures(self) -> Self:
194+
idxs = np.argsort(self.T)
195+
self.T = np.take(self.T, idxs).tolist()
196+
self.k_data = self.k_data[idxs]
197+
self.k_high = (
198+
None if self.k_high is None else np.take(self.k_high, idxs).tolist()
199+
)
200+
return self
201+
192202
def __truediv__(self, other: "Rate" | ArrayLike) -> Self:
193203
"""Scalar division.
194204
@@ -467,7 +477,53 @@ def merge_equivalent(self, other: "Rate", *, tol: float = 0.1) -> "Rate":
467477
return self.model_copy(update={"k_data": k_data, "k_high": k_high})
468478

469479

470-
class RateFit(BaseRate):
480+
class BoundedMixin(pydantic.BaseModel):
481+
"""Mixin to define bounded calculator."""
482+
483+
T_min: float | None = None
484+
T_max: float | None = None
485+
486+
def in_bounds(
487+
self,
488+
T: ArrayLike, # noqa: N803
489+
) -> NDArray[np.bool_]:
490+
"""Determine whether temperature(s) are in bounds.
491+
492+
:param T: Temperature(s)
493+
:return: Boolean value(s)
494+
"""
495+
T = np.array(T, dtype=np.float64) # noqa: N806
496+
greater_than_min = (
497+
np.ones_like(T, dtype=bool) if self.T_min is None else self.T_min <= T
498+
)
499+
less_than_max = (
500+
np.ones_like(T, dtype=bool) if self.T_max is None else self.T_max >= T
501+
)
502+
return greater_than_min & less_than_max
503+
504+
def all_in_bounds(
505+
self,
506+
T: ArrayLike, # noqa: N803
507+
) -> bool:
508+
"""Determine whether all temperature(s) are in bounds.
509+
510+
:param T: Temperature(s)
511+
:return: `True` if they are
512+
"""
513+
return np.all(self.in_bounds(T)).item()
514+
515+
def assert_all_in_bounds(
516+
self,
517+
T: ArrayLike, # noqa: N803
518+
) -> None:
519+
"""Assert that all temperature(s) are in bounds.
520+
521+
:param T: Temperature(s)
522+
"""
523+
assert self.all_in_bounds(T), f"{self.T_min} !<= {T} !<= {self.T_max}"
524+
525+
526+
class RateFit(BaseRate, BoundedMixin):
471527
"""Rate fit abstract base classs."""
472528

473529
efficiencies: dict[str, float] = pydantic.Field(default_factory=dict)
@@ -534,6 +590,7 @@ def __call__(
534590
) -> NDArray[np.float128]:
535591
"""Evaluate rate constant."""
536592
T_, _ = func.normalize_arguments((T, P)) # noqa: N806
593+
T_ = np.where(self.in_bounds(T_), T_, np.nan)
537594
R = const.value(C.gas, UNITS) # noqa: N806
538595
kTP = self.A * (T_**self.b) * np.exp(-self.E / (R * T_)) # noqa: N806
539596
return func.normalize_values(kTP, (T, P))
@@ -1276,9 +1333,6 @@ def display_p( # noqa: PLR0913
12761333
if y_unit:
12771334
y_label = f"{y_label} ({y_unit})"
12781335

1279-
nr = len(rates)
1280-
labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None)
1281-
12821336
def make_chart(
12831337
ixs: Sequence[int],
12841338
rates: Sequence[BaseRate],

src/autochem/unit_/_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
import pydantic
1010
from numpy.typing import NDArray
1111

12-
from ..util.type_ import Frozen
1312
from . import dim
1413
from .dim import Dimension
1514
from .system import UNITS, Units, UnitsData
1615

1716

18-
class UnitManager(Frozen, abc.ABC):
17+
class UnitManager(abc.ABC):
1918
_dimensions: ClassVar[dict[str, Dimension]]
2019

2120
def __init__(self, units: UnitsData | None = None, **kwargs: object) -> None:

src/autochem/unit_/_unit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,4 @@ def string(unit: pint.Unit) -> str:
88

99

1010
def pretty_string(unit: pint.Unit) -> str:
11-
return format(unit, "~P")
11+
return format(unit, "~P").replace("particle", "molecule")

src/autochem/util/plot.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Plotting helpers."""
22

33
import itertools
4-
from collections.abc import Sequence
4+
from collections.abc import Callable, Sequence
5+
from typing import Any
56

67
import altair as alt
78
import numpy as np
89
import pandas as pd
910
from numpy.typing import ArrayLike
11+
from scipy.interpolate import CubicSpline
1012

1113
from .. import unit_
1214
from ..unit_ import UNITS, Units, UnitsData
@@ -158,10 +160,19 @@ class Mark:
158160
line = "line"
159161

160162

163+
def regular_scale(val_range: tuple[float, float]) -> alt.Scale:
164+
"""Generate a regular scale specification.
165+
166+
:param val_range: Range
167+
:return: Scale
168+
"""
169+
return alt.Scale(domain=val_range)
170+
171+
161172
def log_scale(val_range: tuple[float, float]) -> alt.Scale:
162173
"""Generate a log scale specification.
163174
164-
:param val_range: Rante
175+
:param val_range: Range
165176
:return: Scale
166177
"""
167178
return alt.Scale(type="log", domain=log_scale_domain(val_range))
@@ -267,11 +278,32 @@ def recompose_base10(mant: float, exp: int) -> float:
267278
MARKS = (Mark.point, Mark.line)
268279

269280

281+
def transformed_spline_interpolator(
282+
x_data: ArrayLike,
283+
y_data: ArrayLike,
284+
x_trans: Callable[[ArrayLike], ArrayLike] = lambda x: x,
285+
y_trans: Callable[[ArrayLike], ArrayLike] = lambda y: y,
286+
y_trans_inv: Callable[[ArrayLike], ArrayLike] = lambda y: y,
287+
) -> Callable[[Any], np.ndarray]:
288+
"""Generate an inerpolator from data.
289+
290+
:param y_data: Y data
291+
:param x_data: X data
292+
:return: Y interpolator
293+
"""
294+
interp_trans_ = CubicSpline(x_trans(x_data), y_trans(y_data))
295+
296+
def interp_(x: Any) -> np.ndarray:
297+
return np.asarray(y_trans_inv(interp_trans_(x_trans(x))))
298+
299+
return interp_
300+
301+
270302
def general(
271303
y_data: Sequence[Sequence[float]],
272304
x_data: Sequence[float], # noqa: N803
305+
labels: Sequence[str],
273306
*,
274-
labels: Sequence[str] | None = None,
275307
colors: Sequence[str] | None = None,
276308
x_label: str | None = None, # noqa: RUF001
277309
y_label: str | None = None, # noqa: RUF001
@@ -280,18 +312,11 @@ def general(
280312
x_axis: alt.Axis | None = None,
281313
y_axis: alt.Axis | None = None,
282314
mark: str = Mark.line,
315+
mark_kwargs: dict | None = None,
316+
legend: bool = True,
283317
) -> alt.Chart:
284318
"""Display as simple plot.
285319
286-
We should eventually be able to handle everything through this.
287-
288-
:param others: Other rate constants
289-
:param others_labels: Labels for other rate constants
290-
:param T_range: Temperature range
291-
:param P: Pressure
292-
:param x_label: X-axis label
293-
:param y_label: Y-axis label
294-
:param point: Whether to mark with points instead of a line
295320
:return: Chart
296321
"""
297322
x_label = "" if x_label is None else x_label
@@ -308,12 +333,10 @@ def general(
308333
else [*POINT_COLOR_CYCLE, *LINE_COLOR_CYCLE]
309334
)
310335

311-
nk, nT = np.shape(y_data) # noqa: N806
312-
colors = colors or list(itertools.islice(itertools.cycle(color_cycle), nk))
313-
keep_legend = labels is not None
314-
labels = labels or [f"k{i + 1}" for i in range(nk)]
315-
assert len(x_data) == nT, f"{x_data} !~ {y_data}"
316-
assert len(labels) == nk, f"{labels} !~ {y_data}"
336+
ny, nx = np.shape(y_data) # noqa: N806
337+
colors = colors or list(itertools.islice(itertools.cycle(color_cycle), ny))
338+
assert len(x_data) == nx, f"{x_data} !~ {y_data}"
339+
assert len(labels) == ny, f"{labels} !~ {y_data}"
317340

318341
# Gather data from functons
319342
data_dct = dict(zip(labels, y_data, strict=True))
@@ -322,18 +345,18 @@ def general(
322345
# Prepare encoding parameters
323346
x = alt.X("x", title=x_label, scale=x_scale_, axis=x_axis_)
324347
y = alt.Y("value:Q", title=y_label, scale=y_scale_, axis=y_axis_)
325-
color = (
326-
alt.Color("key:N", scale=alt.Scale(domain=labels, range=colors))
327-
if keep_legend
328-
else alt.value(colors[0])
348+
color = alt.Color(
349+
"key:N",
350+
scale=alt.Scale(domain=labels, range=colors),
351+
legend=alt.Undefined if legend else None,
329352
)
330353

331354
chart = alt.Chart(data)
332-
chart = (
333-
chart.mark_point(filled=True, opacity=1)
334-
if mark == Mark.point
335-
else chart.mark_line()
336-
)
355+
kwargs = {} if mark_kwargs is None else mark_kwargs
356+
if mark == Mark.point:
357+
chart = chart.mark_point(**kwargs)
358+
else:
359+
chart = chart.mark_line(**kwargs)
337360

338361
# Create chart
339362
return chart.transform_fold(fold=list(data_dct.keys())).encode(
@@ -439,6 +462,7 @@ def arrhenius( # noqa: PLR0913
439462
x_unit: str | None = None,
440463
y_unit: str | None = None,
441464
mark: str = Mark.line,
465+
mark_kwargs: dict | None = None,
442466
) -> alt.Chart:
443467
"""Display as Arrhenius plot.
444468
@@ -503,11 +527,12 @@ def arrhenius( # noqa: PLR0913
503527
)
504528

505529
chart = alt.Chart(data)
506-
chart = (
507-
chart.mark_point(filled=True, opacity=1)
508-
if mark == Mark.point
509-
else chart.mark_line()
510-
)
530+
if mark == Mark.point:
531+
kwargs = {"filled": True, "opacity": 1} if mark_kwargs is None else mark_kwargs
532+
chart = chart.mark_point(**kwargs)
533+
else:
534+
kwargs = {} if mark_kwargs is None else mark_kwargs
535+
chart = chart.mark_line(**kwargs)
511536

512537
# Create chart
513538
return chart.transform_fold(fold=list(data_dct.keys())).encode(

0 commit comments

Comments
 (0)