diff --git a/unseen/eva.py b/unseen/eva.py index 049ebca..da1a3d8 100644 --- a/unseen/eva.py +++ b/unseen/eva.py @@ -72,6 +72,7 @@ def fit_gev( fc=None, floc=None, fscale=None, + bounds=None, retry_fit=False, assert_good_fit=False, goodness_of_fit_kwargs={}, @@ -100,6 +101,8 @@ def fit_gev( Initial guess of trend parameters. If None, the trend is fixed at zero. fc, floc, fscale : float, default None Fixed values for the shape, location and scale parameters. + bounds : list of tuples, optional + Custom bounds for the shape, loc0, loc1, scale0 and scale1 parameters. retry_fit : bool, default True Retry fit with initial estimate generated by passing data[::2] to fitstart. The best fit is returned. See notes. @@ -172,7 +175,18 @@ def fit_gev( def check_support(dparams, covariate=None): """Check if the GEV parameters are within the support of the distribution.""" c, loc, scale = unpack_gev_params(dparams, covariate) - assert np.isfinite(loc) and np.isfinite(scale) and scale > 0 + supported = ( + np.isfinite(loc).all() and np.isfinite(scale).all() and (scale > 0).all() + ) + if not supported: + warnings.warn(f"GEV parameters are not supported: {dparams}.") + + def ns_gev_parameter_bounds(): + """Get bounds for the GEV parameters.""" + bounds = [(-np.inf, np.inf)] * 5 + bounds[0] = (-np.inf, np.inf) + bounds[3] = (1e-6, np.inf) # Positive scale parameter + return bounds def _fit_1d( data, @@ -184,6 +198,7 @@ def _fit_1d( fc, floc, fscale, + bounds, retry_fit, assert_good_fit, pick_best_model, @@ -192,6 +207,7 @@ def _fit_1d( use_basinhopping, basinhopping_kwargs, goodness_of_fit_kwargs, + scipy_fit_kwargs, ): """Estimate distribution parameters.""" if np.all(~np.isfinite(data)): @@ -203,15 +219,9 @@ def _fit_1d( # Drop NaNs in data mask = np.isfinite(data) data = data[mask] - if not stationary: + if not stationary and not np.isscalar(covariate): covariate = covariate[mask] - scipy_fit_kwargs = {} - - for kw in ["fc", "floc", "fscale"]: - if eval(kw) is not None: - scipy_fit_kwargs[kw] = eval(kw) - # Initial estimates of distribution parameters for MLE if isinstance(fitstart, str): dparams_i = _fitstart_1d(data, fitstart, scipy_fit_kwargs) @@ -256,11 +266,7 @@ def _fit_1d( if not stationary: dparams_ns_i = [dparams_i[0], dparams_i[1], loc1, dparams_i[2], scale1] - # Define parameter bounds and fix - bounds = [(-np.inf, np.inf)] * 5 - # bounds[0] = (-1, 1) - bounds[3] = (1e-6, np.inf) # Allow positive scale parameter - # # Set initial value and bounds of fixed parameters + # Set initial value and bounds of fixed parameters for i, fixed in zip([0, 1, 3], [fc, floc, fscale]): if fixed is not None: dparams_ns_i[i] = fixed @@ -289,7 +295,6 @@ def _fit_1d( _gev_nllf, dparams_ns_i, args=(data, covariate), bounds=bounds ) - # todo: add minimize success checks dparams_ns = np.array([i for i in res.x], dtype="float64") # Stationary and nonstationary model relative goodness of fit @@ -334,6 +339,9 @@ def _fit_1d( warnings.warn(f"Data fit failed (p-value={pvalue} NLL={nll:.3f}).") # check_support(dparams, covariate) + if not np.isnan([dparams]).any(): + check_support(dparams, covariate) + return dparams if stationary and pick_best_model: @@ -344,12 +352,26 @@ def _fit_1d( if covariate is not None: covariate = _format_covariate(data, covariate, core_dim) else: - # check if stationary or loca1 and scale1 are None using np.all - assert stationary is True or ( - loc1 is None and scale1 is None - ), "Covariate must be provided for a nonstationary fit." + # Check if stationary or loc1 and scale1 are None using np.all + assert stationary is True, "Covariate must be provided for a nonstationary fit." covariate = 0 # No covariate (can't be None) + if stationary and use_basinhopping: + # Restrict the trend parameters to zero if stationary + stationary = False + loc1, scale1 = None, None + kwargs.update({"stationary": False, "loc1": loc1, "scale1": scale1}) + + if bounds is None: + kwargs["bounds"] = ns_gev_parameter_bounds() + + scipy_fit_kwargs = {} + + for kw in ["fc", "floc", "fscale"]: + if eval(kw) is not None: + scipy_fit_kwargs[kw] = eval(kw) + + kwargs["scipy_fit_kwargs"] = scipy_fit_kwargs # Input core dimensions if core_dim is not None and hasattr(covariate, core_dim): # Covariate has the same core dimension as data @@ -359,6 +381,7 @@ def _fit_1d( input_core_dims = [[core_dim], []] n_params = 3 if stationary else 5 + # Fit data to distribution parameters dparams = apply_ufunc( _fit_1d, @@ -373,13 +396,26 @@ def _fit_1d( dask_gufunc_kwargs={"output_sizes": {"dparams": n_params}}, ) - # Format output (consistent with xclim) + # Format output + if loc1 is None and scale1 is None: + # Remove both trend parameters if not used + stationary = True + if hasattr(dparams, "isel"): + dparams = dparams.isel(dparams=[0, 1, 3], drop=True) + else: + dparams = dparams[[0, 1, 3]] + if isinstance(data, DataArray): - if n_params == 3: - # todo: change loc to location - dparams.coords["dparams"] = ["c", "loc", "scale"] + if dparams["dparams"].size == 3: + dparams.coords["dparams"] = ["c", "location", "scale"] else: - dparams.coords["dparams"] = ["c", "loc0", "loc1", "scale0", "scale1"] + dparams.coords["dparams"] = [ + "c", + "location_0", + "location_1", + "scale_0", + "scale_1", + ] # Add coordinates for the distribution parameters dist_name = "genextreme" if stationary else "nonstationary genextreme" @@ -417,7 +453,7 @@ def penalised_sum(x): return total + penalty_sum -def _gev_nllf(dparams, x, covariate=None): +def _gev_nllf(dparams, x, covariate=0): """GEV penalised negative log-likelihood function. Parameters @@ -448,10 +484,10 @@ def _gev_nllf(dparams, x, covariate=None): The NLLF is not finite when the shape is nonzero and Z is negative because the PDF is zero (i.e., ``log(0)=inf)``). """ + shape, loc, scale = unpack_gev_params(dparams, covariate) shape = -shape # Reverse shape sign for consistency with scipy.stats results - # transform scale parameter so that it is always positive valid = scale > 0 s = (x - loc) / scale @@ -559,8 +595,8 @@ def __call__(self, x): """take a random step but ensure the new position is within the bounds""" min_step = np.maximum(self.xmin - x, -self.stepsize) max_step = np.minimum(self.xmax - x, self.stepsize) + x = x + self.rng.uniform(low=min_step, high=max_step, size=x.shape) - x += self.rng.uniform(low=min_step, high=max_step, size=x.shape) return x @@ -739,7 +775,7 @@ def get_best_GEV_model_1d( return dparams -def unpack_gev_params(dparams, covariate=None): +def unpack_gev_params(dparams, covariate=0): """Unpack shape, loc, scale from dparams. Parameters @@ -779,7 +815,7 @@ def unpack_gev_params(dparams, covariate=None): return shape, loc, scale -def get_return_period(event, dparams=None, covariate=None, **kwargs): +def get_return_period(event, dparams=None, covariate=0, **kwargs): """Get return periods for a given events. Parameters @@ -818,7 +854,7 @@ def get_return_period(event, dparams=None, covariate=None, **kwargs): return 1.0 / probability -def get_return_level(return_period, dparams=None, covariate=None, **kwargs): +def get_return_level(return_period, dparams=None, covariate=0, **kwargs): """Get the return levels for given return periods. Parameters @@ -899,6 +935,13 @@ def _empirical_return_period(da, event): return ri +def empirical_return_level(da, return_period, **kwargs): + """Empirical return level of an event (1D).""" + cdf = ecdf(da).cdf + probability = 1 - (1 / return_period) + return np.interp(probability, cdf.probabilities, cdf.quantiles) + + def get_empirical_return_level(da, return_period, core_dim="time"): """Calculate the empirical return period of an event. @@ -917,14 +960,8 @@ def get_empirical_return_level(da, return_period, core_dim="time"): The event return level (e.g., 100 mm of rainfall) """ - def _empirical_return_level(da, period): - """Empirical return level of an event (1D).""" - sf = ecdf(da).sf - probability = 1 - (1 / period) - return np.interp(probability, sf.probabilities, sf.quantiles) - return_level = apply_ufunc( - _empirical_return_level, + empirical_return_level, da, return_period, input_core_dims=[[core_dim], []], @@ -967,6 +1004,7 @@ def gev_confidence_interval( n_resamples=1000, ci=0.95, core_dim="time", + covariate=0, fit_kwargs={}, ): """ @@ -999,10 +1037,10 @@ def gev_confidence_interval( ci_bounds : xarray.DataArray Confidence intervals with lower and upper bounds along dim 'quantile' """ - # todo: add max_shape_ratio? + # Replace core dim with the one from the fit_kwargs if it exists core_dim = fit_kwargs.pop("core_dim", core_dim) - covariate = fit_kwargs.pop("covariate", None) + covariate = fit_kwargs.pop("covariate", covariate) rng = np.random.default_rng(seed=0) if dparams is None: @@ -1040,11 +1078,11 @@ def gev_confidence_interval( if return_period is not None: result = get_return_level( - return_period, gev_params_resampled, core_dim=core_dim + return_period, gev_params_resampled, core_dim=core_dim, covariate=covariate ) elif return_level is not None: result = get_return_period( - return_level, gev_params_resampled, core_dim=core_dim + return_level, gev_params_resampled, core_dim=core_dim, covariate=covariate ) # Bounds of confidence intervals @@ -1548,9 +1586,9 @@ def spatial_plot_gev_parameters( ) # Add coastlines and lat/lon labels ax.coastlines() - ax.set_title(f"{params[i]}") ax.set_xlabel(None) ax.set_ylabel(None) + ax.set_title(f"{params[i].title()} parameter") ax.xaxis.set_major_formatter(LongitudeFormatter()) ax.yaxis.set_major_formatter(LatitudeFormatter()) ax.xaxis.set_minor_locator(AutoMinorLocator()) @@ -1563,11 +1601,11 @@ def spatial_plot_gev_parameters( ax.yaxis.set_visible(True) if dataset_name: - fig.suptitle(f"{dataset_name} GEV parameters", y=0.8 if stationary else 0.99) + fig.suptitle(f"{dataset_name} GEV parameters", y=0.75 if stationary else 0.97) - if not stationary: - # Hide the empty subplot - axes[-1].set_visible(False) + # Hide empty subplots + for ax in [ax for ax in axes if not ax.collections]: + ax.axis("off") plt.tight_layout() if outfile: @@ -1587,10 +1625,9 @@ def _parse_command_line(): parser.add_argument("outfile", type=str, help="Output file") parser.add_argument( "--stack_dims", - type=str, - nargs="*", - # default=["ensemble", "init_date", "lead_time"], - help="Dimensions to stack", + action="store_true", + default=False, + help="Stack ensemble, init_date, and lead_time dimensions", ) parser.add_argument("--core_dim", type=str, default="time", help="Core dimension") parser.add_argument( @@ -1626,6 +1663,12 @@ def _parse_command_line(): default=False, help="Retry fit if it doesn't pass the goodness of fit test", ) + parser.add_argument( + "--use_basinhopping", + action="store_true", + default=False, + help="Use basinhopping optimizer to find the best fit", + ) parser.add_argument( "--assert_good_fit", action="store_true", @@ -1729,11 +1772,16 @@ def _main(): else: ds = ds.where(ds[args.lead_dim] >= args.min_lead) - # Stack dimensions along new "sample" dimension - if all([dim in ds[args.var].dims for dim in args.stack_dims]): - ds = ds.stack(**{"sample": args.stack_dims}, create_index=False) + # Stack ensemble, init and lead dimensions along new "sample" dimension + if args.stack_dims: + # Check if the dimensions exist in the dataset + dims = [] + for dim in (args.ensemble_dim, args.init_dim, args.lead_dim): + if dim in ds.dims: + dims.append(dim) + ds = ds.stack(**{"sample": dims}, create_index=False) ds = ds.chunk(dict(sample=-1)) # fixes CAFE large chunk error - args.core_dim = "sample" + args.core_dim = "sample" # Set core dimension to the new stacked dimension # Drop the maximum value before fitting if args.drop_max: @@ -1754,6 +1802,7 @@ def _main(): fitstart=args.fitstart, covariate=covariate, retry_fit=args.retry_fit, + use_basinhopping=args.use_basinhopping, assert_good_fit=args.assert_good_fit, pick_best_model=args.pick_best_model, ) diff --git a/unseen/fileio.py b/unseen/fileio.py index cc3d00b..f1b2999 100644 --- a/unseen/fileio.py +++ b/unseen/fileio.py @@ -10,6 +10,7 @@ import pandas as pd import shutil import yaml +import warnings import xarray as xr import zipfile @@ -37,6 +38,7 @@ def open_dataset( shapefile=None, shapefile_label_header=None, shape_overlap=None, + shape_buffer=None, combine_shapes=False, spatial_agg="none", time_dim="time", @@ -221,6 +223,7 @@ def open_dataset( overlap_fraction=shape_overlap, header=shapefile_label_header, combine_shapes=combine_shapes, + shape_buffer=shape_buffer, lat_dim=lat_dim, lon_dim=lon_dim, ) @@ -240,7 +243,7 @@ def open_dataset( if no_leap_days: ds = ds.sel(time=~((ds[time_dim].dt.month == 2) & (ds[time_dim].dt.day == 29))) if rolling_sum_window: - ds = ds.rolling({time_dim: rolling_sum_window}).sum() + ds = ds.rolling({time_dim: rolling_sum_window}).sum(dim=time_dim) if time_freq: assert time_agg, "Provide a time_agg" assert variables, "Variables argument is required for temporal aggregation" @@ -517,7 +520,8 @@ def _fix_metadata(ds, metadata_file): if "round_coords" in metadata_dict: for coord in metadata_dict["round_coords"]: - ds = ds.assign_coords({coord: ds[coord].round(decimals=6)}) + ds = ds.assign_coords({coord: ds[coord].round(decimals=3)}) + warnings.warn(f"Rounded {coord} to 3 decimal places") if "units" in metadata_dict: for var, units in metadata_dict["units"].items(): @@ -785,6 +789,12 @@ def _parse_command_line(): default=None, help="Fraction that a grid cell must overlap with a shape to be included", ) + parser.add_argument( + "--shp_buffer", + type=float, + default=None, + help="Buffer the shape by this amount (in degrees)", + ) parser.add_argument( "--shp_header", type=str, @@ -883,6 +893,12 @@ def _parse_command_line(): default=False, help="Force a standard calendar when opening each file", ) + parser.add_argument( + "--lead_dim_max_size", + type=int, + default=None, + help="Maximum size of the lead dimension (e.g. 9) [default=None]", + ) args = parser.parse_args() return args @@ -911,6 +927,7 @@ def _main(): "shapefile": args.shapefile, "shapefile_label_header": args.shp_header, "shape_overlap": args.shp_overlap, + "shape_buffer": args.shp_buffer, "combine_shapes": args.combine_shapes, "spatial_agg": args.spatial_agg, "lat_dim": args.lat_dim, @@ -947,6 +964,12 @@ def _main(): **kwargs, ) temporal_dim = "lead_time" + + if args.lead_dim_max_size is not None: + if ds["lead_time"].size > args.lead_dim_max_size: + # Drop the leads after the max size + ds = ds.isel(lead_time=slice(0, args.lead_dim_max_size)) + else: ds = open_dataset(args.infiles, **kwargs) temporal_dim = args.time_dim diff --git a/unseen/general_utils.py b/unseen/general_utils.py index d9960b9..e1c829f 100644 --- a/unseen/general_utils.py +++ b/unseen/general_utils.py @@ -4,6 +4,7 @@ from matplotlib.ticker import AutoMinorLocator import matplotlib.pyplot as plt import numpy as np +import subprocess import xarray as xr from xarray import Dataset from xclim.core import units @@ -160,8 +161,10 @@ def regrid(ds, ds_grid, method="conservative", **kwargs): Notes ----- - The input and target grids should have the same coordinate names. - - Recommended using the "conservative" method for regridding from fine to course and "bilinear" for the opposite. + - Recommended using the "conservative" method for regridding from fine to + coarse and "bilinear" for the opposite. """ + # Copy attributes global_attrs = ds.attrs if isinstance(ds, Dataset): @@ -180,6 +183,51 @@ def regrid(ds, ds_grid, method="conservative", **kwargs): return ds_regrid +def get_model_makefile_dict(cwd, project_details, model, model_details, obs_details): + """Get dictionary of variables defined in config files and makefile. + + Parameters + ---------- + cwd : str + Directory of makefile + project_details : str + Project details file + model : str + Model name + model_details : str + Model details file + obs_details : str + Observed data details file + + Returns + ------- + model_var_dict : dict + Dictionary of model variables defined in the makefile and details files + """ + + args = [ + "make", + "print_file_vars", + f"PROJECT_DETAILS={project_details}", + f"MODEL={model}", + f"MODEL_DETAILS={model_details}", + f"OBS_DETAILS={obs_details}", + ] + + result = subprocess.run(args, capture_output=True, text=True, cwd=cwd) + + # Read stdout into dictionary + model_var_dict = {} + for line in result.stdout.splitlines(): + if "=" in line: + key, value = line.split("=", 1) + model_var_dict[key.lower()] = value + + # Sort dictionary by key + model_var_dict = dict(sorted(model_var_dict.items())) + return model_var_dict + + def plot_timeseries_scatter( da, da_obs=None, @@ -219,6 +267,7 @@ def plot_timeseries_scatter( ax : matplotlib.axes.Axes Axis object """ + if units is None: if "units" in da.attrs: units = da.attrs["units"] @@ -297,7 +346,8 @@ def plot_timeseries_box_plot( Notes ----- - Ensure all time dimensions are set to the correct frequency before calling this function. + Ensure all time dimensions are set to the correct frequency before calling + this function. Examples -------- @@ -308,6 +358,7 @@ def plot_timeseries_box_plot( da = da.stack({"sample": ["ensemble", "init_date", "lead_time"]}) plot_timeseries_box_plot(da, time_dim="time") """ + if units is None: if "units" in da.attrs: units = da.attrs["units"] diff --git a/unseen/similarity.py b/unseen/similarity.py index 6230864..42bc465 100644 --- a/unseen/similarity.py +++ b/unseen/similarity.py @@ -200,6 +200,7 @@ def similarity_spatial_plot(ds, dataset_name=None, outfile=None, alpha=0.05): outfile : str, optional Filename to save the plot """ + fig, axes = plt.subplots( 2, 2, @@ -210,21 +211,25 @@ def similarity_spatial_plot(ds, dataset_name=None, outfile=None, alpha=0.05): constrained_layout=True, ) + # Iterate through vars: ks_statistic, ks_pval, ad_statistic, ad_pval for ax, var in zip(axes.flat, ds.data_vars): + + kwargs = {} if "statistic" in var: long_name = ds[var].attrs["long_name"].replace("_", " ").title() - if ds[var].min() < 0: - kwargs = dict(cmap=plt.cm.coolwarm) + if ds[var].min() > 0: + kwargs["cmap"] = plt.cm.viridis else: - kwargs = dict(cmap=plt.cm.viridis) + kwargs["cmap"] = plt.cm.RdBu_r + elif "pval" in var: - long_name = f"{long_name} p-value" - kwargs = dict( - cmap=plt.cm.seismic, - norm=TwoSlopeNorm(vcenter=alpha, vmin=0, vmax=0.4), - ) - kwargs["cmap"].set_bad("gray") + long_name = f"{long_name} p-value" # use previous long_name + # Centre the colormap at alpha + kwargs["cmap"] = plt.cm.coolwarm_r + vmax = 0.5 if var == "ks_pval" else 0.25 + kwargs["norm"] = TwoSlopeNorm(vcenter=alpha, vmin=0, vmax=vmax) + kwargs["cmap"].set_bad("gray") ds[var].plot( ax=ax, transform=PlateCarree(), diff --git a/unseen/spatial_selection.py b/unseen/spatial_selection.py index a50f942..006b60b 100644 --- a/unseen/spatial_selection.py +++ b/unseen/spatial_selection.py @@ -136,6 +136,7 @@ def select_shapefile_regions( overlap_fraction=None, header=None, combine_shapes=False, + shape_buffer=None, lat_dim="lat", lon_dim="lon", ): @@ -156,6 +157,8 @@ def select_shapefile_regions( Name of the shapefile column containing the region names combine_shape : bool, default False Create an extra region which combines them all + shape_buffer : float, default None + Buffer the shapes by this amount (in degrees) lat_dim: str, default 'lat' Name of the latitude dimension in ds lon_dim: str, default 'lon' @@ -175,6 +178,8 @@ def select_shapefile_regions( "sum", "weighted_mean", "median", + "min", + "max", "none", ], "Invalid spatial aggregation method" @@ -188,12 +193,16 @@ def select_shapefile_regions( new_dim_names[lon_dim] = "lon" if new_dim_names: ds = ds.rename_dims(new_dim_names) + assert "lat" in ds.coords, "Latitude coordinate must be called lat" assert "lon" in ds.coords, "Longitude coordinate must be called lon" lons = ds["lon"].values lats = ds["lat"].values + if shape_buffer is not None: + shapes = shapefile_with_buffer(shapes, shape_buffer, tolerance=10000) + if overlap_fraction: if isinstance(overlap_fraction, str): overlap_fraction = float(overlap_fraction) # Fix cmd line input @@ -208,7 +217,13 @@ def select_shapefile_regions( mask = _squeeze_and_drop_region(mask) # Spatial aggregation - agg_func_map = {"mean": np.nanmean, "sum": np.nansum, "median": np.nanmedian} + agg_func_map = { + "mean": np.nanmean, + "sum": np.nansum, + "median": np.nanmedian, + "min": np.nanmin, + "max": np.nanmax, + } if agg in agg_func_map.keys() and not (overlap_fraction or combine_shapes): ds = ds.groupby(mask).reduce(agg_func_map[agg], keep_attrs=True) @@ -295,8 +310,8 @@ def fraction_overlap_mask(shapes_gp, lons, lats, min_overlap): assert min_overlap > 0.0, "Minimum overlap must be fractional value > 0" assert min_overlap <= 1.0, "Minimum overlap must be fractional value <= 1.0" - _check_regular_grid(lons) - _check_regular_grid(lats) + # _check_regular_grid(lons) + # _check_regular_grid(lats) shapes_rm = regionmask.from_geopandas(shapes_gp) fraction = overlap_fraction(shapes_rm, lons, lats) @@ -353,6 +368,43 @@ def overlap_fraction(shapes_rm, lons, lats): return mask_sampled +def shapefile_with_buffer(shapes, shape_buffer, tolerance=10_000): + """Re-project geometries to a projected CRS, coarsen and buffer. + + Parameters + ---------- + shapes : geopandas.GeoDataFrame + Shapes/regions + shape_buffer : float + Buffer the shapes by this amount (in degrees) + tolerance : float, default 10_000 + Coarsen the geometries to this tolerance (in metres) + + Returns + ------- + shapes : geopandas.GeoDataFrame + GeoDataFrame with buffered geometries + """ + + # Get original projection CRS + original_crs = shapes.crs + if original_crs is None: + original_crs = "EPSG:4326" # WGS 84 + + # Re-project geometries to a projected CRS (units of metres) + shapes = shapes.to_crs(epsg=3395) + + # Convert buffer from degrees to m + buffer_m = shape_buffer * 111_320 # 1 degree ~ 111.32 km + + # Coarsen the geometries and then buffer + shapes["geometry"] = shapes["geometry"].simplify(tolerance).buffer(buffer_m) + + # Re-project back to projection CRS (degrees) + shapes = shapes.to_crs(original_crs) + return shapes + + def _sample_coord(coord): """Sample coordinates for the fractional overlap calculation.""" diff --git a/unseen/stability.py b/unseen/stability.py index a9e4a4c..65e4bf4 100644 --- a/unseen/stability.py +++ b/unseen/stability.py @@ -1,12 +1,13 @@ """Functions and command line program for stability testing.""" import argparse - +import functools import matplotlib.pyplot as plt import numpy as np import pandas as pd -from scipy.stats import genextreme +from scipy.stats import genextreme, _resampling import seaborn as sns +import xarray as xr from . import fileio from . import eva @@ -271,6 +272,128 @@ def plot_return_by_time( ax.legend() +def statistic_by_lead_confidence_interval( + da, + statistic, + sample_size=None, + n_resamples=9999, + confidence_level=0.95, + method="percentile", + rng=np.random.default_rng(0), + ensemble_dim="ensemble", + init_dim="init_date", + lead_dim="lead_time", + **kwargs, +): + """Estimate confidence intervals for a statistic using bootstrapping. + + Similar to `scipy.stats.bootstrap`, with optional sample size of resamples. + + Parameters + ---------- + da : xarray.DataArray + Data with dimensions (ensemble_dim, init_dim, lead_dim, ...) + statistic : callable + Function to calculate the statistic (e.g. np.median) + sample_size : int, optional + Size of the resample. If None, based on ensemble * init dim sizes + n_resamples : int, default 1000 + Number of bootstrap resamples + confidence_level : float, default 0.95 + Confidence level for the confidence intervals + method : {"percentile", "bca"}, default "percentile" + Method for calculating the confidence intervals + rng : numpy.random.Generator, optional + ensemble_dim : str, default "ensemble" + init_dim : str, default "init_date" + lead_dim : str, default "lead_time" + kwargs : dict, optional + Additional keyword arguments passed to the statistic function + + Returns + ------- + ci : xarray.DataArray + Confidence intervals stacked along dimension "bounds" (lower and upper). + + Notes + ----- + The statistic function should accept a 1D array and return a scalar. + If method is "bca", the statistic function should also accept kw `axis`. + """ + + def bootstrap_1d( + data, statistic, sample_size, n_resamples, confidence_level, method + ): + """Bootstrap confidence interval for a statistic function.""" + + # Resample the data + theta_hat_b = [] + for _ in range(n_resamples): + resample = rng.choice(data, size=sample_size, replace=True) + theta_hat_b.append(statistic(resample)) + theta_hat_b = np.array(theta_hat_b) + + alpha = (1 - confidence_level) / 2 + + if method == "percentile": + interval = (alpha, 1 - alpha) + percentile_func = np.percentile + + elif method.lower() == "bca": + interval = _resampling._bca_interval( + (data,), + statistic, + axis=-1, + alpha=alpha, + theta_hat_b=theta_hat_b, + batch=None, + )[:2] + percentile_func = _resampling._percentile_along_axis + + ci_l = percentile_func(theta_hat_b, interval[0] * 100) + ci_u = percentile_func(theta_hat_b, interval[1] * 100) + return np.array([ci_l, ci_u]) + + if sample_size is None: + # Calculate the size of the resample (# samples per lead) + sample_size = da[ensemble_dim].size * da[init_dim].size + + # Pass kwargs to the statistic function + if kwargs: + statistic = functools.partial(statistic, **kwargs) + + # Stack the data along the sample dimension + da_stacked = da.stack( + sample=[ensemble_dim, init_dim, lead_dim], create_index=False + ).dropna("sample", how="all") + + # Ensure sample dim is on axis -1 for vectorization + da_stacked = da_stacked.transpose(..., "sample") + + ci = xr.apply_ufunc( + bootstrap_1d, + da_stacked, + input_core_dims=[["sample"]], + output_core_dims=[["bounds"]], + vectorize=True, + dask="parallelized", + kwargs=dict( + statistic=statistic, + sample_size=sample_size, + n_resamples=n_resamples, + confidence_level=confidence_level, + method=method, + ), + output_dtypes=[float], + dask_gufunc_kwargs={"output_sizes": {"bounds": 2}}, + ) + + # Assign the "bounds" dimension labels + ci = ci.assign_coords(bounds=["lower", "upper"]) + ci.attrs["long_name"] = f"{confidence_level:%} Confidence Interval" + return ci + + def create_plot( da_fcst, metric, diff --git a/unseen/tests/test_eva.py b/unseen/tests/test_eva.py index 4bfbb44..6ff5c9a 100644 --- a/unseen/tests/test_eva.py +++ b/unseen/tests/test_eva.py @@ -174,6 +174,35 @@ def test_fit_ns_gev_1d_pick_best_model(example_da_gev, test, trend): assert np.all(dparams[4] == 0) # No trend in scale +@pytest.mark.parametrize("example_da_gev", ["xarray", "numpy", "dask"], indirect=True) +def test_fit_gev_1d_basinhopping(example_da_gev): + """Run stationary GEV fit using 1D array with `use_basinhopping=True`.""" + data, dparams_i = example_da_gev + dparams = fit_gev(data, stationary=True, use_basinhopping=True) + + assert dparams.shape == (3,) # Should return 3 parameters + + # Check fitted params match params used to create data + npt.assert_allclose(dparams, dparams_i, rtol=rtol) + + +@pytest.mark.parametrize("example_da_gev", ["xarray", "dask"], indirect=True) +def test_fit_ns_gev_1d_basinhopping(example_da_gev): + """Run non-stationary GEV fit using 1D array & check results.""" + data, _ = example_da_gev + data = add_example_gev_trend(data) + covariate = xr.DataArray(np.arange(data.time.size), dims="time") + + dparams = fit_gev( + data, + stationary=False, + core_dim="time", + covariate=covariate, + use_basinhopping=True, + ) + assert np.all(dparams[2] > 0) # Positive trend in location + + @pytest.mark.parametrize("example_da_gev", ["xarray", "numpy", "dask"], indirect=True) def test_get_return_period(example_da_gev): """Run get_return_period for a single event using 1d data.""" diff --git a/unseen/time_utils.py b/unseen/time_utils.py index 87ea649..5b7d230 100644 --- a/unseen/time_utils.py +++ b/unseen/time_utils.py @@ -29,9 +29,13 @@ def get_agg_dates(ds, var, target_freq, agg_method, time_dim="time"): ------- event_datetimes_str : xarray.DataArray Event dates (YYYY-MM-DD) for the resampled array + + todo: replace loop with + ds.resample(index=target_freq, label="left").map(xr.DataArray.idxmax, dim=time_dim, keep_attrs=True) """ - ds_arg = ds[var].resample(time=target_freq, label="left") + da = ds[var] if isinstance(ds, xr.Dataset) else ds + ds_arg = da.resample({time_dim: target_freq}, label="left") if agg_method == "max": dates = [da.idxmax(time_dim) for _, da in ds_arg] elif agg_method == "min": @@ -111,10 +115,12 @@ def temporal_aggregation( start_time = ds[time_dim] if min_tsteps: - counts = ds[variables[0]].resample(time=target_freq).count(dim=time_dim) + # Count the number of time steps in each resampled group + counts = ds[variables[0]].resample({time_dim: target_freq}).count(dim=time_dim) if input_freq == target_freq[0]: pass + elif agg_method in ["max", "min"]: if agg_dates: # Record the date of each time aggregated event (e.g. annual max) @@ -122,20 +128,22 @@ def temporal_aggregation( ds, variables[0], target_freq, agg_method, time_dim=time_dim ) if agg_method == "max": - ds = ds.resample(time=target_freq).max(dim=time_dim, keep_attrs=True) + ds = ds.resample({time_dim: target_freq}).max(dim=time_dim, keep_attrs=True) else: - ds = ds.resample(time=target_freq).min(dim=time_dim, keep_attrs=True) + ds = ds.resample({time_dim: target_freq}).min(dim=time_dim, keep_attrs=True) if agg_dates: # Add the event time to the resampled array ds = ds.assign(event_time=agg_dates_var) elif agg_method == "sum": - ds = ds.resample(time=target_freq).sum(dim=time_dim, keep_attrs=True) + ds = ds.resample({time_dim: target_freq}).sum(dim=time_dim, keep_attrs=True) for var in variables: ds[var].attrs["units"] = _update_rate(ds[var], input_freq, target_freq) elif agg_method == "mean": if input_freq == "D": - ds = ds.resample(time=target_freq).mean(dim=time_dim, keep_attrs=True) + ds = ds.resample({time_dim: target_freq}).mean( + dim=time_dim, keep_attrs=True + ) elif input_freq == "M": ds = _monthly_downsample_mean(ds, target_freq, variables, time_dim=time_dim) else: @@ -172,15 +180,16 @@ def temporal_aggregation( assert np.unique(ds[time_dim].dt.month)[0] == month if min_tsteps: - # First and last time points likely have insufficient time steps - counts = counts.isel({time_dim: [0, -1]}) - # Check the minimum count for the first and last time points - counts = counts.min([dim for dim in ds[variables[0]].dims if dim != time_dim]) - # Remove first and/or last time if they have insufficient time steps - if counts[0] < min_tsteps: - ds = ds.isel({time_dim: slice(1, None)}) - if counts[-1] < min_tsteps: - ds = ds.isel({time_dim: slice(None, -1)}) + # Find resampled groups with fewer than min_tsteps + invalid_mask = counts < min_tsteps + # Reduce the mask to the time dimension (if there are other dimensions) + dims = [dim for dim in ds.dims if dim != time_dim] + if len(dims) > 0: + invalid_mask = invalid_mask.any(dims) + # Convert boolean mask to array indexes + invalid_inds = np.arange(counts[time_dim].size)[invalid_mask] + # Drop the invalid time steps + ds = ds.drop_isel({time_dim: invalid_inds}) if reindexed: ds = ds.compute() @@ -377,9 +386,9 @@ def _monthly_downsample_mean(ds, target_freq, variables, time_dim="time"): """ days_in_month = ds[time_dim].dt.days_in_month - weighted_mean = (ds * days_in_month).resample(time=target_freq).sum( + weighted_mean = (ds * days_in_month).resample({time_dim: target_freq}).sum( dim=time_dim, keep_attrs=True - ) / days_in_month.resample(time=target_freq).sum(dim=time_dim) + ) / days_in_month.resample({time_dim: target_freq}).sum(dim=time_dim) weighted_mean.attrs = ds.attrs for var in variables: weighted_mean[var].attrs = ds[var].attrs