From 97e0531b738dece959a884f18821a7b4e160c01a Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 8 Apr 2024 17:57:59 -0400 Subject: [PATCH] update data checks for 3D properly --- specparam/objs/data.py | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index a7d92e95..4006d1ab 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -151,16 +151,16 @@ def _regenerate_freqs(self): self.freqs = gen_freqs(self.freq_range, self.freq_res) - def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): + def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): """Prepare input data for adding to current object. Parameters ---------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power values, which must be input in linear space. - 1d vector, or 2d as [n_power_spectra, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Frequency range to restrict power spectrum to. If None, keeps the entire range. @@ -170,10 +170,10 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): Returns ------- freqs : 1d array - Frequency values for the power_spectrum, in linear space. - power_spectrum : 1d or 2d array + Frequency values for `powers`, in linear space. + powers : 1d or 2d or 3d array Power spectrum values, in log10 scale. - 1d vector, or 2d as [n_power_specta, n_freqs]. + 1d vector, or 2d as [n_spectra, n_freqs], or 3d as [n_events, n_spectra, n_freqs]. freq_range : list of [float, float] Minimum and maximum values of the frequency vector. freq_res : float @@ -188,20 +188,21 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): """ # Check that data are the right types - if not isinstance(freqs, np.ndarray) or not isinstance(power_spectrum, np.ndarray): + if not isinstance(freqs, np.ndarray) or not isinstance(powers, np.ndarray): raise DataError("Input data must be numpy arrays.") # Check that data have the right dimensionality - if freqs.ndim != 1 or (power_spectrum.ndim != spectra_dim): + if freqs.ndim != 1 or (powers.ndim != spectra_dim): raise DataError("Inputs are not the right dimensions.") # Check that data sizes are compatible - if freqs.shape[-1] != power_spectrum.shape[-1]: + if (spectra_dim < 3 and freqs.shape[-1] != powers.shape[-1]) or \ + spectra_dim == 3 and freqs.shape[-1] != powers.shape[1]: raise InconsistentDataError("The input frequencies and power spectra " "are not consistent size.") # Check if power values are complex - if np.iscomplexobj(power_spectrum): + if np.iscomplexobj(powers): raise DataError("Input power spectra are complex values. " "Model fitting does not currently support complex inputs.") @@ -209,17 +210,17 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): # If they end up as float32, or less, scipy curve_fit fails (sometimes implicitly) if freqs.dtype != 'float64': freqs = freqs.astype('float64') - if power_spectrum.dtype != 'float64': - power_spectrum = power_spectrum.astype('float64') + if powers.dtype != 'float64': + powers = powers.astype('float64') - # Check frequency range, trim the power_spectrum range if requested + # Check frequency range, trim the power values range if requested if freq_range: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, freq_range) + freqs, powers = trim_spectrum(freqs, powers, freq_range) # Check if freqs start at 0 and move up one value if so # Aperiodic fit gets an inf if freq of 0 is included, which leads to an error if freqs[0] == 0.0: - freqs, power_spectrum = trim_spectrum(freqs, power_spectrum, [freqs[1], freqs.max()]) + freqs, powers = trim_spectrum(freqs, powers, [freqs[1], freqs.max()]) if self.verbose: print("\nFITTING WARNING: Skipping frequency == 0, " "as this causes a problem with fitting.") @@ -229,7 +230,7 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): freq_res = freqs[1] - freqs[0] # Log power values - power_spectrum = np.log10(power_spectrum) + powers = np.log10(powers) ## Data checks - run checks on inputs based on check modes @@ -241,14 +242,14 @@ def _prepare_data(self, freqs, power_spectrum, freq_range, spectra_dim=1): "The model expects equidistant frequency values in linear space.") if self._check_data: # Check if there are any infs / nans, and raise an error if so - if np.any(np.isinf(power_spectrum)) or np.any(np.isnan(power_spectrum)): + if np.any(np.isinf(powers)) or np.any(np.isnan(powers)): error_msg = ("The input power spectra data, after logging, contains NaNs or Infs. " "This will cause the fitting to fail. " "One reason this can happen is if inputs are already logged. " "Input data should be in linear spacing, not log.") raise DataError(error_msg) - return freqs, power_spectrum, freq_range, freq_res + return freqs, powers, freq_range, freq_res class BaseData2D(BaseData):