Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/autochem/rate/_reaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def make_chart(
labels_ = None if labels is None else [labels[i] for i in ixs]
colors_ = None if colors is None else [colors[i] for i in ixs]
(T, *Ts), ks = zip( # noqa: N806
*(r.plot_data(T_range=T_range, P=P, units=units) for r in rates_),
*(r.plot_data(T=T_range, P=P, units=units) for r in rates_),
strict=True,
)
for T_ in Ts: # noqa: N806
Expand Down
153 changes: 131 additions & 22 deletions src/autochem/rate/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import abc
import warnings
from collections.abc import Mapping, Sequence
from numbers import Number
from typing import Annotated, ClassVar, Literal, Self

import altair as alt
Expand Down Expand Up @@ -69,20 +70,31 @@ def plot_mark(self) -> str:

def plot_data(
self,
T_range: tuple[float, float] = (400, 1250), # noqa: N803
P: float = 1, # noqa: N803
T: float | tuple[float, float] = (400, 1250), # noqa: N803
P: float | tuple[float, float] = 1, # noqa: N803
units: UnitsData | None = None,
) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
"""Display as an Arrhenius plot.

:param T_range: Temperature range
:param P: Pressure
:param T: Temperature or temperature range
:param P: Pressure or pressure range
:param units: Units
:return: Chart
"""
T = np.linspace(*T_range, 1000) # noqa: N806
k = self(T=T, P=P, units=units)
return T, k
if isinstance(T, Sequence) and isinstance(P, Number):
P_ = P
T_ = np.linspace(*T, 1000) # noqa: N806
x_data = T_
elif isinstance(T, Number) and isinstance(P, Sequence):
P_ = np.linspace(*P, 1000) # noqa: N806
T_ = T
x_data = P_
else:
msg = f"3-dimensional plotting not yet implemented:\nT={T}\nP={P}"
raise ValueError(msg)

y_data = self(T_, P_, units=units)
return x_data, y_data # type: ignore

def display( # noqa: PLR0913
self,
Expand All @@ -103,7 +115,7 @@ def display( # noqa: PLR0913
:param y_label: Y-axis label
:return: Chart
"""
T, k = self.plot_data(T_range=T_range, P=P, units=units) # noqa: N806
T, k = self.plot_data(T=T_range, P=P, units=units) # noqa: N806
return plot.arrhenius(
ks=[k],
T=T,
Expand Down Expand Up @@ -200,23 +212,33 @@ def plot_mark(self) -> str:

def plot_data(
self,
T_range: tuple[float, float] = (400, 1250), # noqa: N803
P: float = 1, # noqa: N803
T: float | tuple[float, float] = (400, 1250), # noqa: N803
P: float | tuple[float, float] = 1, # noqa: N803
units: UnitsData | None = None,
) -> tuple[NDArray, NDArray]:
"""Display as an Arrhenius plot.

:param T_range: Temperature range
:param P: Pressure
:param T: Temperature or temperature range
:param P: Pressure or pressure range
:param units: Units
:return: Chart
"""
T_min, T_max = T_range # noqa: N806
k = self(T=self.T, P=P, units=units)
(ix,) = np.where(
np.greater_equal(self.T, T_min) & np.less_equal(self.T, T_max),
)
return np.take(self.T, ix), np.take(k, ix)
if isinstance(T, Sequence) and isinstance(P, Number):
T_ = self.T
P_ = P
(i_,) = np.where(np.greater_equal(T_, T[0]) & np.less_equal(T_, T[1]))
x_data = np.take(T_, i_)
elif isinstance(T, Number) and isinstance(P, Sequence):
T_ = T
P_ = self.P
(i_,) = np.where(np.greater_equal(P_, P[0]) & np.less_equal(P_, P[1]))
x_data = np.take(P_, i_)
else:
msg = f"3-dimensional plotting not yet implemented:\nT={T}\nP={P}"
raise ValueError(msg)

y_data = np.take(self(T_, P_, units=units), i_)
return x_data, y_data

@unit_.manage_units([D.temperature, D.pressure], D.rate_constant)
def __call__(
Expand Down Expand Up @@ -1163,9 +1185,6 @@ def display( # noqa: PLR0913
raise ValueError(msg)
order = rate0.order

nr = len(rates)
labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None)

plot_ = plot.arrhenius if plot_type == "arrh" else plot.simple

def make_chart(
Expand All @@ -1179,7 +1198,7 @@ def make_chart(
labels_ = None if labels is None else [labels[i] for i in ixs]
colors_ = None if colors is None else [colors[i] for i in ixs]
(T, *Ts), ks = zip( # noqa: N806
*(r.plot_data(T_range=T_range, P=P, units=units) for r in rates_),
*(r.plot_data(T=T_range, P=P, units=units) for r in rates_),
strict=True,
)
for T_ in Ts: # noqa: N806
Expand Down Expand Up @@ -1211,3 +1230,93 @@ def make_chart(
return (
chart if not others else alt.layer(*charts).resolve_scale(color="independent")
)


def display_p( # noqa: PLR0913
rate_: BaseRate | Sequence[BaseRate],
*,
T: float = 825, # noqa: N803
P_range: tuple[float, float] = (0.1, 100), # noqa: N803
units: UnitsData | None = None,
label: str | Sequence[str] | None = None,
color: str | Sequence[str] | None = None,
y_label: str | None = None, # noqa: RUF001
y_unit: str | None = None, # noqa: RUF001
check_order: bool = True,
) -> alt.Chart:
"""Display one or more reaction rates on an Arrhenius plot.

:param rxn_: Reaction rate(s)
:param T_range: Temperature range, defaults to (400, 1250)
:param P: Pressure
:param label_: Label(s), defaults to None
:param color_: Color(s), defaults to None
:param x_label: X-axis label
:param y_label: Y-axis label
"""
rates = [rate_] if isinstance(rate_, BaseRate) else rate_
labels = [label] if isinstance(label, str) else label
colors = [color] if isinstance(color, str) else color
rate0, *rates_ = rates
if check_order:
for other_rate in rates_:
if not rate0.order == other_rate.order:
msg = f"Mismatched reaction orders: {rate0} !~ {other_rate}"
raise ValueError(msg)
order = rate0.order

units = UNITS if units is None else Units.model_validate(units)
x_unit = unit_.pretty_string(units.pressure)
x_label = f"𝑃 ({x_unit})"

y_label = "𝑘" if y_label is None else y_label
y_unit = (
unit_.pretty_string(units.rate_constant(order)) if y_unit is None else y_unit
)
if y_unit:
y_label = f"{y_label} ({y_unit})"

nr = len(rates)
labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None)

def make_chart(
ixs: Sequence[int],
rates: Sequence[BaseRate],
labels: Sequence[str] | None,
colors: Sequence[str] | None,
mark: str,
) -> alt.Chart:
rates_ = [rates[i] for i in ixs]
labels_ = None if labels is None else [labels[i] for i in ixs]
colors_ = None if colors is None else [colors[i] for i in ixs]
(P, *Ps), ks = zip( # noqa: N806
*(r.plot_data(T=T, P=P_range, units=units) for r in rates_),
strict=True,
)
for P_ in Ps: # noqa: N806
assert np.allclose(P, P_), f"{P} !~ {P_}"
return plot.general(
y_data=ks,
x_data=P,
labels=labels_,
colors=colors_,
x_label=x_label,
y_label=y_label,
x_scale=plot.log_scale(P_range),
x_axis=plot.log_scale_axis(P_range),
mark=mark,
)

charts = []
for mark in (plot.Mark.line, plot.Mark.point):
ixs = [i for i, r in enumerate(rates) if r.plot_mark == mark]
if ixs:
chart = make_chart(
ixs, rates=rates, labels=labels, colors=colors, mark=mark
)
charts.append(chart)

chart, *others = charts
return (
chart if not others else alt.layer(*charts).resolve_scale(color="independent")
)
Loading