11"""Plotting helpers."""
22
33import itertools
4- from collections .abc import Sequence
4+ from collections .abc import Callable , Sequence
5+ from typing import Any
56
67import altair as alt
78import numpy as np
89import pandas as pd
910from numpy .typing import ArrayLike
11+ from scipy .interpolate import CubicSpline
1012
1113from .. import unit_
1214from ..unit_ import UNITS , Units , UnitsData
@@ -158,10 +160,19 @@ class Mark:
158160 line = "line"
159161
160162
163+ def regular_scale (val_range : tuple [float , float ]) -> alt .Scale :
164+ """Generate a regular scale specification.
165+
166+ :param val_range: Range
167+ :return: Scale
168+ """
169+ return alt .Scale (domain = val_range )
170+
171+
161172def log_scale (val_range : tuple [float , float ]) -> alt .Scale :
162173 """Generate a log scale specification.
163174
164- :param val_range: Rante
175+ :param val_range: Range
165176 :return: Scale
166177 """
167178 return alt .Scale (type = "log" , domain = log_scale_domain (val_range ))
@@ -267,11 +278,32 @@ def recompose_base10(mant: float, exp: int) -> float:
267278MARKS = (Mark .point , Mark .line )
268279
269280
281+ def transformed_spline_interpolator (
282+ x_data : ArrayLike ,
283+ y_data : ArrayLike ,
284+ x_trans : Callable [[ArrayLike ], ArrayLike ] = lambda x : x ,
285+ y_trans : Callable [[ArrayLike ], ArrayLike ] = lambda y : y ,
286+ y_trans_inv : Callable [[ArrayLike ], ArrayLike ] = lambda y : y ,
287+ ) -> Callable [[Any ], np .ndarray ]:
288+ """Generate an inerpolator from data.
289+
290+ :param y_data: Y data
291+ :param x_data: X data
292+ :return: Y interpolator
293+ """
294+ interp_trans_ = CubicSpline (x_trans (x_data ), y_trans (y_data ))
295+
296+ def interp_ (x : Any ) -> np .ndarray :
297+ return np .asarray (y_trans_inv (interp_trans_ (x_trans (x ))))
298+
299+ return interp_
300+
301+
270302def general (
271303 y_data : Sequence [Sequence [float ]],
272304 x_data : Sequence [float ], # noqa: N803
305+ labels : Sequence [str ],
273306 * ,
274- labels : Sequence [str ] | None = None ,
275307 colors : Sequence [str ] | None = None ,
276308 x_label : str | None = None , # noqa: RUF001
277309 y_label : str | None = None , # noqa: RUF001
@@ -280,18 +312,11 @@ def general(
280312 x_axis : alt .Axis | None = None ,
281313 y_axis : alt .Axis | None = None ,
282314 mark : str = Mark .line ,
315+ mark_kwargs : dict | None = None ,
316+ legend : bool = True ,
283317) -> alt .Chart :
284318 """Display as simple plot.
285319
286- We should eventually be able to handle everything through this.
287-
288- :param others: Other rate constants
289- :param others_labels: Labels for other rate constants
290- :param T_range: Temperature range
291- :param P: Pressure
292- :param x_label: X-axis label
293- :param y_label: Y-axis label
294- :param point: Whether to mark with points instead of a line
295320 :return: Chart
296321 """
297322 x_label = "" if x_label is None else x_label
@@ -308,12 +333,10 @@ def general(
308333 else [* POINT_COLOR_CYCLE , * LINE_COLOR_CYCLE ]
309334 )
310335
311- nk , nT = np .shape (y_data ) # noqa: N806
312- colors = colors or list (itertools .islice (itertools .cycle (color_cycle ), nk ))
313- keep_legend = labels is not None
314- labels = labels or [f"k{ i + 1 } " for i in range (nk )]
315- assert len (x_data ) == nT , f"{ x_data } !~ { y_data } "
316- assert len (labels ) == nk , f"{ labels } !~ { y_data } "
336+ ny , nx = np .shape (y_data ) # noqa: N806
337+ colors = colors or list (itertools .islice (itertools .cycle (color_cycle ), ny ))
338+ assert len (x_data ) == nx , f"{ x_data } !~ { y_data } "
339+ assert len (labels ) == ny , f"{ labels } !~ { y_data } "
317340
318341 # Gather data from functons
319342 data_dct = dict (zip (labels , y_data , strict = True ))
@@ -322,18 +345,18 @@ def general(
322345 # Prepare encoding parameters
323346 x = alt .X ("x" , title = x_label , scale = x_scale_ , axis = x_axis_ )
324347 y = alt .Y ("value:Q" , title = y_label , scale = y_scale_ , axis = y_axis_ )
325- color = (
326- alt . Color ( "key:N" , scale = alt . Scale ( domain = labels , range = colors ))
327- if keep_legend
328- else alt .value ( colors [ 0 ])
348+ color = alt . Color (
349+ "key:N" ,
350+ scale = alt . Scale ( domain = labels , range = colors ),
351+ legend = alt .Undefined if legend else None ,
329352 )
330353
331354 chart = alt .Chart (data )
332- chart = (
333- chart . mark_point ( filled = True , opacity = 1 )
334- if mark == Mark . point
335- else chart . mark_line ()
336- )
355+ kwargs = {} if mark_kwargs is None else mark_kwargs
356+ if mark == Mark . point :
357+ chart = chart . mark_point ( ** kwargs )
358+ else :
359+ chart = chart . mark_line ( ** kwargs )
337360
338361 # Create chart
339362 return chart .transform_fold (fold = list (data_dct .keys ())).encode (
@@ -439,6 +462,7 @@ def arrhenius( # noqa: PLR0913
439462 x_unit : str | None = None ,
440463 y_unit : str | None = None ,
441464 mark : str = Mark .line ,
465+ mark_kwargs : dict | None = None ,
442466) -> alt .Chart :
443467 """Display as Arrhenius plot.
444468
@@ -503,11 +527,12 @@ def arrhenius( # noqa: PLR0913
503527 )
504528
505529 chart = alt .Chart (data )
506- chart = (
507- chart .mark_point (filled = True , opacity = 1 )
508- if mark == Mark .point
509- else chart .mark_line ()
510- )
530+ if mark == Mark .point :
531+ kwargs = {"filled" : True , "opacity" : 1 } if mark_kwargs is None else mark_kwargs
532+ chart = chart .mark_point (** kwargs )
533+ else :
534+ kwargs = {} if mark_kwargs is None else mark_kwargs
535+ chart = chart .mark_line (** kwargs )
511536
512537 # Create chart
513538 return chart .transform_fold (fold = list (data_dct .keys ())).encode (
0 commit comments