diff --git a/src/autochem/rate/_reaction.py b/src/autochem/rate/_reaction.py index 6f3af520..88dd9ece 100644 --- a/src/autochem/rate/_reaction.py +++ b/src/autochem/rate/_reaction.py @@ -294,7 +294,7 @@ def make_chart( labels_ = None if labels is None else [labels[i] for i in ixs] colors_ = None if colors is None else [colors[i] for i in ixs] (T, *Ts), ks = zip( # noqa: N806 - *(r.plot_data(T_range=T_range, P=P, units=units) for r in rates_), + *(r.plot_data(T=T_range, P=P, units=units) for r in rates_), strict=True, ) for T_ in Ts: # noqa: N806 diff --git a/src/autochem/rate/data.py b/src/autochem/rate/data.py index 5d5f18b0..1ca42ac7 100644 --- a/src/autochem/rate/data.py +++ b/src/autochem/rate/data.py @@ -3,6 +3,7 @@ import abc import warnings from collections.abc import Mapping, Sequence +from numbers import Number from typing import Annotated, ClassVar, Literal, Self import altair as alt @@ -69,20 +70,31 @@ def plot_mark(self) -> str: def plot_data( self, - T_range: tuple[float, float] = (400, 1250), # noqa: N803 - P: float = 1, # noqa: N803 + T: float | tuple[float, float] = (400, 1250), # noqa: N803 + P: float | tuple[float, float] = 1, # noqa: N803 units: UnitsData | None = None, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Display as an Arrhenius plot. - :param T_range: Temperature range - :param P: Pressure + :param T: Temperature or temperature range + :param P: Pressure or pressure range :param units: Units :return: Chart """ - T = np.linspace(*T_range, 1000) # noqa: N806 - k = self(T=T, P=P, units=units) - return T, k + if isinstance(T, Sequence) and isinstance(P, Number): + P_ = P + T_ = np.linspace(*T, 1000) # noqa: N806 + x_data = T_ + elif isinstance(T, Number) and isinstance(P, Sequence): + P_ = np.linspace(*P, 1000) # noqa: N806 + T_ = T + x_data = P_ + else: + msg = f"3-dimensional plotting not yet implemented:\nT={T}\nP={P}" + raise ValueError(msg) + + y_data = self(T_, P_, units=units) + return x_data, y_data # type: ignore def display( # noqa: PLR0913 self, @@ -103,7 +115,7 @@ def display( # noqa: PLR0913 :param y_label: Y-axis label :return: Chart """ - T, k = self.plot_data(T_range=T_range, P=P, units=units) # noqa: N806 + T, k = self.plot_data(T=T_range, P=P, units=units) # noqa: N806 return plot.arrhenius( ks=[k], T=T, @@ -200,23 +212,33 @@ def plot_mark(self) -> str: def plot_data( self, - T_range: tuple[float, float] = (400, 1250), # noqa: N803 - P: float = 1, # noqa: N803 + T: float | tuple[float, float] = (400, 1250), # noqa: N803 + P: float | tuple[float, float] = 1, # noqa: N803 units: UnitsData | None = None, ) -> tuple[NDArray, NDArray]: """Display as an Arrhenius plot. - :param T_range: Temperature range - :param P: Pressure + :param T: Temperature or temperature range + :param P: Pressure or pressure range :param units: Units :return: Chart """ - T_min, T_max = T_range # noqa: N806 - k = self(T=self.T, P=P, units=units) - (ix,) = np.where( - np.greater_equal(self.T, T_min) & np.less_equal(self.T, T_max), - ) - return np.take(self.T, ix), np.take(k, ix) + if isinstance(T, Sequence) and isinstance(P, Number): + T_ = self.T + P_ = P + (i_,) = np.where(np.greater_equal(T_, T[0]) & np.less_equal(T_, T[1])) + x_data = np.take(T_, i_) + elif isinstance(T, Number) and isinstance(P, Sequence): + T_ = T + P_ = self.P + (i_,) = np.where(np.greater_equal(P_, P[0]) & np.less_equal(P_, P[1])) + x_data = np.take(P_, i_) + else: + msg = f"3-dimensional plotting not yet implemented:\nT={T}\nP={P}" + raise ValueError(msg) + + y_data = np.take(self(T_, P_, units=units), i_) + return x_data, y_data @unit_.manage_units([D.temperature, D.pressure], D.rate_constant) def __call__( @@ -1163,9 +1185,6 @@ def display( # noqa: PLR0913 raise ValueError(msg) order = rate0.order - nr = len(rates) - labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None) - plot_ = plot.arrhenius if plot_type == "arrh" else plot.simple def make_chart( @@ -1179,7 +1198,7 @@ def make_chart( labels_ = None if labels is None else [labels[i] for i in ixs] colors_ = None if colors is None else [colors[i] for i in ixs] (T, *Ts), ks = zip( # noqa: N806 - *(r.plot_data(T_range=T_range, P=P, units=units) for r in rates_), + *(r.plot_data(T=T_range, P=P, units=units) for r in rates_), strict=True, ) for T_ in Ts: # noqa: N806 @@ -1211,3 +1230,93 @@ def make_chart( return ( chart if not others else alt.layer(*charts).resolve_scale(color="independent") ) + + +def display_p( # noqa: PLR0913 + rate_: BaseRate | Sequence[BaseRate], + *, + T: float = 825, # noqa: N803 + P_range: tuple[float, float] = (0.1, 100), # noqa: N803 + units: UnitsData | None = None, + label: str | Sequence[str] | None = None, + color: str | Sequence[str] | None = None, + y_label: str | None = None, # noqa: RUF001 + y_unit: str | None = None, # noqa: RUF001 + check_order: bool = True, +) -> alt.Chart: + """Display one or more reaction rates on an Arrhenius plot. + + :param rxn_: Reaction rate(s) + :param T_range: Temperature range, defaults to (400, 1250) + :param P: Pressure + :param label_: Label(s), defaults to None + :param color_: Color(s), defaults to None + :param x_label: X-axis label + :param y_label: Y-axis label + """ + rates = [rate_] if isinstance(rate_, BaseRate) else rate_ + labels = [label] if isinstance(label, str) else label + colors = [color] if isinstance(color, str) else color + rate0, *rates_ = rates + if check_order: + for other_rate in rates_: + if not rate0.order == other_rate.order: + msg = f"Mismatched reaction orders: {rate0} !~ {other_rate}" + raise ValueError(msg) + order = rate0.order + + units = UNITS if units is None else Units.model_validate(units) + x_unit = unit_.pretty_string(units.pressure) + x_label = f"𝑃 ({x_unit})" + + y_label = "𝑘" if y_label is None else y_label + y_unit = ( + unit_.pretty_string(units.rate_constant(order)) if y_unit is None else y_unit + ) + if y_unit: + y_label = f"{y_label} ({y_unit})" + + nr = len(rates) + labels = labels or ([f"k{i + 1}" for i in range(nr)] if nr > 1 else None) + + def make_chart( + ixs: Sequence[int], + rates: Sequence[BaseRate], + labels: Sequence[str] | None, + colors: Sequence[str] | None, + mark: str, + ) -> alt.Chart: + rates_ = [rates[i] for i in ixs] + labels_ = None if labels is None else [labels[i] for i in ixs] + colors_ = None if colors is None else [colors[i] for i in ixs] + (P, *Ps), ks = zip( # noqa: N806 + *(r.plot_data(T=T, P=P_range, units=units) for r in rates_), + strict=True, + ) + for P_ in Ps: # noqa: N806 + assert np.allclose(P, P_), f"{P} !~ {P_}" + return plot.general( + y_data=ks, + x_data=P, + labels=labels_, + colors=colors_, + x_label=x_label, + y_label=y_label, + x_scale=plot.log_scale(P_range), + x_axis=plot.log_scale_axis(P_range), + mark=mark, + ) + + charts = [] + for mark in (plot.Mark.line, plot.Mark.point): + ixs = [i for i, r in enumerate(rates) if r.plot_mark == mark] + if ixs: + chart = make_chart( + ixs, rates=rates, labels=labels, colors=colors, mark=mark + ) + charts.append(chart) + + chart, *others = charts + return ( + chart if not others else alt.layer(*charts).resolve_scale(color="independent") + ) diff --git a/src/autochem/util/plot.py b/src/autochem/util/plot.py index 59245e26..d83ef878 100644 --- a/src/autochem/util/plot.py +++ b/src/autochem/util/plot.py @@ -15,7 +15,7 @@ class Color: """Color hex values.""" - # Line colors: + # Core colors: blue = "#0066ff" red = "#ff0000" green = "#1ab73a" @@ -23,17 +23,103 @@ class Color: purple = "#8533ff" pink = "#d0009a" yellow = "#ffcd00" - teal = "#00b3b3" # bright cyan-green - lime = "#b6e300" # light, vivid green-yellow - magenta = "#ff33cc" # vibrant pink-purple - sky = "#33bbff" # lighter blue variant - olive = "#808000" # muted yellow-green + # Extra colors: + teal = "#008080" + cyan = "#00ffff" + magenta = "#ff00ff" + lime = "#00ff00" + navy = "#000080" + maroon = "#800000" + olive = "#808000" + coral = "#ff7f50" + gold = "#ffd700" + sky_blue = "#87ceeb" + violet = "#ee82ee" + indigo = "#4b0082" + salmon = "#fa8072" + mint = "#98ff98" + peach = "#ffdab9" + forest_green = "#228b22" + mustard = "#ffdb58" + steel_blue = "#4682b4" + plum = "#dda0dd" + ochre = "#cc7722" # Point colors: black = "#000000" gray = "#808080ff" light_gray = "#bfbfbfff" brown = "#916e6e" - brown2 = "#a0522d" # earthy neutral + sienna = "#a0522d" + + # ChatGPT generated colors: + # # --- Core colors (your originals) --- + # blue = "#0066ff" + # red = "#ff0000" + # green = "#1ab73a" + # orange = "#ef7810" + # purple = "#8533ff" + # pink = "#d0009a" + # yellow = "#ffcd00" + # black = "#000000" + # gray = "#808080" + # light_gray = "#bfbfbf" + # brown = "#916e6e" + + # # --- Strong blues --- + # navy = "#003f5c" + # royal_blue = "#4169e1" + # sky_blue = "#1ca8dd" + # teal = "#008080" + # turquoise = "#00a0b0" + + # # --- Strong reds / magentas --- + # crimson = "#dc143c" + # brick = "#b22222" + # carmine = "#960018" + # magenta = "#ff00ff" + # raspberry = "#b03060" + + # # --- Oranges & yellows --- + # amber = "#ffbf00" + # burnt_orange = "#cc5500" + # gold = "#d4a017" + # ochre = "#c07a28" + # mustard = "#e1ad01" + + # # --- Greens --- + # forest_green = "#228b22" + # emerald = "#50c878" + # lime_green = "#32cd32" + # olive = "#6b8e23" + # jade = "#00a86b" + + # # --- Purples & violets --- + # indigo = "#4b0082" + # violet = "#7b68ee" + # plum = "#8e4585" + # mauve = "#7d5ba6" + # orchid = "#da70d6" + + # # --- Cyans & aquas --- + # cyan = "#17becf" + # sea_green = "#20b2aa" + # steel_blue = "#4682b4" + # cerulean = "#007ba7" + # azure = "#007fff" + + # # --- Earth tones & neutrals --- + # sienna = "#a0522d" + # copper = "#b87333" + # chocolate = "#7b3f00" + # slate_gray = "#708090" + # charcoal = "#36454f" + + # # --- Extras for balance --- + # coral = "#ff7f50" + # maroon = "#800000" + # olive_drab = "#556b2f" + # midnight_blue = "#191970" + # royal_purple = "#7851a9" LINE_COLOR_CYCLE = [ @@ -45,10 +131,15 @@ class Color: Color.yellow, Color.orange, Color.teal, - Color.lime, + Color.cyan, Color.magenta, - Color.sky, + Color.lime, + Color.navy, + Color.maroon, Color.olive, + Color.coral, + Color.gold, + Color.sky_blue, ] @@ -57,7 +148,6 @@ class Color: Color.gray, Color.light_gray, Color.brown, - Color.brown2, ] @@ -68,9 +158,189 @@ class Mark: line = "line" +def log_scale(val_range: tuple[float, float]) -> alt.Scale: + """Generate a log scale specification. + + :param val_range: Rante + :return: Scale + """ + return alt.Scale(type="log", domain=log_scale_domain(val_range)) + + +def log_scale_axis(val_range: tuple[float, float]) -> alt.Axis: + """Generate a nice log scale axis. + + :param val_range: Range + :return: Axis + """ + max_exp = np.max(np.abs(np.log10(val_range))) + fmt = ".0e" if max_exp > 3 else alt.Undefined + vals = log_scale_values(val_range) + label_expr_condition = " ||\n".join( + f"(abs(datum.value - {v}) < 1e-5)" for v in vals + ) + label_expr = f"({label_expr_condition}) ? datum.label : ''" + return alt.Axis(format=fmt, values=log_scale_ticks(val_range), labelExpr=label_expr) + + +def log_scale_domain(val_range: tuple[float, float]) -> tuple[float, float]: + """Determine log scale ticks for a given range. + + :param val_range: Range + :return: Ticks + """ + # Determine tick min and max + val_min, val_max = val_range + mant_min, exp_min = decompose_base10(val_min) + mant_max, exp_max = decompose_base10(val_max) + start = recompose_base10(mant=np.floor(mant_min), exp=exp_min) + stop = recompose_base10(mant=np.ceil(mant_max), exp=exp_max) + return start, stop + + +def log_scale_values(val_range: tuple[float, float]) -> list[float]: + """Determine log scale ticks for a given range. + + :param val_range: Range + :return: Ticks + """ + val_min, val_max = log_scale_domain(val_range) + _, exp_min = decompose_base10(val_min) + _, exp_max = decompose_base10(val_max) + + # Add power of 10 steps in between + exp_start = exp_min + 1 + exp_stop = exp_max + exp_count = exp_stop - exp_start + 1 + powers_of_10 = [] + if exp_count > 0: + powers_of_10 = np.logspace( + exp_start, exp_stop, num=exp_count, endpoint=True + ).tolist() + return [val_min, *powers_of_10, val_max] + + +def log_scale_ticks(val_range: tuple[float, float]) -> list[float]: + """Determine log scale ticks for a given range. + + :param val_range: Range + :return: Ticks + """ + vals = log_scale_values(val_range) + if len(vals) > 5: + return vals + + # If the scale is not too large, add intervening ticks + bounds = [*vals, None] + ticks = [] + for start, stop in itertools.pairwise(bounds): + ticks.append(start) + if stop is not None: + _, exp = decompose_base10(start) + vals = [recompose_base10(m, exp) for m in range(2, 10)] + vals = [v for v in vals if start < v and v < stop] + ticks.extend(vals) + return ticks + + +def decompose_base10(val: float) -> tuple[float, int]: + """Decompose value into base-10 mantissa and exponent. + + :param val: Value + :return: Mantissa and exponent + """ + exp = np.floor(np.log10(val)).astype(int) + mant = val / (10.0**exp) + return float(mant), int(exp) + + +def recompose_base10(mant: float, exp: int) -> float: + """Recompose value from base-10 mantissa and exponent + + :param mant: Mantissa + :param exp: Exponent + :return: Value + """ + return float(mant * 10.0**exp) + + MARKS = (Mark.point, Mark.line) +def general( + y_data: Sequence[Sequence[float]], + x_data: Sequence[float], # noqa: N803 + *, + labels: Sequence[str] | None = None, + colors: Sequence[str] | None = None, + x_label: str | None = None, # noqa: RUF001 + y_label: str | None = None, # noqa: RUF001 + x_scale: alt.Scale | None = None, + y_scale: alt.Scale | None = None, + x_axis: alt.Axis | None = None, + y_axis: alt.Axis | None = None, + mark: str = Mark.line, +) -> alt.Chart: + """Display as simple plot. + + We should eventually be able to handle everything through this. + + :param others: Other rate constants + :param others_labels: Labels for other rate constants + :param T_range: Temperature range + :param P: Pressure + :param x_label: X-axis label + :param y_label: Y-axis label + :param point: Whether to mark with points instead of a line + :return: Chart + """ + x_label = "" if x_label is None else x_label + y_label = "" if y_label is None else y_label + x_scale_ = alt.Undefined if x_scale is None else x_scale + y_scale_ = alt.Undefined if y_scale is None else y_scale + x_axis_ = alt.Undefined if x_axis is None else x_axis + y_axis_ = alt.Undefined if y_axis is None else y_axis + + assert mark in MARKS, f"{mark} not in {MARKS}" + color_cycle = ( + LINE_COLOR_CYCLE + if mark == Mark.line + else [*POINT_COLOR_CYCLE, *LINE_COLOR_CYCLE] + ) + + nk, nT = np.shape(y_data) # noqa: N806 + colors = colors or list(itertools.islice(itertools.cycle(color_cycle), nk)) + keep_legend = labels is not None + labels = labels or [f"k{i + 1}" for i in range(nk)] + assert len(x_data) == nT, f"{x_data} !~ {y_data}" + assert len(labels) == nk, f"{labels} !~ {y_data}" + + # Gather data from functons + data_dct = dict(zip(labels, y_data, strict=True)) + data = pd.DataFrame({"x": x_data, **data_dct}) + + # Prepare encoding parameters + x = alt.X("x", title=x_label, scale=x_scale_, axis=x_axis_) + y = alt.Y("value:Q", title=y_label, scale=y_scale_, axis=y_axis_) + color = ( + alt.Color("key:N", scale=alt.Scale(domain=labels, range=colors)) + if keep_legend + else alt.value(colors[0]) + ) + + chart = alt.Chart(data) + chart = ( + chart.mark_point(filled=True, opacity=1) + if mark == Mark.point + else chart.mark_line() + ) + + # Create chart + return chart.transform_fold(fold=list(data_dct.keys())).encode( + x=x, y=y, color=color + ) + + def simple( ks: Sequence[Sequence[float]], T: Sequence[float], # noqa: N803 @@ -169,7 +439,6 @@ def arrhenius( # noqa: PLR0913 x_unit: str | None = None, y_unit: str | None = None, mark: str = Mark.line, - domain: tuple[float, float] | None = None, ) -> alt.Chart: """Display as Arrhenius plot. @@ -216,31 +485,21 @@ def arrhenius( # noqa: PLR0913 data = pd.DataFrame({"x": np.divide(1000, T), **data_dct}) # Determine exponent range - if domain is None: - vals_arr = np.array(list(data_dct.values())) - is_nan = np.isnan(vals_arr) - exp_arr = np.log10(vals_arr, where=~is_nan) - exp_arr[is_nan] = 0.0 - exp_arr = np.rint(exp_arr).astype(int) - exp_max = np.max(exp_arr).item() - exp_min = np.min(exp_arr).item() - y_vals = [10**x for x in range(exp_min, exp_max + 2)] - domain = alt.Undefined - else: - y_vals = alt.Undefined + vals_arr = np.array(list(data_dct.values())) + y_range = (np.nanmin(vals_arr), np.nanmax(vals_arr)) # Prepare encoding parameters x = alt.X("x", title=x_label, scale=alt.Scale(zero=False)) - y = ( - alt.Y("value:Q", title=y_label) - .scale(type="log", domain=domain) - .axis(format=".1e") - .axis(format=".1e", values=y_vals) + y = alt.Y( + "value:Q", + title=y_label, + scale=log_scale(y_range), + axis=log_scale_axis(y_range), ) - color = ( - alt.Color("key:N", scale=alt.Scale(domain=labels, range=colors)) - if keep_legend - else alt.value(colors[0]) + color = alt.Color( + "key:N", + scale=alt.Scale(domain=labels, range=colors), + legend=alt.Undefined if keep_legend else None, ) chart = alt.Chart(data)