diff --git a/specparam/algorithms/algorithm.py b/specparam/algorithms/algorithm.py index a98292dbe..0c60c131d 100644 --- a/specparam/algorithms/algorithm.py +++ b/specparam/algorithms/algorithm.py @@ -1,89 +1,11 @@ -"""Define original spectral fitting algorithm object.""" +"""Define object to manage algoirthm implementations.""" + +from specparam.data import ModelSettings +from specparam.modes.items import OBJ_DESC ################################################################################################### ################################################################################################### -class SettingsDefinition(): - """Defines a set of algorithm settings. - - Parameters - ---------- - settings : dict - Settings definition. - Each key should be a str name of a setting. - Each value should be a dictionary with keys 'type' and 'description', with str values. - """ - - def __init__(self, settings): - """Initialize settings definition.""" - - self._settings = settings - - - def _get_settings_subdict(self, field): - """Helper function to select from settings dictionary.""" - - return {label : self._settings[label][field] for label in self._settings.keys()} - - - @property - def names(self): - """Make property alias for setting names.""" - - return list(self._settings.keys()) - - - @property - def types(self): - """Make property alias for setting types.""" - - return self._get_settings_subdict('type') - - - @property - def descriptions(self): - """Make property alias for setting descriptions.""" - - return self._get_settings_subdict('description') - - - def make_setting_str(self, name): - """Make a setting docstring string. - - Parameters - ---------- - name : str - Setting name to make string for. - - Returns - ------- - str - Setting docstring string. - """ - - setting_str = '' + \ - ' ' + name + ' : ' + self.types[name] + '\n' \ - ' ' + self.descriptions[name] - - return setting_str - - - def make_docstring(self): - """Make docstring for all settings. - - Returns - ------- - str - Docstring for all settings. - """ - - pieces = [self.make_setting_str(name) for name in self.names] - pieces = [' Parameters', ' ----------'] + pieces - docstring = '\n'.join(pieces) - - return docstring - - class AlgorithmDefinition(): """Defines an algorithm definition description. @@ -123,17 +45,89 @@ class Algorithm(): Algorithm information. """ - def __init__(self, name, description, settings): + def __init__(self, name, description, settings, debug=False): """Initialize Algorithm object.""" self.algorithm = AlgorithmDefinition(name, description, settings) + self.set_debug(debug) + - def _fit_prechecks(): + def _fit_prechecks(self): """Prechecks to run before the fit function - if are some, overload this function.""" - pass - def _fit(): + def _fit(self): """Required fit function, to be overloaded.""" - pass + + + def add_settings(self, settings): + """Add settings into object from a ModelSettings object. + + Parameters + ---------- + settings : ModelSettings + A data object containing the settings for a power spectrum model. + """ + + for setting in OBJ_DESC['settings']: + setattr(self, setting, getattr(settings, setting)) + + self._check_loaded_settings(settings._asdict()) + + + def get_settings(self): + """Return user defined settings of the current object. + + Returns + ------- + ModelSettings + Object containing the settings from the current object. + """ + + return ModelSettings(**{key : getattr(self, key) \ + for key in OBJ_DESC['settings']}) + + + def get_debug(self): + """Return object debug status.""" + + return self._debug + + + def set_debug(self, debug): + """Set debug state, which controls if an error is raised if model fitting is unsuccessful. + + Parameters + ---------- + debug : bool + Whether to run in debug state. + """ + + self._debug = debug + + + def _check_loaded_settings(self, data): + """Check if settings added, and update the object as needed. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # If settings not loaded from file, clear from object, so that default + # settings, which are potentially wrong for loaded data, aren't kept + if not set(OBJ_DESC['settings']).issubset(set(data.keys())): + + # Reset all public settings to None + for setting in OBJ_DESC['settings']: + setattr(self, setting, None) + + # Reset internal settings so that they are consistent with what was loaded + # Note that this will set internal settings to None, if public settings unavailable + self._reset_internal_settings() + + + def _reset_internal_settings(self): + """"Can be overloaded if any resetting needed for internal settings.""" diff --git a/specparam/algorithms/settings.py b/specparam/algorithms/settings.py new file mode 100644 index 000000000..ef3225ac4 --- /dev/null +++ b/specparam/algorithms/settings.py @@ -0,0 +1,84 @@ +"""Define an algorithm settings object and related functionality.""" + +################################################################################################### +################################################################################################### + +class SettingsDefinition(): + """Defines a set of algorithm settings. + + Parameters + ---------- + settings : dict + Settings definition. + Each key should be a str name of a setting. + Each value should be a dictionary with keys 'type' and 'description', with str values. + """ + + def __init__(self, settings): + """Initialize settings definition.""" + + self._settings = settings + + + def _get_settings_subdict(self, field): + """Helper function to select from settings dictionary.""" + + return {label : self._settings[label][field] for label in self._settings.keys()} + + + @property + def names(self): + """Make property alias for setting names.""" + + return list(self._settings.keys()) + + + @property + def types(self): + """Make property alias for setting types.""" + + return self._get_settings_subdict('type') + + + @property + def descriptions(self): + """Make property alias for setting descriptions.""" + + return self._get_settings_subdict('description') + + + def make_setting_str(self, name): + """Make a setting docstring string. + + Parameters + ---------- + name : str + Setting name to make string for. + + Returns + ------- + str + Setting docstring string. + """ + + setting_str = '' + \ + ' ' + name + ' : ' + self.types[name] + '\n' \ + ' ' + self.descriptions[name] + + return setting_str + + + def make_docstring(self): + """Make docstring for all settings. + + Returns + ------- + str + Docstring for all settings. + """ + + pieces = [self.make_setting_str(name) for name in self.names] + pieces = [' Parameters', ' ----------'] + pieces + docstring = '\n'.join(pieces) + + return docstring diff --git a/specparam/algorithms/spectral_fit.py b/specparam/algorithms/spectral_fit.py index 357f8cc44..adf281a96 100644 --- a/specparam/algorithms/spectral_fit.py +++ b/specparam/algorithms/spectral_fit.py @@ -6,13 +6,12 @@ from numpy.linalg import LinAlgError from scipy.optimize import curve_fit -from specparam.modes.funcs import gaussian_function -from specparam.reports.strings import gen_width_warning_str from specparam.modutils.errors import FitError from specparam.utils.select import groupby from specparam.reports.strings import gen_width_warning_str from specparam.measures.params import compute_gauss_std -from specparam.algorithms.algorithm import SettingsDefinition, Algorithm +from specparam.algorithms.algorithm import Algorithm +from specparam.algorithms.settings import SettingsDefinition ################################################################################################### ################################################################################################### @@ -36,7 +35,7 @@ 'type' : 'float, optional, default: 2.0', 'description' : \ 'Relative threshold for detecting peaks.\n ' \ - 'This threshold is defined in relative units of the power spectrum (standard deviation).', + 'Threshold is defined in relative units of the power spectrum (standard deviation).', }, }) @@ -93,7 +92,8 @@ class SpectralFitAlgorithm(Algorithm): def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0, peak_threshold=2.0, ap_percentile_thresh=0.025, ap_guess=None, ap_bounds=None, cf_bound=1.5, bw_std_edge=1.0, gauss_overlap_thresh=0.75, - maxfev=5000, tol=0.00001): + maxfev=5000, tol=0.00001, + data=None, modes=None, results=None, verbose=False): """Initialize base model object""" # Initialize base algorithm object with algorithm metadata @@ -103,6 +103,12 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h settings=SPECTRAL_FIT_SETTINGS, ) + ## TEMP: + self.data = data + self.modes = modes + self.results = results + self.verbose = verbose + ## Public settings self.peak_width_limits = peak_width_limits self.max_n_peaks = max_n_peaks @@ -123,7 +129,9 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h ## Set internal settings, based on inputs, and initialize data & results attributes self._reset_internal_settings() - self._reset_data_results(True, True, True) + + # TEMP: is this still needed? + #self._reset_data_results(True, True, True) def _fit_prechecks(self): @@ -135,8 +143,8 @@ def _fit_prechecks(self): low given the frequency resolution of the data. """ - if 1.5 * self.freq_res >= self.peak_width_limits[0]: - print(gen_width_warning_str(self.freq_res, self.peak_width_limits[0])) + if 1.5 * self.data.freq_res >= self.peak_width_limits[0]: + print(gen_width_warning_str(self.data.freq_res, self.peak_width_limits[0])) def _fit(self): @@ -150,33 +158,35 @@ def _fit(self): ## FIT PROCEDURES # Take an initial fit of the aperiodic component - temp_aperiodic_params_ = self._robust_ap_fit(self.freqs, self.power_spectrum) - temp_ap_fit = self.aperiodic_mode.func(self.freqs, *temp_aperiodic_params_) + temp_aperiodic_params_ = self._robust_ap_fit(self.data.freqs, self.data.power_spectrum) + temp_ap_fit = self.modes.aperiodic.func(self.data.freqs, *temp_aperiodic_params_) # Find peaks from the flattened power spectrum, and fit them with gaussians - temp_spectrum_flat = self.power_spectrum - temp_ap_fit - self.gaussian_params_ = self._fit_peaks(temp_spectrum_flat) + temp_spectrum_flat = self.data.power_spectrum - temp_ap_fit + self.results.gaussian_params_ = self._fit_peaks(temp_spectrum_flat) # Calculate the peak fit # Note: if no peaks are found, this creates a flat (all zero) peak fit - self._peak_fit = self.periodic_mode.func(\ - self.freqs, *np.ndarray.flatten(self.gaussian_params_)) + self.results._peak_fit = self.modes.periodic.func(\ + self.data.freqs, *np.ndarray.flatten(self.results.gaussian_params_)) # Create peak-removed (but not flattened) power spectrum - self._spectrum_peak_rm = self.power_spectrum - self._peak_fit + self.results._spectrum_peak_rm = self.data.power_spectrum - self.results._peak_fit # Run final aperiodic fit on peak-removed power spectrum - self.aperiodic_params_ = self._simple_ap_fit(self.freqs, self._spectrum_peak_rm) - self._ap_fit = self.aperiodic_mode.func(self.freqs, *self.aperiodic_params_) + self.results.aperiodic_params_ = self._simple_ap_fit(\ + self.data.freqs, self.results._spectrum_peak_rm) + self.results._ap_fit = self.modes.aperiodic.func(\ + self.data.freqs, *self.results.aperiodic_params_) # Create remaining model components: flatspec & full power_spectrum model fit - self._spectrum_flat = self.power_spectrum - self._ap_fit - self.modeled_spectrum_ = self._peak_fit + self._ap_fit + self.results._spectrum_flat = self.data.power_spectrum - self.results._ap_fit + self.results.modeled_spectrum_ = self.results._peak_fit + self.results._ap_fit ## PARAMETER UPDATES # Convert gaussian definitions to peak parameters - self.peak_params_ = self._create_peak_params(self.gaussian_params_) + self.results.peak_params_ = self._create_peak_params(self.results.gaussian_params_) def _reset_internal_settings(self): @@ -213,7 +223,7 @@ def _get_ap_guess(self, freqs, power_spectrum): if not self._ap_guess: ap_guess = [] - for label in self.aperiodic_mode.params.labels: + for label in self.modes.aperiodic.params.labels: if label == 'offset': # Offset guess is the power value for lowest available frequency ap_guess.append(power_spectrum[0]) @@ -245,11 +255,11 @@ def _set_ap_bounds(self, ap_bounds): if ap_bounds: msg = 'Provided aperiodic bounds do not have right length for fit function.' assert len(self._ap_bounds[0]) == len(self._ap_bounds[1]) == \ - self.aperiodic_mode.n_params, msg + self.modes.aperiodic.n_params, msg self._ap_bounds = ap_bounds else: - self._ap_bounds = (tuple([-np.inf] * self.aperiodic_mode.n_params), - tuple([np.inf] * self.aperiodic_mode.n_params)) + self._ap_bounds = (tuple([-np.inf] * self.modes.aperiodic.n_params), + tuple([np.inf] * self.modes.aperiodic.n_params)) def _simple_ap_fit(self, freqs, power_spectrum): @@ -278,7 +288,7 @@ def _simple_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, freqs, power_spectrum, + aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs, power_spectrum, p0=ap_guess, bounds=self._ap_bounds, maxfev=self._maxfev, check_finite=False, ftol=self._tol, xtol=self._tol, gtol=self._tol) @@ -313,7 +323,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): # Do a quick, initial aperiodic fit popt = self._simple_ap_fit(freqs, power_spectrum) - initial_fit = self.aperiodic_mode.func(freqs, *popt) + initial_fit = self.modes.aperiodic.func(freqs, *popt) # Flatten power_spectrum based on initial aperiodic fit flatspec = power_spectrum - initial_fit @@ -332,7 +342,7 @@ def _robust_ap_fit(self, freqs, power_spectrum): try: with warnings.catch_warnings(): warnings.simplefilter("ignore") - aperiodic_params, _ = curve_fit(self.aperiodic_mode.func, + aperiodic_params, _ = curve_fit(self.modes.aperiodic.func, freqs_ignore, spectrum_ignore, p0=popt, bounds=self._ap_bounds, maxfev=self._maxfev, check_finite=False, @@ -368,7 +378,7 @@ def _fit_peaks(self, flatspec): flat_iter = np.copy(flatspec) # Initialize matrix of guess parameters for peak fitting - guess = np.empty([0, self.periodic_mode.n_params]) + guess = np.empty([0, self.modes.periodic.n_params]) # Find peak: loop through, finding a candidate peak, & fit with a guess peak # Stopping procedures: limit on # of peaks, or relative or absolute height thresholds @@ -383,7 +393,7 @@ def _fit_peaks(self, flatspec): break # Set the guess parameters for gaussian fitting, specifying the mean and height - guess_freq = self.freqs[max_ind] + guess_freq = self.data.freqs[max_ind] guess_height = max_height # Halt fitting process if candidate peak drops below minimum height @@ -408,7 +418,7 @@ def _fit_peaks(self, flatspec): # Use the shortest side to estimate full-width, half max (converted to Hz) # and use this to estimate that guess for gaussian standard deviation - fwhm = short_side * 2 * self.freq_res + fwhm = short_side * 2 * self.data.freq_res guess_std = compute_gauss_std(fwhm) except ValueError: @@ -427,12 +437,12 @@ def _fit_peaks(self, flatspec): current_guess_params = (guess_freq, guess_height, guess_std) ## TEMP - if self.periodic_mode.name == 'skewnorm': + if self.modes.periodic.name == 'skewnorm': guess_skew = 0 current_guess_params = (guess_freq, guess_height, guess_std, guess_skew) guess = np.vstack((guess, current_guess_params)) - peak_gauss = self.periodic_mode.func(self.freqs, *current_guess_params) + peak_gauss = self.modes.periodic.func(self.data.freqs, *current_guess_params) flat_iter = flat_iter - peak_gauss # Check peaks based on edges, and on overlap, dropping any that violate requirements @@ -444,7 +454,7 @@ def _fit_peaks(self, flatspec): gaussian_params = self._fit_peak_guess(flatspec, guess) gaussian_params = gaussian_params[gaussian_params[:, 0].argsort()] else: - gaussian_params = np.empty([0, self.periodic_mode.n_params]) + gaussian_params = np.empty([0, self.modes.periodic.n_params]) return gaussian_params @@ -466,10 +476,10 @@ def _get_pe_bounds(self, guess): # Check that CF bounds are within frequency range # If they are not, update them to be restricted to frequency range - lo_bound = [bound if bound[0] > self.freq_range[0] else \ - [self.freq_range[0], *bound[1:]] for bound in lo_bound] - hi_bound = [bound if bound[0] < self.freq_range[1] else \ - [self.freq_range[1], *bound[1:]] for bound in hi_bound] + lo_bound = [bound if bound[0] > self.data.freq_range[0] else \ + [self.data.freq_range[0], *bound[1:]] for bound in lo_bound] + hi_bound = [bound if bound[0] < self.data.freq_range[1] else \ + [self.data.freq_range[1], *bound[1:]] for bound in hi_bound] # Unpacks the embedded lists into flat tuples # This is what the fit function requires as input @@ -497,11 +507,11 @@ def _fit_peak_guess(self, flatspec, guess): # Fit the peaks try: - pe_params, _ = curve_fit(self.periodic_mode.func, - self.freqs, flatspec, + pe_params, _ = curve_fit(self.modes.periodic.func, + self.data.freqs, flatspec, p0=np.ndarray.flatten(guess), bounds=self._get_pe_bounds(guess), - jac=self.periodic_mode.jacobian, + jac=self.modes.periodic.jacobian, maxfev=self._maxfev, check_finite=False, ftol=self._tol, xtol=self._tol, gtol=self._tol) @@ -516,7 +526,7 @@ def _fit_peak_guess(self, flatspec, guess): raise FitError(error_msg) from excp # Re-organize params into 2d matrix - pe_params = np.array(groupby(pe_params, self.periodic_mode.n_params)) + pe_params = np.array(groupby(pe_params, self.modes.periodic.n_params)) return pe_params @@ -541,8 +551,8 @@ def _drop_peak_cf(self, guess): # Check if peaks within drop threshold from the edge of the frequency range keep_peak = \ - (np.abs(np.subtract(cf_params, self.freq_range[0])) > bw_params) & \ - (np.abs(np.subtract(cf_params, self.freq_range[1])) > bw_params) + (np.abs(np.subtract(cf_params, self.data.freq_range[0])) > bw_params) & \ + (np.abs(np.subtract(cf_params, self.data.freq_range[1])) > bw_params) # Drop peaks that fail the center frequency edge criterion guess = np.array([gu for (gu, keep) in zip(guess, keep_peak) if keep]) @@ -626,21 +636,23 @@ def _create_peak_params(self, gaus_params): with `freqs`, `modeled_spectrum_` and `_ap_fit` all required to be available. """ - peak_params = np.empty((len(gaus_params), self.periodic_mode.n_params)) + peak_params = np.empty((len(gaus_params), self.modes.periodic.n_params)) for ii, peak in enumerate(gaus_params): # Gets the index of the power_spectrum at the frequency closest to the CF of the peak - ind = np.argmin(np.abs(self.freqs - peak[0])) + ind = np.argmin(np.abs(self.data.freqs - peak[0])) # Collect peak parameter data - if self.periodic_mode.name == 'gaussian': ## TEMP - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + if self.modes.periodic.name == 'gaussian': ## TEMP + peak_params[ii] = [peak[0], + self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], peak[2] * 2] ## TEMP: - if self.periodic_mode.name == 'skewnorm': - peak_params[ii] = [peak[0], self.modeled_spectrum_[ind] - self._ap_fit[ind], + if self.modes.periodic.name == 'skewnorm': + peak_params[ii] = [peak[0], + self.results.modeled_spectrum_[ind] - self.results._ap_fit[ind], peak[2] * 2, peak[3]] return peak_params diff --git a/specparam/data/__init__.py b/specparam/data/__init__.py index 850f5b404..89a1eef89 100644 --- a/specparam/data/__init__.py +++ b/specparam/data/__init__.py @@ -1,3 +1,3 @@ """Data sub-module.""" -from .data import ModelSettings, ModelChecks, SpectrumMetaData, FitResults, SimParams +from .data import ModelModes, ModelSettings, ModelChecks, SpectrumMetaData, FitResults, SimParams diff --git a/specparam/data/periodic.py b/specparam/data/periodic.py index 6643aad28..a92681334 100644 --- a/specparam/data/periodic.py +++ b/specparam/data/periodic.py @@ -44,7 +44,7 @@ def get_band_peak(model, band, select_highest=True, threshold=None, >>> betas = get_band_peak(model, [13, 30], select_highest=False) # doctest:+SKIP """ - return get_band_peak_arr(getattr(model, attribute + '_'), band, + return get_band_peak_arr(getattr(model.results, attribute + '_'), band, select_highest, threshold, thresh_param) @@ -97,7 +97,7 @@ def get_band_peak_group(group, band, threshold=None, thresh_param='PW', attribut >>> betas = get_band_peak_group(group, [13, 30], threshold=0.1) # doctest:+SKIP """ - return get_band_peak_group_arr(group.get_params(attribute), band, len(group), + return get_band_peak_group_arr(group.results.get_params(attribute), band, len(group.results), threshold, thresh_param) @@ -127,8 +127,8 @@ def get_band_peak_event(event, band, threshold=None, thresh_param='PW', attribut Array of peak data, organized as [n_events, n_time_windows, n_peak_params]. """ - peaks = np.zeros([event.n_events, event.n_time_windows, 3]) - for ind in range(event.n_events): + peaks = np.zeros([event.data.n_events, event.data.n_time_windows, 3]) + for ind in range(event.data.n_events): peaks[ind, :, :] = get_band_peak_group(\ event.get_group(ind, None, 'group'), band, threshold, thresh_param, attribute) diff --git a/specparam/io/models.py b/specparam/io/models.py index f7bd7a512..d670394f3 100644 --- a/specparam/io/models.py +++ b/specparam/io/models.py @@ -45,16 +45,22 @@ def save_model(model, file_name, file_path=None, append=False, """ # Convert object to dictionary & convert all arrays to lists, for JSON serializing + # This 'flattens' the object, getting all relevant attributes in the same dictionary obj_dict = dict_array_to_lst(model.__dict__) + data_dict = dict_array_to_lst(model.data.__dict__) + results_dict = dict_array_to_lst(model.results.__dict__) + algo_dict = dict_array_to_lst(model.algorithm.__dict__) + obj_dict = {**obj_dict, **data_dict, **results_dict, **algo_dict} # Convert modes object to their saveable string name - obj_dict['aperiodic_mode'] = obj_dict['aperiodic_mode'].name - obj_dict['periodic_mode'] = obj_dict['periodic_mode'].name + obj_dict['aperiodic_mode'] = obj_dict['modes'].aperiodic.name + obj_dict['periodic_mode'] = obj_dict['modes'].periodic.name + mode_labels = ['aperiodic_mode', 'periodic_mode'] # Set and select which variables to keep. Use a set to drop any potential overlap # Note that results also saves frequency information to be able to recreate freq vector keep = set((OBJ_DESC['results'] + OBJ_DESC['meta_data'] if save_results else []) + \ - (OBJ_DESC['settings'] + OBJ_DESC['modes'] if save_settings else []) + \ + (OBJ_DESC['settings'] + mode_labels if save_settings else []) + \ (OBJ_DESC['data'] if save_data else [])) obj_dict = dict_select_keys(obj_dict, keep) @@ -139,11 +145,11 @@ def save_event(event, file_name, file_path=None, append=False, if save_settings and not save_results and not save_data: fg.save(file_name, file_path, append=append, save_settings=True) else: - ndigits = len(str(len(event))) - for ind, gres in enumerate(event.event_group_results): - fg.group_results = gres + ndigits = len(str(len(event.results))) + for ind, gres in enumerate(event.results.event_group_results): + fg.results.group_results = gres if save_data: - fg.power_spectra = event.spectrograms[ind, :, :].T + fg.data.power_spectra = event.data.spectrograms[ind, :, :].T fg.save(file_name + '_{:0{ndigits}d}'.format(ind, ndigits=ndigits), file_path=file_path, append=append, save_results=save_results, save_settings=save_settings, save_data=save_data) @@ -274,7 +280,7 @@ def _save_group(group, f_obj, save_results, save_settings, save_data): # For results & data, loop across all data and/or models, and save each out to a new line if save_results or save_data: - for ind in range(len(group.group_results)): + for ind in range(len(group.results.group_results)): model = group.get_model(ind, regenerate=False) save_model(model, file_name=f_obj, file_path=None, append=False, save_results=save_results, save_data=save_data) diff --git a/specparam/measures/gof.py b/specparam/measures/gof.py index a02e98c84..89aeef0eb 100644 --- a/specparam/measures/gof.py +++ b/specparam/measures/gof.py @@ -79,7 +79,7 @@ def compute_gof(power_spectrum, modeled_spectrum, gof_metric='r_squared'): if isinstance(gof_metric, str): gof = GOF_FUNCS[gof_metric.lower()](power_spectrum, modeled_spectrum) - elif isfunction(error_metric): + elif isfunction(gof_metric): gof = gof_metric(power_spectrum, modeled_spectrum) return gof diff --git a/specparam/measures/pointwise.py b/specparam/measures/pointwise.py index b7e7ed30c..819152c79 100644 --- a/specparam/measures/pointwise.py +++ b/specparam/measures/pointwise.py @@ -37,15 +37,15 @@ def compute_pointwise_error(model, plot_errors=True, return_errors=False, **plt_ If there are no model results available to calculate model error from. """ - if not model.has_data: + if not model.data.has_data: raise NoDataError("Data must be available in the object to calculate errors.") - if not model.has_model: + if not model.results.has_model: raise NoModelError("No model is available to use, can not proceed.") - errors = compute_pointwise_error_arr(model.modeled_spectrum_, model.power_spectrum) + errors = compute_pointwise_error_arr(model.results.modeled_spectrum_, model.data.power_spectrum) if plot_errors: - plot_spectral_error(model.freqs, errors, **plt_kwargs) + plot_spectral_error(model.data.freqs, errors, **plt_kwargs) if return_errors: return errors @@ -79,34 +79,34 @@ def compute_pointwise_error_group(group, plot_errors=True, return_errors=False, If there are no model results available to calculate model errors from. """ - if not np.any(group.power_spectra): + if not group.data.has_data: raise NoDataError("Data must be available in the object to calculate errors.") - if not group.has_model: + if not group.results.has_model: raise NoModelError("No model is available to use, can not proceed.") - errors = np.zeros_like(group.power_spectra) + errors = np.zeros_like(group.data.power_spectra) - for ind, (res, data) in enumerate(zip(group, group.power_spectra)): + for ind, (res, data) in enumerate(zip(group.results, group.data.power_spectra)): - model = gen_model(group.freqs, res.aperiodic_params, res.gaussian_params) + model = gen_model(group.data.freqs, res.aperiodic_params, res.gaussian_params) errors[ind, :] = np.abs(model - data) mean = np.mean(errors, 0) standard_dev = np.std(errors, 0) if plot_errors: - plot_spectral_error(group.freqs, mean, standard_dev, **plt_kwargs) + plot_spectral_error(group.data.freqs, mean, standard_dev, **plt_kwargs) if return_errors: return errors -def compute_pointwise_error_arr(data_model, data): +def compute_pointwise_error_arr(model, data): """Calculate point-wise error between original data and a model fit of that data. Parameters ---------- - data_model : 1d array + model : 1d array The model of the data. data : 1d array The original data that is being modeled. @@ -117,4 +117,4 @@ def compute_pointwise_error_arr(data_model, data): Calculated values of the difference between the data and the model. """ - return np.abs(data_model - data) + return np.abs(model - data) diff --git a/specparam/models/__init__.py b/specparam/models/__init__.py index 1aec7d75a..d565feae4 100644 --- a/specparam/models/__init__.py +++ b/specparam/models/__init__.py @@ -5,4 +5,4 @@ from .time import SpectralTimeModel from .event import SpectralTimeEventModel from .utils import (compare_model_objs, average_group, average_reconstructions, - combine_model_objs, fit_models_3d) \ No newline at end of file + combine_model_objs, fit_models_3d) diff --git a/specparam/models/event.py b/specparam/models/event.py index 7cd402717..28c0142e8 100644 --- a/specparam/models/event.py +++ b/specparam/models/event.py @@ -2,6 +2,7 @@ import numpy as np +from specparam.modes.modes import Modes from specparam.models import SpectralModel from specparam.objs.base import BaseObject3D from specparam.algorithms.spectral_fit import SpectralFitAlgorithm @@ -18,7 +19,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): +class SpectralTimeEventModel(BaseObject3D): """Model a set of event as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -58,15 +59,15 @@ class SpectralTimeEventModel(SpectralFitAlgorithm, BaseObject3D): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - BaseObject3D.__init__(self, - aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), - periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug=kwargs.pop('debug', False), - verbose=kwargs.pop('verbose', True)) + self.modes = Modes(aperiodic=kwargs.pop('aperiodic_mode', 'fixed'), + periodic=kwargs.pop('periodic_mode', 'gaussian')) - SpectralFitAlgorithm.__init__(self, *args, **kwargs) + BaseObject3D.__init__(self, modes=self.modes, verbose=kwargs.pop('verbose', True)) - self._reset_event_results() + self.algorithm = SpectralFitAlgorithm(*args, **kwargs, + data=self.data, modes=self.modes, results=self.results, verbose=self.verbose) + + self.results._reset_event_results() def report(self, freqs=None, spectrograms=None, freq_range=None, @@ -128,14 +129,14 @@ def save_report(self, file_name, file_path=None, add_settings=True): save_event_report(self, file_name, file_path, add_settings) - def get_model(self, event_ind, window_ind, regenerate=True): + def get_model(self, event_ind=None, window_ind=None, regenerate=True): """Get a model fit object for a specified index. Parameters ---------- - event_ind : int + event_ind : int, optional Index for which event to extract from. - window_ind : int + window_ind : int, optional Index for which time window to extract from. regenerate : bool, optional, default: False Whether to regenerate the model fits for the requested model. @@ -143,23 +144,20 @@ def get_model(self, event_ind, window_ind, regenerate=True): Returns ------- model : SpectralModel - The FitResults data loaded into a model object. + The data and fit results loaded into a model object. """ # Initialize model object, with same settings, metadata, & check states as current object - model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) - model.add_meta_data(self.get_meta_data()) - model.set_checks(*self.get_checks()) - model.set_debug(self.get_debug()) + model = super().get_model() # Add data for specified single power spectrum, if available - if self.has_data: - model.power_spectrum = self.spectrograms[event_ind][:, window_ind] + if self.data.has_data: + model.data.power_spectrum = self.data.spectrograms[event_ind][:, window_ind] # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.event_group_results[event_ind][window_ind]) + model.results.add_results(self.results.event_group_results[event_ind][window_ind]) if regenerate: - model._regenerate_model() + model.results._regenerate_model(self.data.freqs) return model @@ -206,9 +204,9 @@ def to_df(self, peak_org=None): """ if peak_org is not None: - df = event_group_to_dataframe(self.event_group_results, peak_org) + df = event_group_to_dataframe(self.results.event_group_results, peak_org) else: - df = dict_to_df(flatten_results_dict(self.get_results())) + df = dict_to_df(flatten_results_dict(self.results.get_results())) return df @@ -222,5 +220,5 @@ def _fit_prechecks(self): checking and reporting on every spectrum and repeatedly re-raising the same warning. """ - if np.all(self.power_spectrum == self.spectrograms[0, :, 0]): + if np.all(self.data.power_spectrum == self.data.spectrograms[0, :, 0]): super()._fit_prechecks() diff --git a/specparam/models/group.py b/specparam/models/group.py index 58bd108bf..d6f844df1 100644 --- a/specparam/models/group.py +++ b/specparam/models/group.py @@ -5,6 +5,7 @@ Methods without defined docstrings import docs at runtime, from aliased external functions. """ +from specparam.modes.modes import Modes from specparam.models import SpectralModel from specparam.objs.base import BaseObject2D from specparam.algorithms.spectral_fit import SpectralFitAlgorithm @@ -20,7 +21,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D): +class SpectralGroupModel(BaseObject2D): """Model a group of power spectra as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -71,13 +72,13 @@ class SpectralGroupModel(SpectralFitAlgorithm, BaseObject2D): def __init__(self, *args, **kwargs): - BaseObject2D.__init__(self, - aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), - periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug=kwargs.pop('debug', False), - verbose=kwargs.pop('verbose', True)) + self.modes = Modes(aperiodic=kwargs.pop('aperiodic_mode', 'fixed'), + periodic=kwargs.pop('periodic_mode', 'gaussian')) - SpectralFitAlgorithm.__init__(self, *args, **kwargs) + BaseObject2D.__init__(self, modes=self.modes, verbose=kwargs.pop('verbose', True)) + + self.algorithm = SpectralFitAlgorithm(*args, **kwargs, + data=self.data, modes=self.modes, results=self.results, verbose=self.verbose) def report(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, @@ -172,7 +173,7 @@ def to_df(self, peak_org): Model results organized into a pandas object. """ - return group_to_dataframe(self.get_results(), peak_org) + return group_to_dataframe(self.results.get_results(), peak_org) def _fit_prechecks(self): @@ -184,5 +185,5 @@ def _fit_prechecks(self): checking and reporting on every spectrum and repeatedly re-raising the same warning. """ - if self.power_spectra[0, 0] == self.power_spectrum[0]: + if self.data.power_spectra[0, 0] == self.data.power_spectrum[0]: super()._fit_prechecks() diff --git a/specparam/models/model.py b/specparam/models/model.py index d8472172e..58b064941 100644 --- a/specparam/models/model.py +++ b/specparam/models/model.py @@ -7,6 +7,7 @@ import numpy as np +from specparam.modes.modes import Modes from specparam.objs.base import BaseObject from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS from specparam.reports.save import save_model_report @@ -17,13 +18,12 @@ from specparam.plts.model import plot_model from specparam.data.utils import get_model_params from specparam.data.conversions import model_to_dataframe -from specparam.sim.gen import gen_model ################################################################################################### ################################################################################################### @replace_docstring_sections([SPECTRAL_FIT_SETTINGS.make_docstring()]) -class SpectralModel(SpectralFitAlgorithm, BaseObject): +class SpectralModel(BaseObject): """Model a power spectrum as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -104,12 +104,15 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h verbose=True, **model_kwargs): """Initialize model object.""" - BaseObject.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=model_kwargs.pop('debug', False), verbose=verbose) + self.modes = Modes(aperiodic=aperiodic_mode, periodic=periodic_mode) - SpectralFitAlgorithm.__init__(self, peak_width_limits=peak_width_limits, - max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, - peak_threshold=peak_threshold, **model_kwargs) + BaseObject.__init__(self, modes=self.modes, verbose=verbose) + + self.algorithm = SpectralFitAlgorithm(peak_width_limits=peak_width_limits, + max_n_peaks=max_n_peaks, min_peak_height=min_peak_height, + peak_threshold=peak_threshold, + data=self.data, modes=self.modes, results=self.results, verbose=self.verbose, + **model_kwargs) def report(self, freqs=None, power_spectrum=None, freq_range=None, @@ -229,10 +232,10 @@ def get_params(self, name, col=None): If there are no fit peak (no peak parameters), this method will return NaN. """ - if not self.has_model: + if not self.results.has_model: raise NoModelError("No model fit results are available to extract, can not proceed.") - return get_model_params(self.get_results(), name, col) + return get_model_params(self.results.get_results(), name, col) @copy_doc_func_to_method(plot_model) @@ -268,11 +271,4 @@ def to_df(self, peak_org): Model results organized into a pandas object. """ - return model_to_dataframe(self.get_results(), peak_org) - - - def _regenerate_model(self): - """Regenerate model fit from parameters.""" - - self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( - self.freqs, self.aperiodic_params_, self.gaussian_params_, return_components=True) + return model_to_dataframe(self.results.get_results(), peak_org) diff --git a/specparam/models/time.py b/specparam/models/time.py index c3f04458e..42af9c863 100644 --- a/specparam/models/time.py +++ b/specparam/models/time.py @@ -2,6 +2,7 @@ import numpy as np +from specparam.modes.modes import Modes from specparam.models import SpectralModel from specparam.objs.base import BaseObject2DT from specparam.algorithms.spectral_fit import SpectralFitAlgorithm @@ -17,7 +18,7 @@ @replace_docstring_sections([docs_get_section(SpectralModel.__doc__, 'Parameters'), docs_get_section(SpectralModel.__doc__, 'Notes')]) -class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT): +class SpectralTimeModel(BaseObject2DT): """Model a spectrogram as a combination of aperiodic and periodic components. WARNING: frequency and power values inputs must be in linear space. @@ -59,15 +60,15 @@ class SpectralTimeModel(SpectralFitAlgorithm, BaseObject2DT): def __init__(self, *args, **kwargs): """Initialize object with desired settings.""" - BaseObject2DT.__init__(self, - aperiodic_mode=kwargs.pop('aperiodic_mode', 'fixed'), - periodic_mode=kwargs.pop('periodic_mode', 'gaussian'), - debug=kwargs.pop('debug', False), - verbose=kwargs.pop('verbose', True)) + self.modes = Modes(aperiodic=kwargs.pop('aperiodic_mode', 'fixed'), + periodic=kwargs.pop('periodic_mode', 'gaussian')) - SpectralFitAlgorithm.__init__(self, *args, **kwargs) + BaseObject2DT.__init__(self, modes=self.modes, verbose=kwargs.pop('verbose', True)) - self._reset_time_results() + self.algorithm = SpectralFitAlgorithm(*args, **kwargs, + data=self.data, modes=self.modes, results=self.results, verbose=self.verbose) + + self.results._reset_time_results() def report(self, freqs=None, spectrogram=None, freq_range=None, @@ -153,9 +154,9 @@ def to_df(self, peak_org=None): """ if peak_org is not None: - df = group_to_dataframe(self.group_results, peak_org) + df = group_to_dataframe(self.results.group_results, peak_org) else: - df = dict_to_df(self.get_results()) + df = dict_to_df(self.results.get_results()) return df @@ -169,5 +170,5 @@ def _fit_prechecks(self): checking and reporting on every spectrum and repeatedly re-raising the same warning. """ - if np.all(self.power_spectrum == self.spectrogram[:, 0]): + if np.all(self.data.power_spectrum == self.data.spectrogram[:, 0]): super()._fit_prechecks() diff --git a/specparam/models/utils.py b/specparam/models/utils.py index 4733cbd95..3b35c83e4 100644 --- a/specparam/models/utils.py +++ b/specparam/models/utils.py @@ -4,13 +4,49 @@ from specparam.sim import gen_freqs from specparam.data import FitResults -from specparam.models import SpectralModel, SpectralGroupModel +from specparam.models import (SpectralModel, SpectralGroupModel, + SpectralTimeModel, SpectralTimeEventModel) from specparam.data.periodic import get_band_peak_group from specparam.modutils.errors import NoModelError, IncompatibleSettingsError ################################################################################################### ################################################################################################### +# Collect dictionary of all available models +MODELS = { + 'model' : SpectralModel, + 'group' : SpectralGroupModel, + 'time' : SpectralTimeModel, + 'event' : SpectralTimeEventModel, +} + + +def initialize_model_from_source(source, target): + """Initialize a model object based on a source model object. + + Parameters + ---------- + source : SpectralModel or Spectral*Model + Model object to initialize from. + target : {'model', 'group', 'time', 'event'} + Type of model object to initialize. + + Returns + ------- + model : Spectral*Model + Model object, of type `target`, initialized from source. + """ + + model = MODELS[target](**source.modes.get_modes()._asdict(), + **source.algorithm.get_settings()._asdict(), + verbose=source.verbose) + model.data.add_meta_data(source.data.get_meta_data()) + model.data.set_checks(*source.data.get_checks()) + model.algorithm.set_debug(source.algorithm.get_debug()) + + return model + + def compare_model_objs(model_objs, aspect): """Compare multiple model, checking for consistent attributes. @@ -29,11 +65,10 @@ def compare_model_objs(model_objs, aspect): # Check specified aspect of the objects are the same across instances for m_obj_1, m_obj_2 in zip(model_objs[:-1], model_objs[1:]): - if getattr(m_obj_1, 'get_' + aspect)() != getattr(m_obj_2, 'get_' + aspect)(): - consistent = False - break - else: - consistent = True + if aspect == 'settings': + consistent = m_obj_1.algorithm.get_settings() == m_obj_2.algorithm.get_settings() + if aspect == 'meta_data': + consistent = m_obj_1.data.get_meta_data() == m_obj_2.data.get_meta_data() return consistent @@ -65,7 +100,7 @@ def average_group(group, bands, avg_method='mean', regenerate=True): If there are no model fit results available to average across. """ - if not group.has_model: + if not group.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian} @@ -73,7 +108,7 @@ def average_group(group, bands, avg_method='mean', regenerate=True): raise ValueError("Requested average method not understood.") # Aperiodic parameters: extract & average - ap_params = avg_funcs[avg_method](group.get_params('aperiodic_params'), 0) + ap_params = avg_funcs[avg_method](group.results.get_params('aperiodic_params'), 0) # Periodic parameters: extract & average peak_params = [] @@ -94,21 +129,19 @@ def average_group(group, bands, avg_method='mean', regenerate=True): gauss_params = np.array(gauss_params) # Goodness of fit measures: extract & average - r2 = avg_funcs[avg_method](group.get_params('r_squared')) - error = avg_funcs[avg_method](group.get_params('error')) + r2 = avg_funcs[avg_method](group.results.get_params('r_squared')) + error = avg_funcs[avg_method](group.results.get_params('error')) # Collect all results together, to be added to the model object results = FitResults(ap_params, peak_params, r2, error, gauss_params) # Create the new model object, with settings, data info & results - model = SpectralModel() - model.add_settings(group.get_settings()) - model.add_meta_data(group.get_meta_data()) - model.add_results(results) + model = group.get_model() + model.results.add_results(results) # Generate the average model from the parameters if regenerate: - model._regenerate_model() + model.results._regenerate_model(group.data.freqs) return model @@ -132,20 +165,20 @@ def average_reconstructions(group, avg_method='mean'): Note that power values are in log10 space. """ - if not group.has_model: + if not group.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian} if avg_method not in avg_funcs.keys(): raise ValueError("Requested average method not understood.") - models = np.zeros(shape=group.power_spectra.shape) - for ind in range(len(group)): - models[ind, :] = group.get_model(ind, regenerate=True).modeled_spectrum_ + models = np.zeros(shape=group.data.power_spectra.shape) + for ind in range(len(group.results)): + models[ind, :] = group.get_model(ind, regenerate=True).results.modeled_spectrum_ avg_model = avg_funcs[avg_method](models, 0) - return group.freqs, avg_model + return group.data.freqs, avg_model def combine_model_objs(model_objs): @@ -184,11 +217,12 @@ def combine_model_objs(model_objs): "or meta data, and so cannot be combined.") # Initialize group model object, with settings derived from input objects - group = SpectralGroupModel(*model_objs[0].get_settings(), verbose=model_objs[0].verbose) + group = SpectralGroupModel(*model_objs[0].algorithm.get_settings(), + verbose=model_objs[0].verbose) # Use a temporary store to collect spectra, as we'll only add it if it is consistently present # We check how many frequencies by accessing meta data, in case of no frequency vector - meta_data = model_objs[0].get_meta_data() + meta_data = model_objs[0].data.get_meta_data() n_freqs = len(gen_freqs(meta_data.freq_range, meta_data.freq_res)) temp_power_spectra = np.empty([0, n_freqs]) @@ -197,28 +231,28 @@ def combine_model_objs(model_objs): # Add group object if isinstance(m_obj, SpectralGroupModel): - group.group_results.extend(m_obj.group_results) - if m_obj.power_spectra is not None: - temp_power_spectra = np.vstack([temp_power_spectra, m_obj.power_spectra]) + group.results.group_results.extend(m_obj.results.group_results) + if m_obj.data.power_spectra is not None: + temp_power_spectra = np.vstack([temp_power_spectra, m_obj.data.power_spectra]) # Add model object else: - group.group_results.append(m_obj.get_results()) - if m_obj.power_spectrum is not None: - temp_power_spectra = np.vstack([temp_power_spectra, m_obj.power_spectrum]) + group.results.group_results.append(m_obj.results.get_results()) + if m_obj.data.power_spectrum is not None: + temp_power_spectra = np.vstack([temp_power_spectra, m_obj.data.power_spectrum]) # If the number of collected power spectra is consistent, then add them to object - if len(group) == temp_power_spectra.shape[0]: - group.power_spectra = temp_power_spectra + if len(group.results) == temp_power_spectra.shape[0]: + group.data.power_spectra = temp_power_spectra # Set the status for freqs & data checking # Check states gets set as True if any of the inputs have it on, False otherwise - group.set_checks(\ - check_freqs=any(getattr(m_obj, '_check_freqs') for m_obj in model_objs), - check_data=any(getattr(m_obj, '_check_data') for m_obj in model_objs)) + group.data.set_checks(\ + check_freqs=any(getattr(m_obj.data, '_check_freqs') for m_obj in model_objs), + check_data=any(getattr(m_obj.data, '_check_data') for m_obj in model_objs)) # Add data information information - group.add_meta_data(model_objs[0].get_meta_data()) + group.data.add_meta_data(model_objs[0].data.get_meta_data()) return group diff --git a/specparam/modes/funcs.py b/specparam/modes/funcs.py index 264db01d4..1adb4aebd 100644 --- a/specparam/modes/funcs.py +++ b/specparam/modes/funcs.py @@ -94,7 +94,7 @@ def cauchy_function(xs, *params): ## APERIODIC FUNCTIONS def expo_function(xs, *params): - """Exponential fitting function, for fitting aperiodic component with a 'knee'. + """Exponential function, for fitting aperiodic component with a 'knee'. NOTE: this function requires linear frequency (not log). @@ -119,7 +119,7 @@ def expo_function(xs, *params): def expo_nk_function(xs, *params): - """Exponential fitting function, for fitting aperiodic component without a 'knee'. + """Exponential function, for fitting aperiodic component without a 'knee'. NOTE: this function requires linear frequency (not log). @@ -144,7 +144,7 @@ def expo_nk_function(xs, *params): def double_expo_function(xs, *params): - """Double exponential fitting function, for fitting aperiodic component with two exponents and a knee. + """Double exponential function, for fitting aperiodic component with two exponents and a knee. NOTE: this function requires linear frequency (not log). diff --git a/specparam/modes/info.py b/specparam/modes/info.py index feaf747d8..45a260f3c 100644 --- a/specparam/modes/info.py +++ b/specparam/modes/info.py @@ -28,19 +28,34 @@ def get_description(): """ attributes = { - 'results' : ['aperiodic_params_', 'gaussian_params_', 'peak_params_', - 'r_squared_', 'error_'], - 'settings' : ['peak_width_limits', 'max_n_peaks', - 'min_peak_height', 'peak_threshold'], - 'modes' : ['aperiodic_mode', 'periodic_mode'], + + # Data 'checks' : ['_check_freqs', '_check_data'], - 'debug' : ['_debug'], 'data' : ['power_spectrum', 'freq_range', 'freq_res'], 'meta_data' : ['freq_range', 'freq_res'], - 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', - 'peak_params_', 'gaussian_params_'], + + # Modes + #'modes' : ['aperiodic_mode', 'periodic_mode'], + + # Algorithm + 'settings' : ['peak_width_limits', 'max_n_peaks', + 'min_peak_height', 'peak_threshold'], + 'debug' : ['_debug'], + + # Results + 'results' : ['aperiodic_params_', 'gaussian_params_', 'peak_params_', + 'r_squared_', 'error_'], 'model_components' : ['modeled_spectrum_', '_spectrum_flat', '_spectrum_peak_rm', '_ap_fit', '_peak_fit'], + + # Metrics + # 'metrics' : ['r_squared_', 'error_'] + + # General - data types + 'arrays' : ['freqs', 'power_spectrum', 'aperiodic_params_', + 'peak_params_', 'gaussian_params_'], + + # Mixed 'descriptors' : ['has_data', 'has_model', 'n_peaks_'] } @@ -118,20 +133,21 @@ def get_indices(aperiodic_mode): return indices -def get_info(model_obj, aspect): - """Get a selection of information from a model objects. +# TEMP: TO DROP? +# def get_info(model_obj, aspect): +# """Get a selection of information from a model objects. - Parameters - ---------- - model_obj : SpectralModel or SpectralGroupModel - Object to get attributes from. - aspect : {'settings', 'meta_data', 'results'} - Which set of attributes to compare the objects across. +# Parameters +# ---------- +# model_obj : SpectralModel or SpectralGroupModel +# Object to get attributes from. +# aspect : {'settings', 'meta_data', 'results'} +# Which set of attributes to compare the objects across. - Returns - ------- - dict - The set of specified info from the model object. - """ +# Returns +# ------- +# dict +# The set of specified info from the model object. +# """ - return {key : getattr(model_obj, key) for key in get_description()[aspect]} +# return {key : getattr(model_obj, key) for key in get_description()[aspect]} diff --git a/specparam/modes/mode.py b/specparam/modes/mode.py index 5ea745eae..8631c9e94 100644 --- a/specparam/modes/mode.py +++ b/specparam/modes/mode.py @@ -1,4 +1,4 @@ -"""Modes object.""" +"""Mode object.""" from specparam.utils.checks import check_input_options @@ -55,7 +55,7 @@ def __init__(self, name, component, description, func, jacobian, def __repr__(self): """Return representation of this object as the name.""" - return self.name + return 'MODE: ' + self.component + '-' + self.name def __eq__(self, other): diff --git a/specparam/modes/modes.py b/specparam/modes/modes.py new file mode 100644 index 000000000..6eee60f25 --- /dev/null +++ b/specparam/modes/modes.py @@ -0,0 +1,65 @@ +"""Modes object.""" + +from specparam.data import ModelModes +from specparam.modes.items import OBJ_DESC +from specparam.modes.mode import Mode +from specparam.modes.definitions import AP_MODES, PE_MODES + +################################################################################################### +################################################################################################### + +class Modes(): + """Defines a set of fit modes. + + Parameters + ---------- + aperiodic : str or Mode + Aperiodic mode. + periodic : str or Mode + Periodic mode. + """ + + def __init__(self, aperiodic, periodic): + """Initialize modes.""" + + self.aperiodic = check_mode_definition(aperiodic, AP_MODES) + self.periodic = check_mode_definition(periodic, PE_MODES) + + + def get_modes(self): + """Get the modes definition. + + Returns + ------- + modes_def : ModelModes + Modes definition. + """ + + return ModelModes(aperiodic_mode=self.aperiodic.name, periodic_mode=self.periodic.name) + + +def check_mode_definition(mode, options): + """Check a mode specification. + + Parameters + ---------- + mode : str or Mode + Fit mode. If str, should be a label corresponding to an entry in `options`. + options : dict + Available modes. + + Raises + ------ + ValueError + If the mode definition is not found / understood. + """ + + if isinstance(mode, str): + assert mode in list(options.keys()), 'Specific Mode not found.' + mode = options[mode] + elif isinstance(mode, Mode): + mode = mode + else: + raise ValueError('Mode input not understood.') + + return mode diff --git a/specparam/objs/base.py b/specparam/objs/base.py index 852c6a72a..f2bba65d5 100644 --- a/specparam/objs/base.py +++ b/specparam/objs/base.py @@ -5,8 +5,10 @@ import numpy as np from specparam.utils.array import unlog +from specparam.utils.checks import check_inds +from specparam.modes.modes import Modes from specparam.modes.items import OBJ_DESC -from specparam.modes.definitions import AP_MODES, PE_MODES +from specparam.data.utils import get_results_by_ind from specparam.io.utils import get_files from specparam.io.files import load_json, load_jsonlines from specparam.io.models import save_model, save_group, save_event @@ -16,6 +18,7 @@ from specparam.objs.results import BaseResults, BaseResults2D, BaseResults2DT, BaseResults3D from specparam.objs.data import BaseData, BaseData2D, BaseData2DT, BaseData3D from specparam.objs.utils import run_parallel_group, run_parallel_event, pbar +from specparam.objs.metrics import Metrics ################################################################################################### ################################################################################################### @@ -23,6 +26,15 @@ class CommonBase(): """Define CommonBase object.""" + def __init__(self, verbose): + """Initialize object.""" + + self.metrics = Metrics() + self.metrics.set_defaults() + + self.verbose = verbose + + def copy(self): """Return a copy of the current object.""" @@ -59,7 +71,7 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): self.add_data(freqs, power_spectrum, freq_range) # Check that data is available - if not self.has_data: + if not self.data.has_data: raise NoDataError("No data available to fit, can not proceed.") # In rare cases, the model fails to fit, and so uses try / except @@ -68,27 +80,31 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None): # If not set to fail on NaN or Inf data at add time, check data here # This serves as a catch all for curve_fits which will fail given NaN or Inf # Because FitError's are by default caught, this allows fitting to continue - if not self._check_data: - if np.any(np.isinf(self.power_spectrum)) or np.any(np.isnan(self.power_spectrum)): + if not self.data._check_data: + if np.any(np.isinf(self.data.power_spectrum)) or \ + np.any(np.isnan(self.data.power_spectrum)): raise FitError("Model fitting was skipped because there are NaN or Inf " "values in the data, which preclude model fitting.") # Call the fit function from the algorithm object - self._fit() + self.algorithm._fit() - # Compute goodness of fit & error measures - self._compute_model_gof() - self._compute_model_error() + # Compute post-fit metrics + self.metrics.compute_metrics(self.data, self.results) + + # TEMP: alias metric results into updated management + self.results.error_ = self.metrics['error-mae'].output + self.results.r_squared_ = self.metrics['gof-r_squared'].output except FitError: # If in debug mode, re-raise the error - if self._debug: + if self.algorithm._debug: raise # Clear any interim model results that may have run # Partial model results shouldn't be interpreted in light of overall failure - self._reset_results(clear_results=True) + self.results._reset_results(True) # Print out status if self.verbose: @@ -124,18 +140,19 @@ def get_data(self, component='full', space='log'): With space set as 'linear', this combination holds in linear space. """ - if not self.has_data: + if not self.data.has_data: raise NoDataError("No data available to fit, can not proceed.") assert space in ['linear', 'log'], "Input for 'space' invalid." if component == 'full': - output = self.power_spectrum if space == 'log' else unlog(self.power_spectrum) + output = self.data.power_spectrum if space == 'log' \ + else unlog(self.data.power_spectrum) elif component == 'aperiodic': - output = self._spectrum_peak_rm if space == 'log' else \ - unlog(self.power_spectrum) / unlog(self._peak_fit) + output = self.results._spectrum_peak_rm if space == 'log' else \ + unlog(self.data.power_spectrum) / unlog(self.results._peak_fit) elif component == 'peak': - output = self._spectrum_flat if space == 'log' else \ - unlog(self.power_spectrum) - unlog(self._ap_fit) + output = self.results._spectrum_flat if space == 'log' else \ + unlog(self.data.power_spectrum) - unlog(self.results._ap_fit) else: raise ValueError('Input for component invalid.') @@ -152,19 +169,41 @@ def _add_from_dict(self, data): """ for key in data.keys(): - setattr(self, key, data[key]) + if getattr(self, key, False) is not False: + setattr(self, key, data[key]) + elif getattr(self.data, key, False) is not False: + setattr(self.data, key, data[key]) + elif getattr(self.results, key, False) is not False: + setattr(self.results, key, data[key]) + + + def _check_loaded_modes(self, data): + """Check if fit modes added, and update the object as needed. + + Parameters + ---------- + data : dict + A dictionary of data that has been added to the object. + """ + + # TEMP / ToDo: not quite clear if this is the right place + # And/or - might want a clearer process to 'reset' Modes + if 'aperiodic_mode' in data and 'periodic_mode' in data: + self.modes = Modes(aperiodic=data['aperiodic_mode'], + periodic=data['periodic_mode']) -class BaseObject(CommonBase, BaseResults, BaseData): + +class BaseObject(CommonBase): """Define Base object for fitting models to 1D data.""" - def __init__(self, aperiodic_mode=None, periodic_mode=None, debug=False, verbose=True): + def __init__(self, modes=None, verbose=False): """Initialize BaseObject object.""" - CommonBase.__init__(self) - BaseData.__init__(self) - BaseResults.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) + CommonBase.__init__(self, verbose=verbose) + + self.data = BaseData() + self.results = BaseResults(modes=modes) @replace_docstring_sections([docs_get_section(BaseData.add_data.__doc__, 'Parameters'), @@ -185,9 +224,9 @@ def add_data(self, freqs, power_spectrum, freq_range=None, clear_results=True): """ # Clear results, if present, unless indicated not to - self._reset_results(clear_results=self.has_model and clear_results) + self.results._reset_results(self.results.has_model and clear_results) - super().add_data(freqs, power_spectrum, freq_range=freq_range) + self.data.add_data(freqs, power_spectrum, freq_range=freq_range) @copy_doc_func_to_method(save_model) @@ -219,15 +258,15 @@ def load(self, file_name, file_path=None, regenerate=True): # Add loaded data to object and check loaded data self._add_from_dict(data) self._check_loaded_modes(data) - self._check_loaded_settings(data) - self._check_loaded_results(data) + self.algorithm._check_loaded_settings(data) + self.results._check_loaded_results(data) # Regenerate model components, based on what is available if regenerate: - if self.freq_res: - self._regenerate_freqs() - if np.all(self.freqs) and np.all(self.aperiodic_params_): - self._regenerate_model() + if self.data.freq_res: + self.data._regenerate_freqs() + if np.all(self.data.freqs) and np.all(self.results.aperiodic_params_): + self.results._regenerate_model(self.data.freqs) def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False): @@ -243,20 +282,20 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res Whether to clear model results attributes. """ - self._reset_data(clear_freqs, clear_spectrum) - self._reset_results(clear_results) + self.data._reset_data(clear_freqs, clear_spectrum) + self.results._reset_results(clear_results) -class BaseObject2D(CommonBase, BaseResults2D, BaseData2D): +class BaseObject2D(CommonBase): """Define Base object for fitting models to 2D data.""" - def __init__(self, aperiodic_mode=None, periodic_mode=None, debug=False, verbose=True): + def __init__(self, modes=None, verbose=True): """Initialize BaseObject2D object.""" - CommonBase.__init__(self) - BaseData2D.__init__(self) - BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) + CommonBase.__init__(self, verbose=verbose) + + self.data = BaseData2D() + self.results = BaseResults2D(modes=modes) def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): @@ -282,11 +321,11 @@ def add_data(self, freqs, power_spectra, freq_range=None, clear_results=True): # If any data is already present, then clear data & results # This is to ensure object consistency of all data & results - if clear_results and np.any(self.freqs): + if clear_results and np.any(self.data.freqs): self._reset_data_results(True, True, True, True) - self._reset_group_results() + self.results._reset_group_results() - super().add_data(freqs, power_spectra, freq_range=freq_range) + self.data.add_data(freqs, power_spectra, freq_range=freq_range) def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progress=None): @@ -317,21 +356,22 @@ def fit(self, freqs=None, power_spectra=None, freq_range=None, n_jobs=1, progres # If 'verbose', print out a marker of what is being run if self.verbose and not progress: - print('Fitting model across {} power spectra.'.format(len(self.power_spectra))) + print('Fitting model across {} power spectra.'.format(len(self.data.power_spectra))) # Run linearly if n_jobs == 1: - self._reset_group_results(len(self.power_spectra)) + self.results._reset_group_results(len(self.data.power_spectra)) for ind, power_spectrum in \ - pbar(enumerate(self.power_spectra), progress, len(self)): + pbar(enumerate(self.data.power_spectra), progress, len(self.results)): self._pass_through_spectrum(power_spectrum) super().fit() - self.group_results[ind] = self._get_results() + self.results.group_results[ind] = self.results._get_results() # Run in parallel else: - self._reset_group_results() - self.group_results = run_parallel_group(self, self.power_spectra, n_jobs, progress) + self.results._reset_group_results() + self.results.group_results = run_parallel_group(\ + self, self.data.power_spectra, n_jobs, progress) # Clear the individual power spectrum and fit results of the current fit self._reset_data_results(clear_spectrum=True, clear_results=True) @@ -356,39 +396,110 @@ def load(self, file_name, file_path=None): """ # Clear results so as not to have possible prior results interfere - self._reset_group_results() + self.results._reset_group_results() power_spectra = [] for ind, data in enumerate(load_jsonlines(file_name, file_path)): + # If power spectra data is part of loaded data, collect to add to object + if 'power_spectrum' in data.keys(): + power_spectra.append(data.pop('power_spectrum')) + self._add_from_dict(data) # If settings are loaded, check and update based on the first line if ind == 0: self._check_loaded_modes(data) - self._check_loaded_settings(data) - - # If power spectra data is part of loaded data, collect to add to object - if 'power_spectrum' in data.keys(): - power_spectra.append(data['power_spectrum']) + self.algorithm._check_loaded_settings(data) # If results part of current data added, check and update object results if set(OBJ_DESC['results']).issubset(set(data.keys())): - self._check_loaded_results(data) - self.group_results.append(self._get_results()) + self.results._check_loaded_results(data) + self.results.group_results.append(self.results._get_results()) # Reconstruct frequency vector, if information is available to do so - if self.freq_range: - self._regenerate_freqs() + if self.data.freq_range: + self.data._regenerate_freqs() # Add power spectra data, if they were loaded if power_spectra: - self.power_spectra = np.array(power_spectra) + self.data.power_spectra = np.array(power_spectra) # Reset peripheral data from last loaded result, keeping freqs info self._reset_data_results(clear_spectrum=True, clear_results=True) + def get_model(self, ind=None, regenerate=True): + """Get a model fit object for a specified index. + + Parameters + ---------- + ind : int, optional + The index of the model from `group_results` to access. + If None, return a Model object with initialized settings, with no data or results. + regenerate : bool, optional, default: False + Whether to regenerate the model fits for the requested model. + + Returns + ------- + model : SpectralModel + The data and fit results loaded into a model object. + """ + + # Local import - avoid circularity + from specparam.models.utils import initialize_model_from_source + + # Initialize model object, with same settings, metadata, & check mode as current object + model = initialize_model_from_source(self, 'model') + + # Add data for specified single power spectrum, if available + if ind is not None and self.data.has_data: + model.data.power_spectrum = self.data.power_spectra[ind] + + # Add results for specified power spectrum, regenerating full fit if requested + if ind: + model.results.add_results(self.results.group_results[ind]) + if regenerate: + model.results._regenerate_model(self.data.freqs) + + return model + + + def get_group(self, inds): + """Get a Group model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + + Returns + ------- + group : SpectralGroupModel + The requested selection of results data loaded into a new group model object. + """ + + # Local import - avoid circularity + from specparam.models.utils import initialize_model_from_source + + # Initialize a new model object, with same settings as current object + group = initialize_model_from_source(self, 'group') + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + + # Add data for specified power spectra, if available + if self.data.has_data: + group.data.power_spectra = self.data.power_spectra[inds, :] + + # Add results for specified power spectra + group.results.group_results = [self.results.group_results[ind] for ind in inds] + + return group + + def _pass_through_spectrum(self, power_spectrum): """Pass through a power spectrum to add to object. @@ -400,7 +511,7 @@ def _pass_through_spectrum(self, power_spectrum): have already undergone data checking during data adding. """ - self.power_spectrum = power_spectrum + self.data.power_spectrum = power_spectrum def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, @@ -419,21 +530,20 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, Whether to clear power spectra attribute. """ - self._reset_data(clear_freqs, clear_spectrum, clear_spectra) - self._reset_results(clear_results) + self.data._reset_data(clear_freqs, clear_spectrum, clear_spectra) + self.results._reset_results(clear_results) -class BaseObject2DT(BaseObject2D, BaseResults2DT, BaseData2DT): +class BaseObject2DT(BaseObject2D): """Define Base object for fitting models to 2D data - tranpose version.""" - def __init__(self, aperiodic_mode=None, periodic_mode=None, debug=False, verbose=True): + def __init__(self, modes=None, verbose=True): """Initialize BaseObject2DT object.""" - BaseData2DT.__init__(self) - BaseObject2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) - BaseResults2D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) + BaseObject2D.__init__(self, modes=modes, verbose=verbose) + + self.data = BaseData2DT() + self.results = BaseResults2DT(modes=modes) def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, @@ -465,7 +575,7 @@ def fit(self, freqs=None, spectrogram=None, freq_range=None, peak_org=None, super().fit(freqs, spectrogram, freq_range, n_jobs, progress) if peak_org is not False: - self.convert_results(peak_org) + self.results.convert_results(peak_org) def load(self, file_name, file_path=None, peak_org=None): @@ -484,23 +594,67 @@ def load(self, file_name, file_path=None, peak_org=None): """ # Clear results so as not to have possible prior results interfere - self._reset_time_results() + self.results._reset_time_results() super().load(file_name, file_path=file_path) - if peak_org is not False and self.group_results: - self.convert_results(peak_org) + if peak_org is not False and self.results.group_results: + self.results.convert_results(peak_org) + + + def get_group(self, inds, output_type='time'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + inds : array_like of int or array_like of bool + Indices to extract from the object. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeModel or SpectralGroupModel + The requested selection of results data loaded into a new model object. + """ + + if output_type == 'time': + + # Local import - avoid circularity + from specparam.models.utils import initialize_model_from_source + + # Initialize a new model object, with same settings as current object + output = initialize_model_from_source(self, 'time') + + if inds is not None: + + # Check and convert indices encoding to list of int + inds = check_inds(inds) + # Add data for specified power spectra, if available + if self.data.has_data: + output.data.power_spectra = self.data.power_spectra[inds, :] -class BaseObject3D(BaseObject2DT, BaseResults3D, BaseData3D): + # Add results for specified power spectra + output.results.group_results = [self.results.group_results[ind] for ind in inds] + output.results.time_results = get_results_by_ind(self.results.time_results, inds) + + if output_type == 'group': + output = super().get_group(inds) + + return output + + +class BaseObject3D(BaseObject2DT): """Define Base object for fitting models to 3D data.""" - def __init__(self, aperiodic_mode=None, periodic_mode=None, debug=False, verbose=True): + def __init__(self, modes=None, verbose=True): """Initialize BaseObject3D object.""" - BaseData3D.__init__(self) - BaseObject2DT.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) - BaseResults3D.__init__(self, aperiodic_mode=aperiodic_mode, periodic_mode=periodic_mode, - debug=debug, verbose=verbose) + BaseObject2DT.__init__(self, modes=modes, verbose=verbose) + + self.data = BaseData3D() + self.results = BaseResults3D(modes=modes) def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): @@ -512,8 +666,8 @@ def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): Frequency values for the power spectra, in linear space. spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + If list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to restrict power spectra to. If not provided, keeps the entire range. clear_results : bool, optional, default: True @@ -527,9 +681,9 @@ def add_data(self, freqs, spectrograms, freq_range=None, clear_results=True): """ if clear_results: - self._reset_event_results() + self.results._reset_event_results() - super().add_data(freqs, spectrograms, freq_range=freq_range) + self.data.add_data(freqs, spectrograms, freq_range=freq_range) def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, @@ -542,8 +696,8 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, Frequency values for the power_spectra, in linear space. spectrograms : 3d array or list of 2d array Matrix of power values, in linear space. - If a list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. - If a 3d array, should have shape [n_events, n_freqs, n_time_windows]. + If list of 2d arrays, each should be have the same shape of [n_freqs, n_time_windows]. + If 3d array, should have shape [n_events, n_freqs, n_time_windows]. freq_range : list of [float, float], optional Frequency range to fit the model to. If not provided, fits the entire given range. peak_org : int or Bands @@ -567,23 +721,25 @@ def fit(self, freqs=None, spectrograms=None, freq_range=None, peak_org=None, # If 'verbose', print out a marker of what is being run if self.verbose and not progress: print('Fitting model across {} events of {} windows.'.format(\ - len(self.spectrograms), self.n_time_windows)) + len(self.data.spectrograms), self.data.n_time_windows)) if n_jobs == 1: - self._reset_event_results(len(self.spectrograms)) - for ind, spectrogram in pbar(enumerate(self.spectrograms), progress, len(self)): - self.power_spectra = spectrogram.T + self.results._reset_event_results(len(self.data.spectrograms)) + for ind, spectrogram in \ + pbar(enumerate(self.data.spectrograms), progress, len(self.results)): + self.data.power_spectra = spectrogram.T super().fit(peak_org=False) - self.event_group_results[ind] = self.group_results - self._reset_group_results() + self.results.event_group_results[ind] = self.results.group_results + self.results._reset_group_results() self._reset_data_results(clear_spectra=True) else: fg = self.get_group(None, None, 'group') - self.event_group_results = run_parallel_event(fg, self.spectrograms, n_jobs, progress) + self.results.event_group_results = run_parallel_event(\ + fg, self.data.spectrograms, n_jobs, progress) if peak_org is not False: - self.convert_results(peak_org) + self.results.convert_results(peak_org) @copy_doc_func_to_method(save_event) @@ -612,15 +768,86 @@ def load(self, file_name, file_path=None, peak_org=None): spectrograms = [] for file in files: super().load(file, file_path, peak_org=False) - if self.group_results: - self.add_results(self.group_results, append=True) - if np.all(self.power_spectra): - spectrograms.append(self.spectrogram) - self.spectrograms = np.array(spectrograms) if spectrograms else None + if self.results.group_results: + self.results.add_results(self.results.group_results, append=True) + if np.all(self.data.power_spectra): + spectrograms.append(self.data.spectrogram) + self.data.spectrograms = np.array(spectrograms) if spectrograms else None - self._reset_group_results() - if peak_org is not False and self.event_group_results: - self.convert_results(peak_org) + self.results._reset_group_results() + if peak_org is not False and self.results.event_group_results: + self.results.convert_results(peak_org) + + + def get_group(self, event_inds, window_inds, output_type='event'): + """Get a new model object with the specified sub-selection of model fits. + + Parameters + ---------- + event_inds, window_inds : array_like of int or array_like of bool or None + Indices to extract from the object, for event and time windows. + If None, selects all available indices. + output_type : {'time', 'group'}, optional + Type of model object to extract: + 'event' : SpectralTimeEventObject + 'time' : SpectralTimeObject + 'group' : SpectralGroupObject + + Returns + ------- + output : SpectralTimeEventModel + The requested selection of results data loaded into a new model object. + """ + + # Check and convert indices encoding to list of int + einds = check_inds(event_inds, self.data.n_events) + winds = check_inds(window_inds, self.data.n_time_windows) + + if output_type == 'event': + + # Local import - avoid circularity + from specparam.models.utils import initialize_model_from_source + + # Initialize a new model object, with same settings as current object + output = initialize_model_from_source(self, 'event') + + if event_inds is not None or window_inds is not None: + + # Add data for specified power spectra, if available + if self.data.has_data: + output.data.spectrograms = self.data.spectrograms[einds, :, :][:, :, winds] + + # Add results for specified power spectra - event group results + temp = [self.results.event_group_results[ei][wi] for ei in einds for wi in winds] + step = int(len(temp) / len(einds)) + output.results.event_group_results = \ + [temp[ind:ind+step] for ind in range(0, len(temp), step)] + + # Add results for specified power spectra - event time results + output.results.event_time_results = \ + {key : self.results.event_time_results[key][event_inds][:, window_inds] \ + for key in self.results.event_time_results} + + elif output_type in ['time', 'group']: + + if event_inds is not None or window_inds is not None: + + # Move specified results & data to `group_results` & `power_spectra` for export + self.results.group_results = \ + [self.results.event_group_results[ei][wi] for ei in einds for wi in winds] + if self.data.has_data: + self.data.power_spectra = \ + np.hstack(self.data.spectrograms[einds, :, :][:, :, winds]).T + + new_inds = range(0, len(self.results.group_results)) if \ + self.results.group_results else None + output = super().get_group(new_inds, output_type) + + # Clear the data that was moved for export + self.results._reset_group_results() + self._reset_data_results(clear_spectra=True) + + return output def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_results=False, @@ -641,5 +868,5 @@ def _reset_data_results(self, clear_freqs=False, clear_spectrum=False, clear_res Whether to clear spectrograms attribute. """ - self._reset_data(clear_freqs, clear_spectrum, clear_spectra, clear_spectrograms) - self._reset_results(clear_results) + self.data._reset_data(clear_freqs, clear_spectrum, clear_spectra, clear_spectrograms) + self.results._reset_results(clear_results) diff --git a/specparam/objs/data.py b/specparam/objs/data.py index 6a31246c7..aa5abbb02 100644 --- a/specparam/objs/data.py +++ b/specparam/objs/data.py @@ -34,7 +34,7 @@ def __init__(self, check_freqs=True, check_data=True): self._reset_data(True, True) - # Define data check run modes + # Define data check run statuses self._check_freqs = check_freqs self._check_data = check_data @@ -43,7 +43,7 @@ def __init__(self, check_freqs=True, check_data=True): def has_data(self): """Indicator for if the object contains data.""" - return True if np.any(self.power_spectrum) else False + return bool(np.any(self.power_spectrum)) def add_data(self, freqs, power_spectrum, freq_range=None): @@ -93,7 +93,7 @@ def get_checks(self): Returns ------- ModelChecks - Object containing the run modes from the current object. + Object containing the check statuses from the current object. """ return ModelChecks(**{key.strip('_') : getattr(self, key) \ @@ -246,7 +246,7 @@ def _prepare_data(self, freqs, powers, freq_range, spectra_dim=1): # Log power values powers = np.log10(powers) - ## Data checks - run checks on inputs based on check modes + ## Data checks - run checks on inputs based on check statuses if self._check_freqs: # Check if the frequency data is unevenly spaced, and raise an error if so @@ -281,7 +281,7 @@ def __init__(self): def has_data(self): """Indicator for if the object contains data.""" - return True if np.any(self.power_spectra) else False + return bool(np.any(self.power_spectra)) def add_data(self, freqs, power_spectra, freq_range=None): diff --git a/specparam/objs/metrics.py b/specparam/objs/metrics.py new file mode 100644 index 000000000..f27bf9224 --- /dev/null +++ b/specparam/objs/metrics.py @@ -0,0 +1,154 @@ +"""Metrics object.""" + +################################################################################################### +################################################################################################### + +class Metric(): + """Define a metric to apply to a power spectrum model. + + Parameters + ---------- + measure : str + The type of measure, e.g. 'error' or 'gof'. + metric : str + The specific measure, e.g. 'r_squared'. + func : callable + The function that computes the metric. + """ + + def __init__(self, measure, metric, func): + """Initialize metric.""" + + self.measure = measure + self.metric = metric + self.func = func + self.output = None + + + def __repr__(self): + """Set string representation of object.""" + + return 'Metric: ' + self.label + + + @property + def label(self): + """Define label property.""" + + return self.measure + '-' + self.metric + + + def compute_metric(self, data, results): + """Compute metric. + + Parameters + ---------- + data : Data + Model data. + results : Results + Model results. + """ + + self.output = self.func(data.power_spectrum, results.modeled_spectrum_) + + +class Metrics(): + """Define a collection of metrics. + + Parameters + ---------- + metrics : list of Metric or list of dict + Metric(s) to add to the object. + """ + + def __init__(self, metrics=None): + """Initialize metrics.""" + + self.metrics = [] + if metrics: + self.add_metrics(metrics) + + + def __getitem__(self, label): + """Index into the object based on metric label. + + Parameters + ---------- + label : str + Label of the metric to access. + """ + + for ind, clabel in enumerate(self.labels): + if label == clabel: + return self.metrics[ind] + + raise ValueError('Requested label not found.') + + + def add_metric(self, metric): + """Add a metric to the object. + + Parameters + ---------- + metric : Metric or dict + Metric to add to the object. + If dict, should have keys corresponding to a metric definition. + """ + + if isinstance(metric, dict): + metric = Metric(**metric) + + self.metrics.append(metric) + + + def add_metrics(self, metrics): + """Add metric(s) to object + + Parameters + ---------- + metrics : list of Metric or list of dict + Metric(s) to add to the object. + """ + + for metric in metrics: + self.add_metric(metric) + + + def compute_metrics(self, data, results): + """Compute all currently defined metrics. + + Parameters + ---------- + data : Data + Model data. + results : Results + Model results. + """ + + for metric in self.metrics: + metric.compute_metric(data, results) + + + @property + def labels(self): + """Define alias for labels of all currently defined metrics.""" + + return [metric.label for metric in self.metrics] + + + @property + def outputs(self): + """Define alias for ouputs of all currently defined metrics.""" + + return {metric.label : metric.output for metric in self.metrics} + + + # TEMP: CHECK IF THIS IS HOW TO MANAGE THIS + def set_defaults(self): + """Set default metrics.""" + + from specparam.measures.error import compute_mean_abs_error + from specparam.measures.gof import compute_r_squared + + self.add_metrics([Metric('error', 'mae', compute_mean_abs_error), + Metric('gof', 'r_squared', compute_r_squared)]) diff --git a/specparam/objs/results.py b/specparam/objs/results.py index 7ab1fd988..154725aa9 100644 --- a/specparam/objs/results.py +++ b/specparam/objs/results.py @@ -5,15 +5,13 @@ import numpy as np from specparam.modes.items import OBJ_DESC -from specparam.modes.definitions import AP_MODES, PE_MODES from specparam.utils.array import unlog from specparam.utils.checks import check_inds, check_array_dim from specparam.modutils.errors import NoModelError -from specparam.data import FitResults, ModelSettings +from specparam.data import FitResults from specparam.data.conversions import group_to_dict, event_group_to_dict from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row -from specparam.measures.gof import compute_gof -from specparam.measures.error import compute_error +from specparam.sim.gen import gen_model ################################################################################################### ################################################################################################### @@ -22,31 +20,14 @@ class BaseResults(): """Base object for managing results.""" # pylint: disable=attribute-defined-outside-init, arguments-differ - def __init__(self, aperiodic_mode, periodic_mode, debug=False, - verbose=True, error_metric='MAE', gof_metric='r_squared'): + def __init__(self, modes=None): """Initialize BaseResults object.""" - # Set fit component modes - if isinstance(aperiodic_mode, str): - self.aperiodic_mode = AP_MODES[aperiodic_mode] - else: - self.aperiodic_mode = aperiodic_mode - if isinstance(periodic_mode, str): - self.periodic_mode = PE_MODES[periodic_mode] - else: - self.periodic_mode = periodic_mode - - # Set run approaches - self.set_debug(debug) - self.verbose = verbose + self.modes = modes # Initialize results attributes self._reset_results(True) - # Set private run settings - self._error_metric = error_metric - self._gof_metric = gof_metric - @property def has_model(self): @@ -60,7 +41,7 @@ def has_model(self): - necessarily defined, as floats, if model has been fit """ - return True if not np.all(np.isnan(self.aperiodic_params_)) else False + return not np.all(np.isnan(self.aperiodic_params_)) @property @@ -70,40 +51,6 @@ def n_peaks_(self): return self.peak_params_.shape[0] if self.has_model else None - def add_settings(self, settings): - """Add settings into object from a ModelSettings object. - - Parameters - ---------- - settings : ModelSettings - A data object containing the settings for a power spectrum model. - """ - - for setting in OBJ_DESC['settings']: - setattr(self, setting, getattr(settings, setting)) - - self._check_loaded_settings(settings._asdict()) - - - def get_settings(self): - """Return user defined settings of the current object. - - Returns - ------- - ModelSettings - Object containing the settings from the current object. - """ - - return ModelSettings(**{key : getattr(self, key) \ - for key in OBJ_DESC['settings']}) - - - def get_debug(self): - """Return object debug status.""" - - return self._debug - - def add_results(self, results): """Add results data into object from a FitResults object. @@ -135,7 +82,7 @@ def get_results(self): for key in OBJ_DESC['results']}) - def get_model(self, component='full', space='log'): + def get_component(self, component='full', space='log'): """Get a model component. Parameters @@ -181,56 +128,6 @@ def get_model(self, component='full', space='log'): return output - def set_debug(self, debug): - """Set debug state, which controls if an error is raised if model fitting is unsuccessful. - - Parameters - ---------- - debug : bool - Whether to run in debug state. - """ - - self._debug = debug - - - def _check_loaded_modes(self, data): - """Check if fit modes added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If fit mode information in loaded, reload mode definitions - self.aperiodic_mode = AP_MODES[data['aperiodic_mode']] \ - if 'aperiodic_mode' in data else None - self.periodic_mode = PE_MODES[data['periodic_mode']] \ - if 'periodic_mode' in data else None - - - def _check_loaded_settings(self, data): - """Check if settings added, and update the object as needed. - - Parameters - ---------- - data : dict - A dictionary of data that has been added to the object. - """ - - # If settings not loaded from file, clear from object, so that default - # settings, which are potentially wrong for loaded data, aren't kept - if not set(OBJ_DESC['settings']).issubset(set(data.keys())): - - # Reset all public settings to None - for setting in OBJ_DESC['settings']: - setattr(self, setting, None) - - # Reset internal settings so that they are consistent with what was loaded - # Note that this will set internal settings to None, if public settings unavailable - self._reset_internal_settings() - - def _check_loaded_results(self, data): """Check if results have been added and check data. @@ -247,10 +144,6 @@ def _check_loaded_results(self, data): self.gaussian_params_ = check_array_dim(self.gaussian_params_) - def _reset_internal_settings(self): - """"Can be overloaded if any resetting needed for internal settings.""" - - def _reset_results(self, clear_results=False): """Set, or reset, results attributes to empty. @@ -262,18 +155,16 @@ def _reset_results(self, clear_results=False): if clear_results: - # TEMP / Note - for ap / pe params, move to something like `xx_params` and `_xx_params` (?) - # Aperiodic parameters - if self.aperiodic_mode: - self.aperiodic_params_ = np.array([np.nan] * self.aperiodic_mode.n_params) + if self.modes: + self.aperiodic_params_ = np.array([np.nan] * self.modes.aperiodic.n_params) else: self.aperiodic_params_ = np.nan # Periodic parameters - if self.periodic_mode: - self.gaussian_params_ = np.empty([0, self.periodic_mode.n_params]) - self.peak_params_ = np.empty([0, self.periodic_mode.n_params]) + if self.modes: + self.gaussian_params_ = np.empty([0, self.modes.periodic.n_params]) + self.peak_params_ = np.empty([0, self.modes.periodic.n_params]) else: self.gaussian_params_ = np.nan self.peak_params_ = np.nan @@ -292,51 +183,27 @@ def _reset_results(self, clear_results=False): self._peak_fit = None - def _compute_model_gof(self, metric=None): - """Calculate the r-squared goodness of fit of the model, compared to the original data. - - Parameters - ---------- - metric : {'r_squared', 'adj_r_squared'}, optional - Which goodness of fit measure to compute: - * 'r_squared' : R-squared - * 'adj_r_squared' : Adjusted R-squared - - Notes - ----- - Which measure is applied is by default controlled by the `_gof_metric` attribute. - """ - - self.r_squared_ = compute_gof(self.power_spectrum, self.modeled_spectrum_) - - - def _compute_model_error(self, metric=None): - """Calculate the overall error of the model fit, compared to the original data. + def _regenerate_model(self, freqs): + """Regenerate model fit from parameters. Parameters ---------- - metric : {'MAE', 'MSE', 'RMSE'}, optional - Which error measure to compute: - * 'MAE' : mean absolute error - * 'MSE' : mean squared error - * 'RMSE' : root mean squared error - - Notes - ----- - Which measure is applied is by default controlled by the `_error_metric` attribute. + freqs : 1d array + Frequency values for the power_spectrum, in linear scale. """ - self.error_ = compute_error(self.power_spectrum, self.modeled_spectrum_, - self._error_metric if not metric else metric) + self.modeled_spectrum_, self._peak_fit, self._ap_fit = gen_model( + freqs, self.aperiodic_params_, + self.gaussian_params_, return_components=True) class BaseResults2D(BaseResults): """Base object for managing results - 2D version.""" - def __init__(self, aperiodic_mode, periodic_mode, debug=False, verbose=True): + def __init__(self, modes=None): """Initialize BaseResults2D object.""" - BaseResults.__init__(self, aperiodic_mode, periodic_mode, debug=debug, verbose=verbose) + BaseResults.__init__(self, modes=modes) self._reset_group_results() @@ -382,7 +249,7 @@ def _get_results(self): def has_model(self): """Indicator for if the object contains model fits.""" - return True if self.group_results else False + return bool(self.group_results) @property @@ -440,10 +307,10 @@ def drop(self, inds): This method sets the model fits as null, and preserves the shape of the model fits. """ - # Temp import - consider refactoring + # Local import - avoid circular from specparam import SpectralModel - null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + null_model = SpectralModel(**self.modes.get_modes()._asdict()).results.get_results() for ind in check_inds(inds): self.group_results[ind] = null_model @@ -483,88 +350,13 @@ def get_params(self, name, col=None): return get_group_params(self.group_results, name, col) - def get_model(self, ind, regenerate=True): - """Get a model fit object for a specified index. - - Parameters - ---------- - ind : int - The index of the model from `group_results` to access. - regenerate : bool, optional, default: False - Whether to regenerate the model fits for the requested model. - - Returns - ------- - model : SpectralModel - The FitResults data loaded into a model object. - """ - - # Local import - avoid circular - from specparam import SpectralModel - - # Initialize model object, with same settings, metadata, & check mode as current object - model = SpectralModel(**self.get_settings()._asdict(), verbose=self.verbose) - model.add_meta_data(self.get_meta_data()) - model.set_checks(*self.get_checks()) - model.set_debug(self.get_debug()) - - # Add data for specified single power spectrum, if available - if self.has_data: - model.power_spectrum = self.power_spectra[ind] - - # Add results for specified power spectrum, regenerating full fit if requested - model.add_results(self.group_results[ind]) - if regenerate: - model._regenerate_model() - - return model - - - def get_group(self, inds): - """Get a Group model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - - Returns - ------- - group : SpectralGroupModel - The requested selection of results data loaded into a new group model object. - """ - - # Local import - avoid circular - from specparam import SpectralGroupModel - - # Initialize a new model object, with same settings as current object - group = SpectralGroupModel(**self.get_settings()._asdict(), verbose=self.verbose) - group.add_meta_data(self.get_meta_data()) - group.set_checks(*self.get_checks()) - group.set_debug(self.get_debug()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - group.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - group.group_results = [self.group_results[ind] for ind in inds] - - return group - - class BaseResults2DT(BaseResults2D): """Base object for managing results - 2D transpose version.""" - def __init__(self, aperiodic_mode, periodic_mode, debug=False, verbose=True): + def __init__(self, modes=None): """Initialize BaseResults2DT object.""" - BaseResults2D.__init__(self, aperiodic_mode, periodic_mode, debug=debug, verbose=verbose) + BaseResults2D.__init__(self, modes=modes) self._reset_time_results() @@ -587,52 +379,6 @@ def get_results(self): return self.time_results - def get_group(self, inds, output_type='time'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - inds : array_like of int or array_like of bool - Indices to extract from the object. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeModel or SpectralGroupModel - The requested selection of results data loaded into a new model object. - """ - - if output_type == 'time': - - # Local import - avoid circular - from specparam import SpectralTimeModel - - # Initialize a new model object, with same settings as current object - output = SpectralTimeModel(**self.get_settings()._asdict(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if inds is not None: - - # Check and convert indices encoding to list of int - inds = check_inds(inds) - - # Add data for specified power spectra, if available - if self.has_data: - output.power_spectra = self.power_spectra[inds, :] - - # Add results for specified power spectra - output.group_results = [self.group_results[ind] for ind in inds] - output.time_results = get_results_by_ind(self.time_results, inds) - - if output_type == 'group': - output = super().get_group(inds) - - return output - - def drop(self, inds): """Drop one or more model fit results from the object. @@ -668,10 +414,10 @@ def convert_results(self, peak_org): class BaseResults3D(BaseResults2DT): """Base object for managing results - 3D version.""" - def __init__(self, aperiodic_mode, periodic_mode, debug=False, verbose=True): + def __init__(self, modes=None): """Initialize BaseResults3D object.""" - BaseResults2DT.__init__(self, aperiodic_mode, periodic_mode, debug=debug, verbose=verbose) + BaseResults2DT.__init__(self, modes=modes) self._reset_event_results() @@ -731,7 +477,7 @@ def drop(self, drop_inds=None, window_inds=None): # Local import - avoid circular from specparam import SpectralModel - null_model = SpectralModel(**self.get_settings()._asdict()).get_results() + null_model = SpectralModel(**self.modes.get_modes()._asdict()).results.get_results() drop_inds = drop_inds if isinstance(drop_inds, dict) else \ dict(zip(check_inds(drop_inds), repeat(window_inds))) @@ -800,75 +546,6 @@ def get_params(self, name, col=None): return [get_group_params(gres, name, col) for gres in self.event_group_results] - def get_group(self, event_inds, window_inds, output_type='event'): - """Get a new model object with the specified sub-selection of model fits. - - Parameters - ---------- - event_inds, window_inds : array_like of int or array_like of bool or None - Indices to extract from the object, for event and time windows. - If None, selects all available indices. - output_type : {'time', 'group'}, optional - Type of model object to extract: - 'event' : SpectralTimeEventObject - 'time' : SpectralTimeObject - 'group' : SpectralGroupObject - - Returns - ------- - output : SpectralTimeEventModel - The requested selection of results data loaded into a new model object. - """ - - # Local import - avoid circular - from specparam import SpectralTimeEventModel - - # Check and convert indices encoding to list of int - einds = check_inds(event_inds, self.n_events) - winds = check_inds(window_inds, self.n_time_windows) - - if output_type == 'event': - - # Initialize a new model object, with same settings as current object - output = SpectralTimeEventModel(**self.get_settings()._asdict(), verbose=self.verbose) - output.add_meta_data(self.get_meta_data()) - - if event_inds is not None or window_inds is not None: - - # Add data for specified power spectra, if available - if self.has_data: - output.spectrograms = self.spectrograms[einds, :, :][:, :, winds] - - # Add results for specified power spectra - event group results - temp = [self.event_group_results[ei][wi] for ei in einds for wi in winds] - step = int(len(temp) / len(einds)) - output.event_group_results = \ - [temp[ind:ind+step] for ind in range(0, len(temp), step)] - - # Add results for specified power spectra - event time results - output.event_time_results = \ - {key : self.event_time_results[key][event_inds][:, window_inds] \ - for key in self.event_time_results} - - elif output_type in ['time', 'group']: - - if event_inds is not None or window_inds is not None: - - # Move specified results & data to `group_results` & `power_spectra` for export - self.group_results = \ - [self.event_group_results[ei][wi] for ei in einds for wi in winds] - if self.has_data: - self.power_spectra = np.hstack(self.spectrograms[einds, :, :][:, :, winds]).T - - new_inds = range(0, len(self.group_results)) if self.group_results else None - output = super().get_group(new_inds, output_type) - - self._reset_group_results() - self._reset_data_results(clear_spectra=True) - - return output - - def convert_results(self, peak_org): """Convert the event results to be organized across events and time windows. diff --git a/specparam/objs/utils.py b/specparam/objs/utils.py index 537c19afc..f8aa27d1e 100644 --- a/specparam/objs/utils.py +++ b/specparam/objs/utils.py @@ -49,9 +49,9 @@ def _par_fit_group(power_spectrum, group): """Function to partialize for running in parallel - group.""" group._pass_through_spectrum(power_spectrum) - group._fit() + group.algorithm._fit() - return group._get_results() + return group.results._get_results() ## EVENT @@ -66,10 +66,10 @@ def run_parallel_event(model, data, n_jobs, progress): def _par_fit_event(spectrogram, model): """Function to partialize for running in parallel - event.""" - model.power_spectra = spectrogram.T + model.data.power_spectra = spectrogram.T model.fit() - return model.get_results() + return model.results.get_results() ################################################################################################### ## PROGRESS BARS diff --git a/specparam/plts/annotate.py b/specparam/plts/annotate.py index a9249fb7b..05c411d51 100644 --- a/specparam/plts/annotate.py +++ b/specparam/plts/annotate.py @@ -33,41 +33,44 @@ def plot_annotated_peak_search(model): # Recalculate the initial aperiodic fit and flattened spectrum that # is the same as the one that is used in the peak fitting procedure - flatspec = model.power_spectrum - \ - gen_aperiodic(model.freqs, - model._robust_ap_fit(model.freqs, model.power_spectrum), - model.aperiodic_mode.name) + flatspec = model.data.power_spectrum - \ + gen_aperiodic(model.data.freqs, + model.algorithm._robust_ap_fit(model.data.freqs, model.data.power_spectrum), + model.modes.aperiodic.name) # Calculate ylims of the plot that are scaled to the range of the data ylims = [min(flatspec) - 0.1 * np.abs(min(flatspec)), max(flatspec) + 0.1 * max(flatspec)] # Sort parameters by peak height - gaussian_params = model.gaussian_params_[model.gaussian_params_[:, 1].argsort()][::-1] + gaussian_params = model.results.gaussian_params_[model.results.gaussian_params_[:, 1].argsort()][::-1] # Loop through the iterative search for each peak - for ind in range(model.n_peaks_ + 1): + for ind in range(model.results.n_peaks_ + 1): # This forces the creation of a new plotting axes per iteration ax = check_ax(None, PLT_FIGSIZES['spectral']) - plot_spectra(model.freqs, flatspec, ax=ax, linewidth=2.5, - label='Flattened Spectrum', color=PLT_COLORS['data']) - plot_spectra(model.freqs, [model.peak_threshold * np.std(flatspec)]*len(model.freqs), ax=ax, - label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed') - plot_spectra(model.freqs, [model.min_peak_height]*len(model.freqs), ax=ax, - label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed') + plot_spectra(model.data.freqs, flatspec, linewidth=2.5, + label='Flattened Spectrum', color=PLT_COLORS['data'], ax=ax) + plot_spectra(model.data.freqs, + [model.algorithm.peak_threshold * np.std(flatspec)] * len(model.data.freqs), + label='Relative Threshold', color='orange', linewidth=2.5, + linestyle='dashed', ax=ax) + plot_spectra(model.data.freqs, [model.algorithm.min_peak_height]*len(model.data.freqs), + label='Absolute Threshold', color='red', linewidth=2.5, + linestyle='dashed', ax=ax) maxi = np.argmax(flatspec) - ax.plot(model.freqs[maxi], flatspec[maxi], '.', + ax.plot(model.data.freqs[maxi], flatspec[maxi], '.', color=PLT_COLORS['periodic'], alpha=0.75, markersize=30) ax.set_ylim(ylims) ax.set_title('Iteration #' + str(ind+1), fontsize=16) - if ind < model.n_peaks_: + if ind < model.results.n_peaks_: - gauss = gaussian_function(model.freqs, *gaussian_params[ind, :]) - plot_spectra(model.freqs, gauss, ax=ax, label='Gaussian Fit', + gauss = gaussian_function(model.data.freqs, *gaussian_params[ind, :]) + plot_spectra(model.data.freqs, gauss, ax=ax, label='Gaussian Fit', color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) flatspec = flatspec - gauss @@ -101,7 +104,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, """ # Check that model is available - if not model.has_model: + if not model.results.has_model: raise NoModelError("No model is available to plot, can not proceed.") # Settings @@ -123,7 +126,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, 'alpha' : 0.75, 'lw' : lw2}}) # Get freqs for plotting, and convert to log if needed - freqs = model.freqs if not plt_log else np.log10(model.freqs) + freqs = model.data.freqs if not plt_log else np.log10(model.data.freqs) ## Buffers: for spacing things out on the plot (scaled by plot values) x_buff1 = max(freqs) * 0.1 @@ -135,10 +138,10 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # See: https://github.com/matplotlib/matplotlib/issues/12820. Fixed in 3.2.1. bug_buff = 0.000001 - if annotate_peaks and model.n_peaks_: + if annotate_peaks and model.results.n_peaks_: # Extract largest peak, to annotate, grabbing gaussian params - gauss = get_band_peak(model, model.freq_range, attribute='gaussian_params') + gauss = get_band_peak(model, model.data.freq_range, attribute='gaussian_params') peak_ctr, peak_hgt, peak_wid = gauss bw_freqs = [peak_ctr - 0.5 * compute_fwhm(peak_wid), @@ -148,7 +151,7 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, peak_ctr = np.log10(peak_ctr) bw_freqs = np.log10(bw_freqs) - peak_top = model.power_spectrum[nearest_ind(freqs, peak_ctr)] + peak_top = model.data.power_spectrum[nearest_ind(freqs, peak_ctr)] # Annotate Peak CF ax.annotate('Center Frequency', @@ -182,24 +185,24 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Offset # Add a line to indicate offset, without adjusting plot limits below it ax.set_autoscaley_on(False) - ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.modeled_spectrum_[0]], + ax.plot([freqs[0], freqs[0]], [ax.get_ylim()[0], model.results.modeled_spectrum_[0]], color=PLT_COLORS['aperiodic'], linewidth=lw2, alpha=0.5) ax.annotate('Offset', - xy=(freqs[0]+bug_buff, model.power_spectrum[0]-y_buff1), - xytext=(freqs[0]-x_buff1, model.power_spectrum[0]-y_buff1), + xy=(freqs[0]+bug_buff, model.data.power_spectrum[0]-y_buff1), + xytext=(freqs[0]-x_buff1, model.data.power_spectrum[0]-y_buff1), verticalalignment='center', horizontalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['aperiodic'], shrink=shrink), color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Annotate Aperiodic Knee - if model.aperiodic_mode.name == 'knee': + if model.modes.aperiodic.name == 'knee': # Find the knee frequency point to annotate knee_freq = compute_knee_frequency(model.get_params('aperiodic', 'knee'), model.get_params('aperiodic', 'exponent')) knee_freq = np.log10(knee_freq) if plt_log else knee_freq - knee_pow = model.power_spectrum[nearest_ind(freqs, knee_freq)] + knee_pow = model.data.power_spectrum[nearest_ind(freqs, knee_freq)] # Add a dot to the plot indicating the knee frequency ax.plot(knee_freq, knee_pow, 'o', color=PLT_COLORS['aperiodic'], ms=ms1*1.5, alpha=0.7) @@ -214,8 +217,8 @@ def plot_annotated_model(model, plt_log=False, annotate_peaks=True, # Annotate Aperiodic Exponent mid_ind = int(len(freqs)/2) ax.annotate('Exponent', - xy=(freqs[mid_ind], model.power_spectrum[mid_ind]), - xytext=(freqs[mid_ind]-x_buff2, model.power_spectrum[mid_ind]-y_buff1), + xy=(freqs[mid_ind], model.data.power_spectrum[mid_ind]), + xytext=(freqs[mid_ind]-x_buff2, model.data.power_spectrum[mid_ind]-y_buff1), verticalalignment='center', arrowprops=dict(facecolor=PLT_COLORS['aperiodic'], shrink=shrink), color=PLT_COLORS['aperiodic'], fontsize=fontsize) diff --git a/specparam/plts/event.py b/specparam/plts/event.py index a5835691a..e105f55e8 100644 --- a/specparam/plts/event.py +++ b/specparam/plts/event.py @@ -22,12 +22,12 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_event_model(event_model, **plot_kwargs): +def plot_event_model(event, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeEventModel object. Parameters ---------- - event_model : SpectralTimeEventModel + event : SpectralTimeEventModel Object containing results from fitting power spectra across events. **plot_kwargs Keyword arguments to apply to the plot. @@ -38,14 +38,14 @@ def plot_event_model(event_model, **plot_kwargs): If the model object does not have model fit data available to plot. """ - if not event_model.has_model: + if not event.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") - pe_labels = get_periodic_labels(event_model.event_time_results) + pe_labels = get_periodic_labels(event.results.event_time_results) band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) - has_knee = 'knee' in event_model.event_time_results.keys() + has_knee = 'knee' in event.results.event_time_results.keys() height_ratios = [1] * (3 if has_knee else 2) + [0.25, 1, 1, 1, 1] * n_bands + [0.25] + [1, 1] axes = plot_kwargs.pop('axes', None) @@ -55,13 +55,13 @@ def plot_event_model(event_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 5 * n_bands])) axes = cycle(axes) - xlim = [0, event_model.n_time_windows - 1] + xlim = [0, event.data.n_time_windows - 1] # 01: aperiodic params alabels = ['offset', 'knee', 'exponent'] if has_knee else ['offset', 'exponent'] for alabel in alabels: plot_param_over_time_yshade(\ - None, event_model.event_time_results[alabel], + None, event.results.event_time_results[alabel], label=alabel, drop_xticks=True, add_xlabel=False, xlim=xlim, title='Aperiodic Parameters' if alabel == 'offset' else None, color=PARAM_COLORS[alabel], ax=next(axes)) @@ -71,12 +71,12 @@ def plot_event_model(event_model, **plot_kwargs): for band_ind in range(n_bands): for plabel in ['cf', 'pw', 'bw']: plot_param_over_time_yshade(None, \ - event_model.event_time_results[pe_labels[plabel][band_ind]], + event.results.event_time_results[pe_labels[plabel][band_ind]], label=plabel.upper(), drop_xticks=True, add_xlabel=False, xlim=xlim, title='Periodic Parameters - ' + band_labels[band_ind] if plabel == 'cf' else None, color=PARAM_COLORS[plabel], ax=next(axes)) plot_param_over_time_yshade(None, \ - compute_presence(event_model.event_time_results[pe_labels[plabel][band_ind]], + compute_presence(event.results.event_time_results[pe_labels[plabel][band_ind]], output='percent'), label='Presence (%)', drop_xticks=True, add_xlabel=False, xlim=xlim, color=PARAM_COLORS['presence'], ax=next(axes)) @@ -85,7 +85,7 @@ def plot_event_model(event_model, **plot_kwargs): # 03: goodness of fit for glabel in ['error', 'r_squared']: plot_param_over_time_yshade(\ - None, event_model.event_time_results[glabel], label=glabel, + None, event.results.event_time_results[glabel], label=glabel, drop_xticks=False if glabel == 'r_squared' else True, add_xlabel=True if glabel == 'r_squared' else False, title='Goodness of Fit' if glabel == 'error' else None, diff --git a/specparam/plts/group.py b/specparam/plts/group.py index e1aae0a94..8e94f036a 100644 --- a/specparam/plts/group.py +++ b/specparam/plts/group.py @@ -36,7 +36,7 @@ def plot_group_model(group, **plot_kwargs): If the model object does not have model fit data available to plot. """ - if not group.has_model: + if not group.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") fig = plt.figure(figsize=plot_kwargs.pop('figsize', PLT_FIGSIZES['group'])) @@ -75,12 +75,12 @@ def plot_group_aperiodic(group, ax=None, **plot_kwargs): Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ - if group.aperiodic_mode.name == 'knee': - plot_scatter_2(group.get_params('aperiodic_params', 'exponent'), 'Exponent', - group.get_params('aperiodic_params', 'knee'), 'Knee', + if group.modes.aperiodic.name == 'knee': + plot_scatter_2(group.results.get_params('aperiodic_params', 'exponent'), 'Exponent', + group.results.get_params('aperiodic_params', 'knee'), 'Knee', 'Aperiodic Fit', ax=ax) else: - plot_scatter_1(group.get_params('aperiodic_params', 'exponent'), 'Exponent', + plot_scatter_1(group.results.get_params('aperiodic_params', 'exponent'), 'Exponent', 'Aperiodic Fit', ax=ax) @@ -100,8 +100,8 @@ def plot_group_goodness(group, ax=None, **plot_kwargs): Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ - plot_scatter_2(group.get_params('error'), 'Error', - group.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax) + plot_scatter_2(group.results.get_params('error'), 'Error', + group.results.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax) @savefig @@ -120,5 +120,5 @@ def plot_group_peak_frequencies(group, ax=None, **plot_kwargs): Additional plot related keyword arguments, with styling options managed by ``style_plot``. """ - plot_hist(group.get_params('peak_params', 0)[:, 0], 'Center Frequency', - 'Peaks - Center Frequencies', x_lims=group.freq_range, ax=ax) + plot_hist(group.results.get_params('peak_params', 0)[:, 0], 'Center Frequency', + 'Peaks - Center Frequencies', x_lims=group.data.freq_range, ax=ax) diff --git a/specparam/plts/model.py b/specparam/plts/model.py index db093f02f..e313174d1 100644 --- a/specparam/plts/model.py +++ b/specparam/plts/model.py @@ -73,21 +73,21 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp log_powers = False # Plot the data, if available - if model.has_data or custom_spectrum: + if model.data.has_data or custom_spectrum: data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, 'label' : 'Original Spectrum' if add_legend else None} data_kwargs = check_plot_kwargs(data_kwargs, data_defaults) - plot_spectra(freqs if custom_spectrum else model.freqs, - power_spectrum if custom_spectrum else model.power_spectrum, + plot_spectra(freqs if custom_spectrum else model.data.freqs, + power_spectrum if custom_spectrum else model.data.power_spectrum, log_freqs, log_powers if not custom_spectrum else True, freq_range, ax=ax, **data_kwargs) # Add the full model fit, and components (if requested) - if model.has_model: + if model.results.has_model: model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, 'label' : 'Full Model Fit' if add_legend else None} model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) - plot_spectra(model.freqs, model.modeled_spectrum_, + plot_spectra(model.data.freqs, model.results.modeled_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit @@ -96,7 +96,7 @@ def plot_model(model, plot_peaks=None, plot_aperiodic=True, freqs=None, power_sp 'alpha' : 0.5, 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None} aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) - plot_spectra(model.freqs, model._ap_fit, + plot_spectra(model.data.freqs, model.results._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit @@ -169,12 +169,12 @@ def _add_peaks_shade(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.gaussian_params_: + for peak in model.results.gaussian_params_: - peak_freqs = np.log10(model.freqs) if plt_log else model.freqs - peak_line = model._ap_fit + gen_periodic(model.freqs, peak) + peak_freqs = np.log10(model.data.freqs) if plt_log else model.data.freqs + peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) - ax.fill_between(peak_freqs, peak_line, model._ap_fit, **plot_kwargs) + ax.fill_between(peak_freqs, peak_line, model.results._ap_fit, **plot_kwargs) def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): @@ -195,9 +195,9 @@ def _add_peaks_dot(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.peak_params_: + for peak in model.results.peak_params_: - ap_point = np.interp(peak[0], model.freqs, model._ap_fit) + ap_point = np.interp(peak[0], model.data.freqs, model.results._ap_fit) freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak @@ -225,14 +225,14 @@ def _add_peaks_outline(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.gaussian_params_: + for peak in model.results.gaussian_params_: # Define the frequency range around each peak to plot - peak bandwidth +/- 3 peak_range = [peak[0] - peak[2]*3, peak[0] + peak[2]*3] # Generate a peak reconstruction for each peak, and trim to desired range - peak_line = model._ap_fit + gen_periodic(model.freqs, peak) - peak_freqs, peak_line = trim_spectrum(model.freqs, peak_line, peak_range) + peak_line = model.results._ap_fit + gen_periodic(model.data.freqs, peak) + peak_freqs, peak_line = trim_spectrum(model.data.freqs, peak_line, peak_range) # Plot the peak outline peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs @@ -259,7 +259,7 @@ def _add_peaks_line(model, plt_log, ax, **plot_kwargs): ylims = ax.get_ylim() - for peak in model.peak_params_: + for peak in model.results.peak_params_: freq_point = np.log10(peak[0]) if plt_log else peak[0] ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) @@ -289,9 +289,9 @@ def _add_peaks_width(model, plt_log, ax, **plot_kwargs): defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) - for peak in model.gaussian_params_: + for peak in model.results.gaussian_params_: - peak_top = model.power_spectrum[nearest_ind(model.freqs, peak[0])] + peak_top = model.data.power_spectrum[nearest_ind(model.data.freqs, peak[0])] bw_freqs = [peak[0] - 0.5 * compute_fwhm(peak[2]), peak[0] + 0.5 * compute_fwhm(peak[2])] diff --git a/specparam/plts/time.py b/specparam/plts/time.py index 2788a6c65..5b99e87d8 100644 --- a/specparam/plts/time.py +++ b/specparam/plts/time.py @@ -21,12 +21,12 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_time_model(time_model, **plot_kwargs): +def plot_time_model(time, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a SpectralTimeModel object. Parameters ---------- - time_model : SpectralTimeModel + time : SpectralTimeModel Object containing results from fitting power spectra across time windows. **plot_kwargs Keyword arguments to apply to the plot. @@ -37,11 +37,11 @@ def plot_time_model(time_model, **plot_kwargs): If the model object does not have model fit data available to plot. """ - if not time_model.has_model: + if not time.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") # Check band structure - pe_labels = get_periodic_labels(time_model.time_results) + pe_labels = get_periodic_labels(time.results.time_results) band_labels = get_band_labels(pe_labels) n_bands = len(pe_labels['cf']) @@ -52,16 +52,16 @@ def plot_time_model(time_model, **plot_kwargs): figsize=plot_kwargs.pop('figsize', [10, 4 + 2 * n_bands])) axes = cycle(axes) - xlim = [0, time_model.n_time_windows - 1] + xlim = [0, time.data.n_time_windows - 1] # 01: aperiodic parameters - ap_params = [time_model.time_results['offset'], - time_model.time_results['exponent']] + ap_params = [time.results.time_results['offset'], + time.results.time_results['exponent']] ap_labels = ['Offset', 'Exponent'] ap_colors = [PARAM_COLORS['offset'], PARAM_COLORS['exponent']] - if 'knee' in time_model.time_results.keys(): - ap_params.insert(1, time_model.time_results['knee']) + if 'knee' in time.results.time_results.keys(): + ap_params.insert(1, time.results.time_results['knee']) ap_labels.insert(1, 'Knee') ap_colors.insert(1, PARAM_COLORS['knee']) @@ -72,17 +72,17 @@ def plot_time_model(time_model, **plot_kwargs): for band_ind in range(n_bands): plot_params_over_time(\ None, - [time_model.time_results[pe_labels['cf'][band_ind]], - time_model.time_results[pe_labels['pw'][band_ind]], - time_model.time_results[pe_labels['bw'][band_ind]]], + [time.results.time_results[pe_labels['cf'][band_ind]], + time.results.time_results[pe_labels['pw'][band_ind]], + time.results.time_results[pe_labels['bw'][band_ind]]], labels=['CF', 'PW', 'BW'], add_xlabel=False, xlim=xlim, colors=[PARAM_COLORS['cf'], PARAM_COLORS['pw'], PARAM_COLORS['bw']], title='Periodic Parameters - ' + band_labels[band_ind], ax=next(axes)) # 03: goodness of fit plot_params_over_time(None, - [time_model.time_results['error'], - time_model.time_results['r_squared']], + [time.results.time_results['error'], + time.results.time_results['r_squared']], labels=['Error', 'R-squared'], xlim=xlim, colors=[PARAM_COLORS['error'], PARAM_COLORS['r_squared']], title='Goodness of Fit', ax=next(axes)) diff --git a/specparam/reports/save.py b/specparam/reports/save.py index a9fb9bb62..dbd5f73b7 100644 --- a/specparam/reports/save.py +++ b/specparam/reports/save.py @@ -116,12 +116,12 @@ def save_group_report(group, file_name, file_path=None, add_settings=True): @check_dependency(plt, 'matplotlib') -def save_time_report(time_model, file_name, file_path=None, add_settings=True): +def save_time_report(time, file_name, file_path=None, add_settings=True): """Generate and save out a PDF report for models of a spectrogram. Parameters ---------- - time_model : SpectralTimeModel + time : SpectralTimeModel Object with results from fitting a spectrogram. file_name : str Name to give the saved out file. @@ -132,7 +132,7 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): """ # Check model object for number of bands, to decide report size - pe_labels = get_periodic_labels(time_model.time_results) + pe_labels = get_periodic_labels(time.results.time_results) n_bands = len(pe_labels['cf']) # Initialize figure, defining number of axes based on model + what is to be plotted @@ -143,14 +143,14 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): figsize=REPORT_FIGSIZE) # First / top: text results - plot_text(gen_time_results_str(time_model), 0.5, 0.7, ax=axes[0]) + plot_text(gen_time_results_str(time), 0.5, 0.7, ax=axes[0]) # Second - data plots - time_model.plot(axes=axes[1:2+n_bands+1]) + time.plot(axes=axes[1:2+n_bands+1]) # Third - Model settings if add_settings: - plot_text(gen_settings_str(time_model, False), 0.5, 0.1, ax=axes[-1]) + plot_text(gen_settings_str(time, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(create_file_path(file_name, file_path, SAVE_FORMAT)) @@ -158,12 +158,12 @@ def save_time_report(time_model, file_name, file_path=None, add_settings=True): @check_dependency(plt, 'matplotlib') -def save_event_report(event_model, file_name, file_path=None, add_settings=True): +def save_event_report(event, file_name, file_path=None, add_settings=True): """Generate and save out a PDF report for models of a set of events. Parameters ---------- - event_model : SpectralTimeEventModel + event : SpectralTimeEventModel Object with results from fitting a group of power spectra. file_name : str Name to give the saved out file. @@ -174,9 +174,9 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) """ # Check model object for number of bands & aperiodic mode, to decide report size - pe_labels = get_periodic_labels(event_model.event_time_results) + pe_labels = get_periodic_labels(event.results.event_time_results) n_bands = len(pe_labels['cf']) - has_knee = 'knee' in event_model.event_time_results.keys() + has_knee = 'knee' in event.results.event_time_results.keys() # Initialize figure, defining number of axes based on model + what is to be plotted n_rows = 1 + (4 if has_knee else 3) + (n_bands * 5) + 2 + (1 if add_settings else 0) @@ -187,14 +187,14 @@ def save_event_report(event_model, file_name, file_path=None, add_settings=True) figsize=(REPORT_FIGSIZE[0], REPORT_FIGSIZE[1] + 7)) # First / top: text results - plot_text(gen_event_results_str(event_model), 0.5, 0.7, ax=axes[0]) + plot_text(gen_event_results_str(event), 0.5, 0.7, ax=axes[0]) # Second - data plots - event_model.plot(axes=axes[1:-1]) + event.plot(axes=axes[1:-1]) # Third - Model settings if add_settings: - plot_text(gen_settings_str(event_model, False), 0.5, 0.1, ax=axes[-1]) + plot_text(gen_settings_str(event, False), 0.5, 0.1, ax=axes[-1]) # Save out the report plt.savefig(create_file_path(file_name, file_path, SAVE_FORMAT)) diff --git a/specparam/reports/strings.py b/specparam/reports/strings.py index 38a42ef85..7f90893ad 100644 --- a/specparam/reports/strings.py +++ b/specparam/reports/strings.py @@ -86,12 +86,12 @@ def gen_version_str(concise=False): return output -def gen_modes_str(model_obj, description=False, concise=False): +def gen_modes_str(model, description=False, concise=False): """Generate a string representation of fit modes. Parameters ---------- - model_obj : SpectralModel or SpectralGroupModel or ModelModes + model : SpectralModel or Spectral*Model or ModelModes Object to access fit modes from. description : bool, optional, default: False Whether to also print out a description of the fit modes. @@ -123,9 +123,9 @@ def gen_modes_str(model_obj, description=False, concise=False): '', # Settings - include descriptions if requested - *[el for el in ['Periodic Mode : {}'.format(model_obj.periodic_mode.name), + *[el for el in ['Periodic Mode : {}'.format(model.modes.periodic.name), '{}'.format(desc['aperiodic_mode']), - 'Aperiodic Mode : {}'.format(model_obj.aperiodic_mode.name), + 'Aperiodic Mode : {}'.format(model.modes.aperiodic.name), '{}'.format(desc['aperiodic_mode'])] if el != ''], # Footer @@ -137,12 +137,13 @@ def gen_modes_str(model_obj, description=False, concise=False): return output -def gen_settings_str(model_obj, description=False, concise=False): + +def gen_settings_str(model, description=False, concise=False): """Generate a string representation of current fit settings. Parameters ---------- - model_obj : SpectralModel or SpectralGroupModel or ModelSettings + model : SpectralModel or Spectral*Model or ModelSettings Object to access settings from. description : bool, optional, default: False Whether to also print out a description of the settings. @@ -159,9 +160,9 @@ def gen_settings_str(model_obj, description=False, concise=False): desc = { 'peak_width_limits' : 'Limits for minimum and maximum peak widths, in Hz.', 'max_n_peaks' : 'Maximum number of peaks that can be extracted.', - 'min_peak_height' : 'Minimum absolute height of a peak, above the aperiodic component.', + 'min_peak_height' : 'Minimum absolute height of a peak above the aperiodic component.', 'peak_threshold' : 'Relative threshold for minimum height required for detecting peaks.', - } + } # Clear description for printing, if not requested if not description: @@ -177,13 +178,13 @@ def gen_settings_str(model_obj, description=False, concise=False): '', # Settings - include descriptions if requested - *[el for el in ['Peak Width Limits : {}'.format(model_obj.peak_width_limits), + *[el for el in ['Peak Width Limits : {}'.format(model.algorithm.peak_width_limits), '{}'.format(desc['peak_width_limits']), - 'Max Number of Peaks : {}'.format(model_obj.max_n_peaks), + 'Max Number of Peaks : {}'.format(model.algorithm.max_n_peaks), '{}'.format(desc['max_n_peaks']), - 'Minimum Peak Height : {}'.format(model_obj.min_peak_height), + 'Minimum Peak Height : {}'.format(model.algorithm.min_peak_height), '{}'.format(desc['min_peak_height']), - 'Peak Threshold: {}'.format(model_obj.peak_threshold), + 'Peak Threshold: {}'.format(model.algorithm.peak_threshold), '{}'.format(desc['peak_threshold'])] if el != ''], # Footer @@ -196,12 +197,12 @@ def gen_settings_str(model_obj, description=False, concise=False): return output -def gen_freq_range_str(model_obj, concise=False): +def gen_freq_range_str(model, concise=False): """Generate a string representation of the fit range that was used for the model. Parameters ---------- - model_obj : SpectralModel or SpectralGroupModel + model : SpectralModel or Spectral*Model Object to access settings from. concise : bool, optional, default: False Whether to print the report in concise mode. @@ -211,7 +212,7 @@ def gen_freq_range_str(model_obj, concise=False): If fit range is not available, will print out 'XX' for missing values. """ - freq_range = model_obj.freq_range if model_obj.has_data else ('XX', 'XX') + freq_range = model.data.freq_range if model.data.has_data else ('XX', 'XX') str_lst = [ @@ -275,12 +276,12 @@ def gen_methods_report_str(concise=False): return output -def gen_methods_text_str(model_obj=None): +def gen_methods_text_str(model=None): """Generate a string representation of a template methods report. Parameters ---------- - model_obj : SpectralModel or SpectralGroupModel, optional + model : SpectralModel or Spectral*Model, optional A model object with settings information available. If None, the text is returned as a template, without values. """ @@ -298,18 +299,18 @@ def gen_methods_text_str(model_obj=None): "{} to {} Hz." ) - if model_obj: - freq_range = model_obj.freq_range if model_obj.has_data else ('XX', 'XX') + if model: + freq_range = model.data.freq_range if model.data.has_data else ('XX', 'XX') else: freq_range = ('XX', 'XX') methods_str = template.format(MODULE_VERSION, - model_obj.aperiodic_mode.name if model_obj else 'XX', - model_obj.periodic_mode.name if model_obj else 'XX', - model_obj.peak_width_limits if model_obj else 'XX', - model_obj.max_n_peaks if model_obj else 'XX', - model_obj.min_peak_height if model_obj else 'XX', - model_obj.peak_threshold if model_obj else 'XX', + model.modes.aperiodic.name if model else 'XX', + model.modes.periodic.name if model else 'XX', + model.algorithm.peak_width_limits if model else 'XX', + model.algorithm.max_n_peaks if model else 'XX', + model.algorithm.min_peak_height if model else 'XX', + model.algorithm.peak_threshold if model else 'XX', *freq_range) return methods_str @@ -332,7 +333,7 @@ def gen_model_results_str(model, concise=False): """ # Returns a null report if no results are available - if np.all(np.isnan(model.aperiodic_params_)): + if np.all(np.isnan(model.results.aperiodic_params_)): return _no_model_str(concise) # Create the formatted strings for printing @@ -346,28 +347,29 @@ def gen_model_results_str(model, concise=False): # Frequency range and resolution 'The model was run on the frequency range {} - {} Hz'.format( - int(np.floor(model.freq_range[0])), int(np.ceil(model.freq_range[1]))), - 'Frequency Resolution is {:1.2f} Hz'.format(model.freq_res), + int(np.floor(model.data.freq_range[0])), int(np.ceil(model.data.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(model.data.freq_res), '', # Aperiodic parameters ('Aperiodic Parameters (offset, ' + \ - ('knee, ' if model.aperiodic_mode.name == 'knee' else '') + \ + ('knee, ' if model.modes.aperiodic.name == 'knee' else '') + \ 'exponent): '), - ', '.join(['{:2.4f}'] * len(model.aperiodic_params_)).format(*model.aperiodic_params_), + ', '.join(['{:2.4f}'] * \ + len(model.results.aperiodic_params_)).format(*model.results.aperiodic_params_), '', # Peak parameters '{} peaks were found:'.format( - len(model.peak_params_)), + len(model.results.peak_params_)), *['CF: {:6.2f}, PW: {:6.3f}, BW: {:5.2f}'.format(op[0], op[1], op[2]) \ - for op in model.peak_params_], + for op in model.results.peak_params_], '', # Goodness if fit 'Goodness of fit metrics:', - 'R^2 of model fit is {:5.4f}'.format(model.r_squared_), - 'Error of the fit is {:5.4f}'.format(model.error_), + 'R^2 of model fit is {:5.4f}'.format(model.results.r_squared_), + 'Error of the fit is {:5.4f}'.format(model.results.error_), '', # Footer @@ -400,16 +402,16 @@ def gen_group_results_str(group, concise=False): If no model fit data is available to report. """ - if not group.has_model: + if not group.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") # Extract all the relevant data for printing - n_peaks = len(group.get_params('peak_params')) - r2s = group.get_params('r_squared') - errors = group.get_params('error') - exps = group.get_params('aperiodic_params', 'exponent') - kns = group.get_params('aperiodic_params', 'knee') \ - if str(group.aperiodic_mode) == 'knee' else np.array([0]) + n_peaks = len(group.results.get_params('peak_params')) + r2s = group.results.get_params('r_squared') + errors = group.results.get_params('error') + exps = group.results.get_params('aperiodic_params', 'exponent') + kns = group.results.get_params('aperiodic_params', 'knee') \ + if group.modes.aperiodic.name == 'knee' else np.array([0]) str_lst = [ @@ -420,24 +422,25 @@ def gen_group_results_str(group, concise=False): '', # Group information - 'Number of power spectra in the Group: {}'.format(len(group.group_results)), - *[el for el in ['{} power spectra failed to fit'.format(group.n_null_)] if group.n_null_], + 'Number of power spectra in the Group: {}'.format(len(group.results.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(\ + group.results.n_null_)] if group.results.n_null_], '', # Frequency range and resolution 'The model was run on the frequency range {} - {} Hz'.format( - int(np.floor(group.freq_range[0])), int(np.ceil(group.freq_range[1]))), - 'Frequency Resolution is {:1.2f} Hz'.format(group.freq_res), + int(np.floor(group.data.freq_range[0])), int(np.ceil(group.data.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(group.data.freq_res), '', # Aperiodic parameters - knee fit status, and quick exponent description 'Power spectra were fit {} a knee.'.format(\ - 'with' if str(group.aperiodic_mode) == 'knee' else 'without'), + 'with' if group.modes.aperiodic.name == 'knee' else 'without'), '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:5.2f}' .format(*compute_arr_desc(kns)), - ] if group.aperiodic_mode.name == 'knee'], + ] if group.modes.aperiodic.name == 'knee'], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' .format(*compute_arr_desc(exps)), @@ -465,12 +468,12 @@ def gen_group_results_str(group, concise=False): return output -def gen_time_results_str(time_model, concise=False): +def gen_time_results_str(time, concise=False): """Generate a string representation of time fit results. Parameters ---------- - time_model : SpectralTimeModel + time : SpectralTimeModel Object to access results from. concise : bool, optional, default: False Whether to print the report in concise mode. @@ -486,15 +489,15 @@ def gen_time_results_str(time_model, concise=False): If no model fit data is available to report. """ - if not time_model.has_model: + if not time.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") # Get parameter information needed for printing - pe_labels = get_periodic_labels(time_model.time_results) + pe_labels = get_periodic_labels(time.results.time_results) band_labels = [\ pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ for band_ind in range(len(pe_labels['cf']))] - has_knee = time_model.aperiodic_mode.name == 'knee' + has_knee = time.modes.aperiodic.name == 'knee' str_lst = [ @@ -505,47 +508,48 @@ def gen_time_results_str(time_model, concise=False): '', # Group information - 'Number of time windows fit: {}'.format(len(time_model.group_results)), - *[el for el in ['{} power spectra failed to fit'.format(time_model.n_null_)] \ - if time_model.n_null_], + 'Number of time windows fit: {}'.format(len(time.results.group_results)), + *[el for el in ['{} power spectra failed to fit'.format(time.results.n_null_)] \ + if time.results.n_null_], '', # Frequency range and resolution 'The model was run on the frequency range {} - {} Hz'.format( - int(np.floor(time_model.freq_range[0])), int(np.ceil(time_model.freq_range[1]))), - 'Frequency Resolution is {:1.2f} Hz'.format(time_model.freq_res), + int(np.floor(time.data.freq_range[0])), + int(np.ceil(time.data.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(time.data.freq_res), '', # Aperiodic parameters - knee fit status, and quick exponent description 'Power spectra were fit {} a knee.'.format(\ - 'with' if time_model.aperiodic_mode.name == 'knee' else 'without'), + 'with' if time.modes.aperiodic.name == 'knee' else 'without'), '', 'Aperiodic Fit Values:', *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(*compute_arr_desc(time_model.time_results['knee']) \ + .format(*compute_arr_desc(time.results.time_results['knee']) \ if has_knee else [0, 0, 0]), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(time_model.time_results['exponent'])), + .format(*compute_arr_desc(time.results.time_results['exponent'])), '', # Periodic parameters 'Periodic params (mean values across windows):', *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( label, - np.nanmean(time_model.time_results[pe_labels['cf'][ind]]), - np.nanmean(time_model.time_results[pe_labels['pw'][ind]]), - np.nanmean(time_model.time_results[pe_labels['bw'][ind]]), - compute_presence(time_model.time_results[pe_labels['cf'][ind]], output='percent')) + np.nanmean(time.results.time_results[pe_labels['cf'][ind]]), + np.nanmean(time.results.time_results[pe_labels['pw'][ind]]), + np.nanmean(time.results.time_results[pe_labels['bw'][ind]]), + compute_presence(time.results.time_results[pe_labels['cf'][ind]], output='percent')) for ind, label in enumerate(band_labels)], '', # Goodness if fit 'Goodness of fit (mean values across windows):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(time_model.time_results['r_squared'])), + .format(*compute_arr_desc(time.results.time_results['r_squared'])), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(time_model.time_results['error'])), + .format(*compute_arr_desc(time.results.time_results['error'])), '', # Footer @@ -557,12 +561,12 @@ def gen_time_results_str(time_model, concise=False): return output -def gen_event_results_str(event_model, concise=False): +def gen_event_results_str(event, concise=False): """Generate a string representation of event fit results. Parameters ---------- - event_model : SpectralTimeEventModel + event : SpectralTimeEventModel Object to access results from. concise : bool, optional, default: False Whether to print the report in concise mode. @@ -578,15 +582,15 @@ def gen_event_results_str(event_model, concise=False): If no model fit data is available to report. """ - if not event_model.has_model: + if not event.results.has_model: raise NoModelError("No model fit results are available, can not proceed.") # Extract all the relevant data for printing - pe_labels = get_periodic_labels(event_model.event_time_results) + pe_labels = get_periodic_labels(event.results.event_time_results) band_labels = [\ pe_labels['cf'][band_ind].split('_')[-1 if pe_labels['cf'][-2:] == 'cf' else 0] \ for band_ind in range(len(pe_labels['cf']))] - has_knee = event_model.aperiodic_mode.name == 'knee' + has_knee = event.modes.aperiodic.name == 'knee' str_lst = [ @@ -597,36 +601,37 @@ def gen_event_results_str(event_model, concise=False): '', # Group information - 'Number of events fit: {}'.format(len(event_model.event_group_results)), + 'Number of events fit: {}'.format(len(event.results.event_group_results)), '', # Frequency range and resolution 'The model was run on the frequency range {} - {} Hz'.format( - int(np.floor(event_model.freq_range[0])), int(np.ceil(event_model.freq_range[1]))), - 'Frequency Resolution is {:1.2f} Hz'.format(event_model.freq_res), + int(np.floor(event.data.freq_range[0])), + int(np.ceil(event.data.freq_range[1]))), + 'Frequency Resolution is {:1.2f} Hz'.format(event.data.freq_res), '', # Aperiodic parameters - knee fit status, and quick exponent description 'Power spectra were fit {} a knee.'.format(\ - 'with' if event_model.aperiodic_mode.name == 'knee' else 'without'), + 'with' if event.modes.aperiodic.name == 'knee' else 'without'), '', 'Aperiodic params (values across events):', - *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}' - .format(*compute_arr_desc(np.mean(event_model.event_time_results['knee'], 1) \ + *[el for el in [' Knees - Min: {:6.2f}, Max: {:6.2f}, Mean: {:6.2f}'.format(\ + *compute_arr_desc(np.mean(event.results.event_time_results['knee'], 1) \ if has_knee else [0, 0, 0])), ] if has_knee], 'Exponents - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(np.mean(event_model.event_time_results['exponent'], 1))), + .format(*compute_arr_desc(np.mean(event.results.event_time_results['exponent'], 1))), '', # Periodic parameters 'Periodic params (mean values across events):', *['{:>6s} - CF: {:5.2f}, PW: {:5.2f}, BW: {:5.2f}, Presence: {:3.1f}%'.format( label, - np.nanmean(event_model.event_time_results[pe_labels['cf'][ind]]), - np.nanmean(event_model.event_time_results[pe_labels['pw'][ind]]), - np.nanmean(event_model.event_time_results[pe_labels['bw'][ind]]), - compute_presence(event_model.event_time_results[pe_labels['cf'][ind]], + np.nanmean(event.results.event_time_results[pe_labels['cf'][ind]]), + np.nanmean(event.results.event_time_results[pe_labels['pw'][ind]]), + np.nanmean(event.results.event_time_results[pe_labels['bw'][ind]]), + compute_presence(event.results.event_time_results[pe_labels['cf'][ind]], average=True, output='percent')) for ind, label in enumerate(band_labels)], '', @@ -634,10 +639,10 @@ def gen_event_results_str(event_model, concise=False): # Goodness if fit 'Goodness of fit (values across events):', ' R2s - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(np.mean(event_model.event_time_results['r_squared'], 1))), + .format(*compute_arr_desc(np.mean(event.results.event_time_results['r_squared'], 1))), 'Errors - Min: {:6.3f}, Max: {:6.3f}, Mean: {:5.3f}' - .format(*compute_arr_desc(np.mean(event_model.event_time_results['error'], 1))), + .format(*compute_arr_desc(np.mean(event.results.event_time_results['error'], 1))), '', # Footer diff --git a/specparam/tests/algorithms/test_algorithm.py b/specparam/tests/algorithms/test_algorithm.py index 2a9d8d21d..405aa9f68 100644 --- a/specparam/tests/algorithms/test_algorithm.py +++ b/specparam/tests/algorithms/test_algorithm.py @@ -1,26 +1,13 @@ """Tests for specparam.algorthms.algorithm.""" +from specparam.data import ModelSettings +from specparam.modes.items import OBJ_DESC + from specparam.algorithms.algorithm import * ################################################################################################### ################################################################################################### -def test_settings_definition(): - - tsettings = { - 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, - 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, - } - - settings = SettingsDefinition(tsettings) - assert settings._settings == tsettings - assert settings.names == list(tsettings.keys()) - assert settings.types - assert settings.descriptions - for label in tsettings.keys(): - assert settings.make_setting_str(label) - assert settings.make_docstring() - def test_algorithm_definition(): tname = 'test_algo' @@ -48,3 +35,23 @@ def test_algorithm(): algo = Algorithm(name=tname, description=tdescription, settings=tsettings) assert algo assert isinstance(algo.algorithm, AlgorithmDefinition) + +def test_algorithm_settings(): + + tname = 'test_algo' + tdescription = 'Test algorithm description' + tsettings = { + 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, + 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, + } + + talgo = Algorithm(name=tname, description=tdescription, settings=tsettings) + + settings = ModelSettings([1, 4], 6, 0, 2) + talgo.add_settings(settings) + for setting in OBJ_DESC['settings']: + assert getattr(talgo, setting) == getattr(settings, setting) + + settings_out = talgo.get_settings() + assert isinstance(settings, ModelSettings) + assert settings_out == settings diff --git a/specparam/tests/algorithms/test_settings.py b/specparam/tests/algorithms/test_settings.py new file mode 100644 index 000000000..30501a863 --- /dev/null +++ b/specparam/tests/algorithms/test_settings.py @@ -0,0 +1,22 @@ +"""Tests for specparam.algorthms.settings.""" + +from specparam.algorithms.settings import * + +################################################################################################### +################################################################################################### + +def test_settings_definition(): + + tsettings = { + 'a' : {'type' : 'a type desc', 'description' : 'a desc'}, + 'b' : {'type' : 'b type desc', 'description' : 'b desc'}, + } + + settings = SettingsDefinition(tsettings) + assert settings._settings == tsettings + assert settings.names == list(tsettings.keys()) + assert settings.types + assert settings.descriptions + for label in tsettings.keys(): + assert settings.make_setting_str(label) + assert settings.make_docstring() diff --git a/specparam/tests/algorithms/test_spectral_fit.py b/specparam/tests/algorithms/test_spectral_fit.py index ea4941e30..8c26f9177 100644 --- a/specparam/tests/algorithms/test_spectral_fit.py +++ b/specparam/tests/algorithms/test_spectral_fit.py @@ -1,8 +1,9 @@ """Tests for specparam.algorthms.spectral_fit.""" +from specparam.modes.modes import Modes from specparam.objs.base import BaseObject from specparam.sim import sim_power_spectrum -from specparam.algorithms.algorithm import AlgorithmDefinition +from specparam.algorithms.algorithm import Algorithm, AlgorithmDefinition from specparam.tests.tdata import default_spectrum_params @@ -15,9 +16,12 @@ def test_algorithm_inherit(): class TestAlgo(SpectralFitAlgorithm, BaseObject): def __init__(self): - BaseObject.__init__(self, aperiodic_mode='fixed', periodic_mode='gaussian') - SpectralFitAlgorithm.__init__(self) + self.modes = Modes(aperiodic='fixed', periodic='gaussian') + BaseObject.__init__(self) + self.algorithm = SpectralFitAlgorithm(\ + data=self.data, results=self.results, modes=self.modes) talgo = TestAlgo() - assert isinstance(talgo.algorithm, AlgorithmDefinition) + assert isinstance(talgo.algorithm, Algorithm) + assert isinstance(talgo.algorithm.algorithm, AlgorithmDefinition) talgo.fit(*sim_power_spectrum(*default_spectrum_params())) diff --git a/specparam/tests/core/__init__.py b/specparam/tests/core/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/specparam/tests/io/test_models.py b/specparam/tests/io/test_models.py index f9232a0cd..b6e12f455 100644 --- a/specparam/tests/io/test_models.py +++ b/specparam/tests/io/test_models.py @@ -10,6 +10,7 @@ from specparam import (SpectralModel, SpectralGroupModel, SpectralTimeModel, SpectralTimeEventModel) from specparam.modes.items import OBJ_DESC +from specparam.modes.modes import Modes from specparam.io.files import load_json from specparam.tests.tsettings import TEST_DATA_PATH @@ -125,14 +126,14 @@ def test_save_event(tfe): save_event(tfe, file_name=dat_file_name, file_path=TEST_DATA_PATH, save_data=True) assert os.path.exists(TEST_DATA_PATH / (set_file_name + '.json')) - for ind in range(len(tfe)): + for ind in range(len(tfe.results)): assert os.path.exists(TEST_DATA_PATH / (res_file_name + '_' + str(ind) + '.json')) assert os.path.exists(TEST_DATA_PATH / (dat_file_name + '_' + str(ind) + '.json')) # Test saving out all save elements file_name_all = 'test_event_all' save_event(tfe, file_name_all, TEST_DATA_PATH, False, True, True, True) - for ind in range(len(tfe)): + for ind in range(len(tfe.results)): assert os.path.exists(TEST_DATA_PATH / (file_name_all + '_' + str(ind) + '.json')) def test_load_file_contents(): @@ -160,14 +161,21 @@ def test_load_model(): assert isinstance(tfm, SpectralModel) # Check that all elements get loaded + assert isinstance(tfm.modes, Modes) for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) + assert not np.all(np.isnan(getattr(tfm.results, result))) for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None + assert getattr(tfm.algorithm, setting) is not None for data in OBJ_DESC['data']: - assert getattr(tfm, data) is not None + assert getattr(tfm.data, data) is not None for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) is not None + assert getattr(tfm.data, meta_dat) is not None + + # Check directory matches (loading didn't add any unexpected attributes) + cfm = SpectralModel() + assert dir(cfm) == dir(tfm) + assert dir(cfm.data) == dir(tfm.data) + assert dir(cfm.results) == dir(tfm.results) def test_load_group(): @@ -179,12 +187,18 @@ def test_load_group(): assert isinstance(tfg, SpectralGroupModel) # Check that all elements get loaded - assert len(tfg.group_results) > 0 + assert len(tfg.results.group_results) > 0 for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None - assert tfg.power_spectra is not None + assert getattr(tfg.algorithm, setting) is not None + assert tfg.data.power_spectra is not None for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfg, meta_dat) is not None + assert getattr(tfg.data, meta_dat) is not None + + # Check directory matches (loading didn't add any unexpected attributes) + cfg = SpectralGroupModel() + assert dir(cfg) == dir(tfg) + assert dir(cfg.data) == dir(tfg.data) + assert dir(cfg.results) == dir(tfg.results) def test_load_time(tbands): @@ -198,7 +212,13 @@ def test_load_time(tbands): # Load with bands definition tft2 = load_time(file_name, TEST_DATA_PATH, tbands) assert isinstance(tft2, SpectralTimeModel) - assert tft2.time_results + assert tft2.results.time_results + + # Check directory matches (loading didn't add any unexpected attributes) + cft = SpectralTimeModel() + assert dir(cft) == dir(tft2) + assert dir(cft.data) == dir(tft2.data) + assert dir(cft.results) == dir(tft2.results) def test_load_event(tbands): @@ -208,10 +228,16 @@ def test_load_event(tbands): # Load without bands definition tfe = load_event(file_name, TEST_DATA_PATH) assert isinstance(tfe, SpectralTimeEventModel) - assert len(tfe) > 1 + assert len(tfe.results) > 1 # Load with bands definition tfe2 = load_event(file_name, TEST_DATA_PATH, tbands) assert isinstance(tfe2, SpectralTimeEventModel) - assert tfe2.event_time_results - assert len(tfe2) > 1 + assert tfe2.results.event_time_results + assert len(tfe2.results) > 1 + + # Check directory matches (loading didn't add any unexpected attributes) + cfe = SpectralTimeEventModel() + assert dir(cfe) == dir(tfe2) + assert dir(cfe.data) == dir(tfe2.data) + assert dir(cfe.results) == dir(tfe2.results) diff --git a/specparam/tests/measures/test_error.py b/specparam/tests/measures/test_error.py index 52b272972..b48ac2254 100644 --- a/specparam/tests/measures/test_error.py +++ b/specparam/tests/measures/test_error.py @@ -9,26 +9,26 @@ def test_compute_mean_abs_error(tfm): - error = compute_mean_abs_error(tfm.power_spectrum, tfm.modeled_spectrum_) + error = compute_mean_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(error, float) def test_compute_mean_squared_error(tfm): - error = compute_mean_squared_error(tfm.power_spectrum, tfm.modeled_spectrum_) + error = compute_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(error, float) def test_compute_root_mean_squared_error(tfm): - error = compute_root_mean_squared_error(tfm.power_spectrum, tfm.modeled_spectrum_) + error = compute_root_mean_squared_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(error, float) def test_compute_median_abs_error(tfm): - error = compute_median_abs_error(tfm.power_spectrum, tfm.modeled_spectrum_) + error = compute_median_abs_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(error, float) def test_compute_error(tfm): for metric in ['mae', 'mse', 'rmse', 'medae']: - error = compute_error(tfm.power_spectrum, tfm.modeled_spectrum_) + error = compute_error(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(error, float) diff --git a/specparam/tests/measures/test_gof.py b/specparam/tests/measures/test_gof.py index 0a545b44c..4343c88f8 100644 --- a/specparam/tests/measures/test_gof.py +++ b/specparam/tests/measures/test_gof.py @@ -9,16 +9,16 @@ def test_compute_r_squared(tfm): - r_squared = compute_r_squared(tfm.power_spectrum, tfm.modeled_spectrum_) + r_squared = compute_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(r_squared, float) def test_compute_adj_r_squared(tfm): - r_squared = compute_adj_r_squared(tfm.power_spectrum, tfm.modeled_spectrum_, 5) + r_squared = compute_adj_r_squared(tfm.data.power_spectrum, tfm.results.modeled_spectrum_, 5) assert isinstance(r_squared, float) def test_compute_gof(tfm): for metric in ['r_squared', 'adj_r_squared']: - gof = compute_gof(tfm.power_spectrum, tfm.modeled_spectrum_) + gof = compute_gof(tfm.data.power_spectrum, tfm.results.modeled_spectrum_) assert isinstance(gof, float) diff --git a/specparam/tests/models/test_event.py b/specparam/tests/models/test_event.py index 37a21a27d..628ed883e 100644 --- a/specparam/tests/models/test_event.py +++ b/specparam/tests/models/test_event.py @@ -8,6 +8,7 @@ import numpy as np +from specparam.models import SpectralModel, SpectralGroupModel, SpectralTimeModel from specparam.sim import sim_spectrogram from specparam.modutils.dependencies import safe_import @@ -31,16 +32,16 @@ def test_event_model(): def test_event_getitem(tfe): - assert tfe[0] + assert tfe.results[0] def test_event_iter(tfe): - for out in tfe: + for out in tfe.results: assert out def test_event_n_peaks(tfe): - assert np.all(tfe.n_peaks_) + assert np.all(tfe.results.n_peaks_) def test_event_fit(): @@ -50,7 +51,7 @@ def test_event_fit(): tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys) - results = tfe.get_results() + results = tfe.results.get_results() assert results assert isinstance(results, dict) for key in results.keys(): @@ -66,7 +67,7 @@ def test_event_fit_par(): tfe = SpectralTimeEventModel(verbose=False) tfe.fit(xs, ys, n_jobs=2) - results = tfe.get_results() + results = tfe.results.get_results() assert results assert isinstance(results, dict) for key in results.keys(): @@ -102,17 +103,17 @@ def test_event_load(tbands): # Test loading results tfe = SpectralTimeEventModel(verbose=False) tfe.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) - assert tfe.event_time_results + assert tfe.results.event_time_results # Test loading settings tfe = SpectralTimeEventModel(verbose=False) tfe.load(file_name_set, TEST_DATA_PATH) - assert tfe.get_settings() + assert tfe.algorithm.get_settings() # Test loading data tfe = SpectralTimeEventModel(verbose=False) tfe.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tfe.spectrograms) + assert np.all(tfe.data.spectrograms) def test_event_get_model(tfe): @@ -123,12 +124,12 @@ def test_event_get_model(tfe): # Check with regenerating tfm1 = tfe.get_model(1, 1, True) assert tfm1 - assert np.all(tfm1.modeled_spectrum_) + assert np.all(tfm1.results.modeled_spectrum_) def test_event_get_params(tfe): for dname in ['aperiodic', 'peak', 'error', 'r_squared']: - assert np.any(tfe.get_params(dname)) + assert np.any(tfe.results.get_params(dname)) def test_event_get_group(tfe): @@ -140,30 +141,34 @@ def test_event_get_group(tfe): n_out = len(einds) * len(winds) ntfe1 = tfe.get_group(einds, winds) - assert ntfe1 - assert ntfe1.spectrograms.shape == (len(einds), len(tfe.freqs), len(winds)) - tkey = list(ntfe1.event_time_results.keys())[0] - assert ntfe1.event_time_results[tkey].shape == (len(einds), len(winds)) - assert len(ntfe1.event_group_results), len(ntfe1.event_group_results[0]) == (len(einds, len(winds))) + assert isinstance(ntfe1, SpectralTimeEventModel) + assert ntfe1.data.spectrograms.shape == (len(einds), len(tfe.data.freqs), len(winds)) + tkey = list(ntfe1.results.event_time_results.keys())[0] + assert ntfe1.results.event_time_results[tkey].shape == (len(einds), len(winds)) + assert len(ntfe1.results.event_group_results), len(ntfe1.results.event_group_results[0]) == (len(einds, len(winds))) # Test export sub-objects, including with None input ntft0 = tfe.get_group(None, None, 'time') + assert isinstance(ntft0, SpectralTimeModel) assert not isinstance(ntft0, SpectralTimeEventModel) - assert not ntft0.group_results + assert not ntft0.results.group_results ntft1 = tfe.get_group(einds, winds, 'time') + assert isinstance(ntft1, SpectralTimeModel) assert not isinstance(ntft1, SpectralTimeEventModel) - assert ntft1.group_results - assert len(ntft1.group_results) == len(ntft1.power_spectra) == n_out + assert ntft1.results.group_results + assert len(ntft1.results.group_results) == len(ntft1.data.power_spectra) == n_out ntfg0 = tfe.get_group(None, None, 'group') - assert not isinstance(ntfg0, SpectralTimeEventModel) - assert not ntfg0.group_results + assert isinstance(ntfg0, SpectralGroupModel) + assert not isinstance(ntfg0, (SpectralTimeModel, SpectralTimeEventModel)) + assert not ntfg0.results.group_results ntfg1 = tfe.get_group(einds, winds, 'group') - assert not isinstance(ntfg1, SpectralTimeEventModel) - assert ntfg1.group_results - assert len(ntfg1.group_results) == len(ntfg1.power_spectra) == n_out + assert isinstance(ntfg1, SpectralGroupModel) + assert not isinstance(ntfg1, (SpectralTimeModel, SpectralTimeEventModel)) + assert ntfg1.results.group_results + assert len(ntfg1.results.group_results) == len(ntfg1.data.power_spectra) == n_out def test_event_drop(): @@ -176,23 +181,23 @@ def test_event_drop(): # Check list drops event_inds = [0] window_inds = [1] - tfe.drop(event_inds, window_inds) - assert len(tfe) == len(ys) - dropped_fres = tfe.event_group_results[event_inds[0]][window_inds[0]] + tfe.results.drop(event_inds, window_inds) + assert len(tfe.results) == len(ys) + dropped_fres = tfe.results.event_group_results[event_inds[0]][window_inds[0]] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) - for key in tfe.event_time_results: - assert np.isnan(tfe.event_time_results[key][event_inds[0], window_inds[0]]) + for key in tfe.results.event_time_results: + assert np.isnan(tfe.results.event_time_results[key][event_inds[0], window_inds[0]]) # Check dictionary drops drop_inds = {0 : [2], 1 : [1, 2]} - tfe.drop(drop_inds) - assert len(tfe) == len(ys) - dropped_fres = tfe.event_group_results[0][drop_inds[0][0]] + tfe.results.drop(drop_inds) + assert len(tfe.results) == len(ys) + dropped_fres = tfe.results.event_group_results[0][drop_inds[0][0]] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) - for key in tfe.event_time_results: - assert np.isnan(tfe.event_time_results[key][0, drop_inds[0][0]]) + for key in tfe.results.event_time_results: + assert np.isnan(tfe.results.event_time_results[key][0, drop_inds[0][0]]) def test_event_to_df(tfe, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_group.py b/specparam/tests/models/test_group.py index 2c436bb76..99618b222 100644 --- a/specparam/tests/models/test_group.py +++ b/specparam/tests/models/test_group.py @@ -38,46 +38,46 @@ def test_group(): def test_getitem(tfg): """Check indexing, from custom `__getitem__` in group object.""" - assert tfg[0] + assert tfg.results[0] def test_iter(tfg): """Check iterating through group object.""" - for res in tfg: + for res in tfg.results: assert res def test_has_data(tfg): """Test the has_data property attribute, with and without data.""" - assert tfg.has_model + assert tfg.results.has_model ntfg = SpectralGroupModel() - assert not ntfg.has_data + assert not ntfg.data.has_data def test_has_model(tfg): """Test the has_model property attribute, with and without model fits.""" - assert tfg.has_model + assert tfg.results.has_model ntfg = SpectralGroupModel() - assert not ntfg.has_model + assert not ntfg.results.has_model def test_n_peaks(tfg): """Test the n_peaks property attribute.""" - assert tfg.n_peaks_ + assert tfg.results.n_peaks_ def test_n_null(tfg): """Test the n_null_ property attribute.""" # Since there should have been no failed fits, this should return 0 - assert tfg.n_null_ == 0 + assert tfg.results.n_null_ == 0 def test_null_inds(tfg): """Test the null_inds_ property attribute.""" # Since there should be no failed fits, this should return an empty list - assert tfg.null_inds_ == [] + assert tfg.results.null_inds_ == [] def test_fit_nk(): """Test group fit, no knee.""" @@ -87,7 +87,7 @@ def test_fit_nk(): tfg = SpectralGroupModel(verbose=False) tfg.fit(xs, ys) - out = tfg.get_results() + out = tfg.results.get_results() assert out assert len(out) == n_spectra @@ -104,7 +104,7 @@ def test_fit_nk_noise(): tfg.fit(xs, ys) # No accuracy checking here - just checking that it ran - assert tfg.has_model + assert tfg.results.has_model def test_fit_knee(): """Test group fit, with a knee.""" @@ -119,7 +119,7 @@ def test_fit_knee(): tfg.fit(xs, ys) # No accuracy checking here - just checking that it ran - assert tfg.has_model + assert tfg.results.has_model def test_fit_progress(tfg): """Test running group fitting, with a progress bar.""" @@ -136,30 +136,30 @@ def test_fg_fail(): # Use a fg with the max iterations set so low that it will fail to converge ntfg = SpectralGroupModel() - ntfg._maxfev = 5 + ntfg.algorithm._maxfev = 5 # Fit models, where some will fail, to see if it completes cleanly ntfg.fit(fs, ps) # Check that results are all - for res in ntfg.get_results(): + for res in ntfg.results.get_results(): assert res # Test that get_params works with failed model fits - outs1 = ntfg.get_params('aperiodic_params') - outs2 = ntfg.get_params('aperiodic_params', 'exponent') - outs3 = ntfg.get_params('peak_params') - outs4 = ntfg.get_params('peak_params', 0) - outs5 = ntfg.get_params('gaussian_params', 2) + outs1 = ntfg.results.get_params('aperiodic_params') + outs2 = ntfg.results.get_params('aperiodic_params', 'exponent') + outs3 = ntfg.results.get_params('peak_params') + outs4 = ntfg.results.get_params('peak_params', 0) + outs5 = ntfg.results.get_params('gaussian_params', 2) # Test shortcut labels - outs6 = ntfg.get_params('aperiodic') - outs6 = ntfg.get_params('peak', 'CF') + outs6 = ntfg.results.get_params('aperiodic') + outs6 = ntfg.results.get_params('peak', 'CF') # Test the property attributes related to null model fits # This checks that they do the right thing when there are null fits (failed fits) - assert ntfg.n_null_ > 0 - assert ntfg.null_inds_ + assert ntfg.results.n_null_ > 0 + assert ntfg.results.null_inds_ def test_drop(): """Test function to drop results from group object.""" @@ -173,24 +173,24 @@ def test_drop(): tfg.fit(xs, ys) drop_ind = 0 - tfg.drop(drop_ind) - dropped_fres = tfg.group_results[drop_ind] + tfg.results.drop(drop_ind) + dropped_fres = tfg.results.group_results[drop_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) # Test dropping multiple inds tfg.fit(xs, ys) drop_inds = [0, 2] - tfg.drop(drop_inds) + tfg.results.drop(drop_inds) for d_ind in drop_inds: - dropped_fres = tfg.group_results[d_ind] + dropped_fres = tfg.results.group_results[d_ind] for field in dropped_fres._fields: assert np.all(np.isnan(getattr(dropped_fres, field))) # Test that a group object that has had inds dropped still works with `get_params` - cfs = tfg.get_params('peak_params', 1) - exps = tfg.get_params('aperiodic_params', 'exponent') + cfs = tfg.results.get_params('peak_params', 1) + exps = tfg.results.get_params('aperiodic_params', 'exponent') assert np.all(np.isnan(exps[drop_inds])) assert np.all(np.invert(np.isnan(np.delete(exps, drop_inds)))) @@ -202,7 +202,7 @@ def test_fit_par(): tfg = SpectralGroupModel(verbose=False) tfg.fit(xs, ys, n_jobs=2) - out = tfg.get_results() + out = tfg.results.get_results() assert out assert len(out) == n_spectra @@ -225,13 +225,13 @@ def test_save_model_report(tfg): def test_get_results(tfg): """Check get results method.""" - assert tfg.get_results() + assert tfg.results.get_results() def test_get_params(tfg): """Check get_params method.""" for dname in ['aperiodic', 'peak', 'error', 'r_squared']: - assert np.any(tfg.get_params(dname)) + assert np.any(tfg.results.get_params(dname)) @plot_test def test_plot(tfg, skip_if_no_mpl): @@ -250,42 +250,42 @@ def test_load(): # Test loading just results tfg = SpectralGroupModel(verbose=False) tfg.load(file_name_res, TEST_DATA_PATH) - assert len(tfg.group_results) > 0 + assert len(tfg.results.group_results) > 0 # Test that settings and data are None for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is None - assert tfg.power_spectra is None + assert getattr(tfg.algorithm, setting) is None + assert tfg.data.power_spectra is None # Test loading just settings tfg = SpectralGroupModel(verbose=False) tfg.load(file_name_set, TEST_DATA_PATH) for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None + assert getattr(tfg.algorithm, setting) is not None # Test that results and data are None for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfg, result))) - assert tfg.power_spectra is None + assert np.all(np.isnan(getattr(tfg.results, result))) + assert tfg.data.power_spectra is None # Test loading just data tfg = SpectralGroupModel(verbose=False) tfg.load(file_name_dat, TEST_DATA_PATH) - assert tfg.power_spectra is not None + assert tfg.data.has_data # Test that settings and results are None for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is None + assert getattr(tfg.algorithm, setting) is None for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfg, result))) + assert np.all(np.isnan(getattr(tfg.results, result))) # Test loading all elements tfg = SpectralGroupModel(verbose=False) file_name_all = 'test_group_all' tfg.load(file_name_all, TEST_DATA_PATH) - assert len(tfg.group_results) > 0 + assert len(tfg.results.group_results) > 0 for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) is not None - assert tfg.power_spectra is not None + assert getattr(tfg.algorithm, setting) is not None + assert tfg.data.has_data for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfg, meta_dat) is not None + assert getattr(tfg.data, meta_dat) is not None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -301,19 +301,29 @@ def test_report(skip_if_no_mpl): def test_get_model(tfg): """Check return of an individual model fit from a group object.""" + # Test with no ind (no data / results) + tfm = tfg.get_model() + assert tfm + # Check that settings are copied over properly, but data and results are empty + for setting in OBJ_DESC['settings']: + assert getattr(tfg.algorithm, setting) == getattr(tfm.algorithm, setting) + for result in OBJ_DESC['results']: + assert np.all(np.isnan(getattr(tfm.results, result))) + assert not tfm.data.power_spectrum + # Check without regenerating tfm0 = tfg.get_model(0, False) assert tfm0 # Check that settings are copied over properly for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) == getattr(tfm0, setting) + assert getattr(tfg.algorithm, setting) == getattr(tfm0.algorithm, setting) # Check with regenerating tfm1 = tfg.get_model(1, True) assert tfm1 # Check that regenerated model is created for result in OBJ_DESC['results']: - assert np.all(getattr(tfm1, result)) + assert np.all(getattr(tfm1.results, result)) # Test when object has no data (clear a copy of tfg) new_tfg = tfg.copy() @@ -322,7 +332,7 @@ def test_get_model(tfg): assert tfm2 # Check that data info is copied over properly for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm2, meta_dat) + assert getattr(tfm2.data, meta_dat) def test_get_group(tfg): """Check the return of a sub-sampled group object.""" @@ -330,8 +340,8 @@ def test_get_group(tfg): # Test with no inds nfg0 = tfg.get_group(None) assert isinstance(nfg0, SpectralGroupModel) - assert nfg0.get_settings() == tfg.get_settings() - assert nfg0.get_meta_data() == tfg.get_meta_data() + assert nfg0.algorithm.get_settings() == tfg.algorithm.get_settings() + assert nfg0.data.get_meta_data() == tfg.data.get_meta_data() # Check with list index inds1 = [1, 2] @@ -345,21 +355,21 @@ def test_get_group(tfg): # Check that settings are copied over properly for setting in OBJ_DESC['settings']: - assert getattr(tfg, setting) == getattr(nfg1, setting) - assert getattr(tfg, setting) == getattr(nfg2, setting) + assert getattr(tfg.algorithm, setting) == getattr(nfg1.algorithm, setting) + assert getattr(tfg.algorithm, setting) == getattr(nfg2.algorithm, setting) # Check that data info is copied over properly for meta_dat in OBJ_DESC['meta_data']: - assert getattr(nfg1, meta_dat) - assert getattr(nfg2, meta_dat) + assert getattr(nfg1.data, meta_dat) + assert getattr(nfg2.data, meta_dat) # Check that the correct data is extracted - assert_equal(tfg.power_spectra[inds1, :], nfg1.power_spectra) - assert_equal(tfg.power_spectra[inds2, :], nfg2.power_spectra) + assert_equal(tfg.data.power_spectra[inds1, :], nfg1.data.power_spectra) + assert_equal(tfg.data.power_spectra[inds2, :], nfg2.data.power_spectra) # Check that the correct results are extracted - assert [tfg.group_results[ind] for ind in inds1] == nfg1.group_results - assert [tfg.group_results[ind] for ind in inds2] == nfg2.group_results + assert [tfg.results.group_results[ind] for ind in inds1] == nfg1.results.group_results + assert [tfg.results.group_results[ind] for ind in inds2] == nfg2.results.group_results def test_fg_to_df(tfg, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_model.py b/specparam/tests/models/test_model.py index d66a319b5..6ea7e95e0 100644 --- a/specparam/tests/models/test_model.py +++ b/specparam/tests/models/test_model.py @@ -42,23 +42,23 @@ def test_model_object(): def test_has_data(tfm): """Test the has_data property attribute, with and without model fits.""" - assert tfm.has_data + assert tfm.data.has_data ntfm = SpectralModel() - assert not ntfm.has_data + assert not ntfm.data.has_data def test_has_model(tfm): """Test the has_model property attribute, with and without model fits.""" - assert tfm.has_model + assert tfm.results.has_model ntfm = SpectralModel() - assert not ntfm.has_model + assert not ntfm.results.has_model def test_n_peaks(tfm): """Test the n_peaks property attribute.""" - assert tfm.n_peaks_ + assert tfm.results.n_peaks_ def test_fit_nk(): """Test fit, no knee.""" @@ -73,11 +73,11 @@ def test_fit_nk(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.aperiodic_params_, [0.5, 0.1]) + assert np.allclose(ap_params, tfm.results.aperiodic_params_, [0.5, 0.1]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) def test_fit_nk_noise(): """Test fit on noisy data, to make sure nothing breaks.""" @@ -89,7 +89,7 @@ def test_fit_nk_noise(): tfm.fit(xs, ys) # No accuracy checking here - just checking that it ran - assert tfm.has_model + assert tfm.results.has_model def test_fit_knee(): """Test fit, with a knee.""" @@ -104,11 +104,11 @@ def test_fit_knee(): tfm.fit(xs, ys) # Check model results - aperiodic parameters - assert np.allclose(ap_params, tfm.aperiodic_params_, [1, 2, 0.2]) + assert np.allclose(ap_params, tfm.results.aperiodic_params_, [1, 2, 0.2]) # Check model results - gaussian parameters for ii, gauss in enumerate(groupby(gauss_params, 3)): - assert np.allclose(gauss, tfm.gaussian_params_[ii], [2.0, 0.5, 1.0]) + assert np.allclose(gauss, tfm.results.gaussian_params_[ii], [2.0, 0.5, 1.0]) def test_fit_measures(): """Test goodness of fit & error metrics, post model fitting.""" @@ -116,20 +116,20 @@ def test_fit_measures(): tfm = SpectralModel(verbose=False) # Hack fake data with known properties: total error magnitude 2 - tfm.power_spectrum = np.array([1, 2, 3, 4, 5]) - tfm.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) + tfm.data.power_spectrum = np.array([1, 2, 3, 4, 5]) + tfm.results.modeled_spectrum_ = np.array([1, 2, 5, 4, 5]) # Check default goodness of fit and error measures - tfm._compute_model_gof() - assert np.isclose(tfm.r_squared_, 0.75757575) - tfm._compute_model_error() - assert np.isclose(tfm.error_, 0.4) + tfm.metrics.compute_metrics(tfm.data, tfm.results) + assert np.isclose(tfm.metrics['error-mae'].output, 0.4) + assert np.isclose(tfm.metrics['gof-r_squared'].output, 0.75757575) - # Check with alternative error fit approach - tfm._compute_model_error(metric='MSE') - assert np.isclose(tfm.error_, 0.8) - tfm._compute_model_error(metric='RMSE') - assert np.isclose(tfm.error_, np.sqrt(0.8)) + # # TODO: fix / turn back on when adding update metric functionality + # # Check with alternative error fit metrics + # tfm.results._compute_model_error(metric='MSE') + # assert np.isclose(tfm.results.error_, 0.8) + # tfm.results._compute_model_error(metric='RMSE') + # assert np.isclose(tfm.results.error_, np.sqrt(0.8)) def test_checks(): """Test various checks, errors and edge cases for model fitting. @@ -163,7 +163,7 @@ def test_checks(): # Check freq of 0 issue xs, ys = sim_power_spectrum(*default_spectrum_params()) tfm.fit(xs, ys) - assert tfm.freqs[0] != 0 + assert tfm.data.freqs[0] != 0 # Check error for `check_freqs` - for if there is non-even frequency values with raises(DataError): @@ -192,46 +192,46 @@ def test_load(): tfm.load(file_name_res, TEST_DATA_PATH) # Check that result attributes get filled for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) + assert not np.all(np.isnan(getattr(tfm.results, result))) # Test that settings and data are None for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is None - assert getattr(tfm, 'power_spectrum') is None + assert getattr(tfm.algorithm, setting) is None + assert tfm.data.power_spectrum is None # Test loading just settings tfm = SpectralModel(verbose=False) file_name_set = 'test_model_set' tfm.load(file_name_set, TEST_DATA_PATH) for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None + assert getattr(tfm.algorithm, setting) is not None # Test that results and data are None for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) - assert tfm.power_spectrum is None + assert np.all(np.isnan(getattr(tfm.results, result))) + assert tfm.data.power_spectrum is None # Test loading just data tfm = SpectralModel(verbose=False) file_name_dat = 'test_model_dat' tfm.load(file_name_dat, TEST_DATA_PATH) - assert tfm.power_spectrum is not None + assert tfm.data.power_spectrum is not None # Test that settings and results are None for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is None + assert getattr(tfm.algorithm, setting) is None for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) + assert np.all(np.isnan(getattr(tfm.results, result))) # Test loading all elements tfm = SpectralModel(verbose=False) file_name_all = 'test_model_all' tfm.load(file_name_all, TEST_DATA_PATH) for result in OBJ_DESC['results']: - assert not np.all(np.isnan(getattr(tfm, result))) + assert not np.all(np.isnan(getattr(tfm.results, result))) for setting in OBJ_DESC['settings']: - assert getattr(tfm, setting) is not None + assert getattr(tfm.algorithm, setting) is not None for data in OBJ_DESC['data']: - assert getattr(tfm, data) is not None + assert getattr(tfm.data, data) is not None for meta_dat in OBJ_DESC['meta_data']: - assert getattr(tfm, meta_dat) is not None + assert getattr(tfm.data, meta_dat) is not None def test_add_data(): """Tests method to add data to model objects.""" @@ -244,22 +244,22 @@ def test_add_data(): # Test adding data tfm.add_data(freqs, pows) - assert tfm.has_data - assert np.all(tfm.freqs == freqs) - assert np.all(tfm.power_spectrum == np.log10(pows)) + assert tfm.data.has_data + assert np.all(tfm.data.freqs == freqs) + assert np.all(tfm.data.power_spectrum == np.log10(pows)) # Test that prior data does not get cleared, when requesting not to clear tfm._reset_data_results(True, True, True) - tfm.add_results(FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])) + tfm.results.add_results(FitResults([1, 1], [10, 0.5, 0.5], 0.95, 0.02, [10, 0.5, 0.25])) tfm.add_data(freqs, pows, clear_results=False) - assert tfm.has_data - assert tfm.has_model + assert tfm.data.has_data + assert tfm.results.has_model # Test that prior data does get cleared, when requesting not to clear tfm._reset_data_results(True, True, True) tfm.add_data(freqs, pows, clear_results=True) - assert tfm.has_data - assert not tfm.has_model + assert tfm.data.has_data + assert not tfm.results.has_model def test_get_params(tfm): """Test the get_params method.""" @@ -282,11 +282,11 @@ def test_get_data(tfm): for space in ['log', 'linear']: assert isinstance(tfm.get_data(comp, space), np.ndarray) -def test_get_model(tfm): +def test_get_component(tfm): for comp in ['full', 'aperiodic', 'peak']: for space in ['log', 'linear']: - assert isinstance(tfm.get_model(comp, space), np.ndarray) + assert isinstance(tfm.results.get_component(comp, space), np.ndarray) def test_prints(tfm): """Test methods that print (alias and pass through methods). @@ -312,14 +312,15 @@ def test_resets(): tfm = get_tfm() tfm._reset_data_results(True, True, True) - tfm._reset_internal_settings() + tfm.algorithm._reset_internal_settings() - for data in ['data', 'model_components']: - for field in OBJ_DESC[data]: - assert getattr(tfm, field) is None + for field in OBJ_DESC['data']: + assert getattr(tfm.data, field) is None + for field in OBJ_DESC['model_components']: + assert getattr(tfm.results, field) is None for field in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, field))) - assert tfm.freqs is None and tfm.modeled_spectrum_ is None + assert np.all(np.isnan(getattr(tfm.results, field))) + assert tfm.data.freqs is None and tfm.results.modeled_spectrum_ is None def test_report(skip_if_no_mpl): """Check that running the top level model method runs.""" @@ -334,36 +335,36 @@ def test_fit_failure(): ## Induce a runtime error, and check it runs through tfm = SpectralModel(verbose=False) - tfm._maxfev = 2 + tfm.algorithm._maxfev = 2 tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) + assert np.all(np.isnan(getattr(tfm.results, result))) ## Monkey patch to check errors in general # This mimics the main fit-failure, without requiring bad data / waiting for it to fail. tfm = SpectralModel(verbose=False) def raise_runtime_error(*args, **kwargs): raise FitError('Test-MonkeyPatch') - tfm._fit_peaks = raise_runtime_error + tfm.algorithm._fit_peaks = raise_runtime_error # Run a model fit - this should raise an error, but continue in try/except tfm.fit(*sim_power_spectrum(*default_spectrum_params())) # Check after failing out of fit, all results are reset for result in OBJ_DESC['results']: - assert np.all(np.isnan(getattr(tfm, result))) + assert np.all(np.isnan(getattr(tfm.results, result))) def test_debug(): """Test model object in debug state, including with fit failures.""" tfm = SpectralModel(verbose=False) - tfm._maxfev = 2 + tfm.algorithm._maxfev = 2 - tfm.set_debug(True) - assert tfm._debug is True + tfm.algorithm.set_debug(True) + assert tfm.algorithm._debug is True with raises(FitError): tfm.fit(*sim_power_spectrum(*default_spectrum_params())) @@ -373,28 +374,28 @@ def test_set_checks(): Note that testing for checks raising errors happens in test_checks.`""" tfm = SpectralModel(verbose=False) - tfm.set_checks(False, False) + tfm.data.set_checks(False, False) # Add bad frequency data, with check freqs turned off freqs = np.array([1, 2, 4]) powers = np.array([1, 2, 3]) tfm.add_data(freqs, powers) - assert tfm.has_data + assert tfm.data.has_data # Add bad power values data, with check data turned off freqs = gen_freqs([3, 30], 1) powers = np.ones_like(freqs) * np.nan tfm.add_data(freqs, powers) - assert tfm.has_data + assert tfm.data.has_data # Model fitting should execute, but return a null model fit, given the NaNs, without failing tfm.fit() - assert not tfm.has_model + assert not tfm.results.has_model # Reset checks to true - tfm.set_checks(True, True) - assert tfm._check_freqs is True - assert tfm._check_data is True + tfm.data.set_checks(True, True) + assert tfm.data._check_freqs is True + assert tfm.data._check_data is True def test_to_df(tfm, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_time.py b/specparam/tests/models/test_time.py index dea297dab..135b97b3d 100644 --- a/specparam/tests/models/test_time.py +++ b/specparam/tests/models/test_time.py @@ -31,16 +31,16 @@ def test_time_model(): def test_time_getitem(tft): - assert tft[0] + assert tft.results[0] def test_time_iter(tft): - for out in tft: + for out in tft.results: assert out def test_time_n_peaks(tft): - assert tft.n_peaks_ + assert tft.results.n_peaks_ def test_time_fit(): @@ -50,7 +50,7 @@ def test_time_fit(): tft = SpectralTimeModel(verbose=False) tft.fit(xs, ys) - results = tft.get_results() + results = tft.results.get_results() assert results assert isinstance(results, dict) @@ -86,17 +86,17 @@ def test_time_load(tbands): # Test loading results tft = SpectralTimeModel(verbose=False) tft.load(file_name_res, TEST_DATA_PATH, peak_org=tbands) - assert tft.time_results + assert tft.results.time_results # Test loading settings tft = SpectralTimeModel(verbose=False) tft.load(file_name_set, TEST_DATA_PATH) - assert tft.get_settings() + assert tft.algorithm.get_settings() # Test loading data tft = SpectralTimeModel(verbose=False) tft.load(file_name_dat, TEST_DATA_PATH) - assert np.all(tft.power_spectra) + assert np.all(tft.data.power_spectra) def test_time_drop(): @@ -106,11 +106,11 @@ def test_time_drop(): tft.fit(xs, ys) drop_inds = [0, 2] - tft.drop(drop_inds) - assert len(tft) == n_windows + tft.results.drop(drop_inds) + assert len(tft.results) == n_windows for dind in drop_inds: - for key in tft.time_results: - assert np.isnan(tft.time_results[key][dind]) + for key in tft.results.time_results: + assert np.isnan(tft.results.time_results[key][dind]) def test_time_get_group(tft): @@ -121,13 +121,13 @@ def test_time_get_group(tft): nft = tft.get_group(inds) assert isinstance(nft, SpectralTimeModel) - assert len(nft.group_results) == len(inds) - assert len(nft.time_results[list(nft.time_results.keys())[0]]) == len(inds) - assert nft.spectrogram.shape[-1] == len(inds) + assert len(nft.results.group_results) == len(inds) + assert len(nft.results.time_results[list(nft.results.time_results.keys())[0]]) == len(inds) + assert nft.data.spectrogram.shape[-1] == len(inds) nfg = tft.get_group(inds, 'group') assert not isinstance(nfg, SpectralTimeModel) - assert len(nfg.group_results) == len(inds) + assert len(nfg.results.group_results) == len(inds) def test_time_to_df(tft, tbands, skip_if_no_pandas): diff --git a/specparam/tests/models/test_utils.py b/specparam/tests/models/test_utils.py index 5712920af..6c018a9f5 100644 --- a/specparam/tests/models/test_utils.py +++ b/specparam/tests/models/test_utils.py @@ -15,6 +15,17 @@ ################################################################################################### ################################################################################################### +def test_initialize_model_from_source(tfm, tfg): + + for source in [tfm, tfg]: + for target in ['model', 'group', 'time', 'event']: + out = initialize_model_from_source(source, target) + assert out.algorithm.get_settings() == source.algorithm.get_settings() + assert out.data.get_meta_data() == source.data.get_meta_data() + assert out.modes.get_modes() == source.modes.get_modes() + assert not out.data.has_data + assert not out.results.has_model + def test_compare_model_objs(tfm, tfg): for f_obj in [tfm, tfg]: @@ -22,12 +33,12 @@ def test_compare_model_objs(tfm, tfg): f_obj2 = f_obj.copy() assert compare_model_objs([f_obj, f_obj2], 'settings') - f_obj2.peak_width_limits = [2, 4] - f_obj2._reset_internal_settings() + f_obj2.algorithm.peak_width_limits = [2, 4] + f_obj2.algorithm._reset_internal_settings() assert not compare_model_objs([f_obj, f_obj2], 'settings') assert compare_model_objs([f_obj, f_obj2], 'meta_data') - f_obj2.freq_range = [5, 25] + f_obj2.data.freq_range = [5, 25] assert not compare_model_objs([f_obj, f_obj2], 'meta_data') def test_average_group(tfg, tbands): @@ -61,57 +72,57 @@ def test_combine_model_objs(tfm, tfg): # Check combining 2 model objects nfg1 = combine_model_objs([tfm, tfm2]) assert nfg1 - assert len(nfg1) == 2 + assert len(nfg1.results) == 2 assert compare_model_objs([nfg1, tfm], 'settings') - assert nfg1.group_results[0] == tfm.get_results() - assert nfg1.group_results[-1] == tfm2.get_results() + assert nfg1.results.group_results[0] == tfm.results.get_results() + assert nfg1.results.group_results[-1] == tfm2.results.get_results() # Check combining 3 model objects nfg2 = combine_model_objs([tfm, tfm2, tfm3]) assert nfg2 - assert len(nfg2) == 3 + assert len(nfg2.results) == 3 assert compare_model_objs([nfg2, tfm], 'settings') - assert nfg2.group_results[0] == tfm.get_results() - assert nfg2.group_results[-1] == tfm3.get_results() + assert nfg2.results.group_results[0] == tfm.results.get_results() + assert nfg2.results.group_results[-1] == tfm3.results.get_results() # Check combining 2 group objects nfg3 = combine_model_objs([tfg, tfg2]) assert nfg3 - assert len(nfg3) == len(tfg) + len(tfg2) + assert len(nfg3.results) == len(tfg.results) + len(tfg2.results) assert compare_model_objs([nfg3, tfg, tfg2], 'settings') - assert nfg3.group_results[0] == tfg.group_results[0] - assert nfg3.group_results[-1] == tfg2.group_results[-1] + assert nfg3.results.group_results[0] == tfg.results.group_results[0] + assert nfg3.results.group_results[-1] == tfg2.results.group_results[-1] # Check combining 3 group objects nfg4 = combine_model_objs([tfg, tfg2, tfg3]) assert nfg4 - assert len(nfg4) == len(tfg) + len(tfg2) + len(tfg3) + assert len(nfg4.results) == len(tfg.results) + len(tfg2.results) + len(tfg3.results) assert compare_model_objs([nfg4, tfg, tfg2, tfg3], 'settings') - assert nfg4.group_results[0] == tfg.group_results[0] - assert nfg4.group_results[-1] == tfg3.group_results[-1] + assert nfg4.results.group_results[0] == tfg.results.group_results[0] + assert nfg4.results.group_results[-1] == tfg3.results.group_results[-1] # Check combining a mixture of model & group objects nfg5 = combine_model_objs([tfg, tfm, tfg2, tfm2]) assert nfg5 - assert len(nfg5) == len(tfg) + 1 + len(tfg2) + 1 + assert len(nfg5.results) == len(tfg.results) + 1 + len(tfg2.results) + 1 assert compare_model_objs([nfg5, tfg, tfm, tfg2, tfm2], 'settings') - assert nfg5.group_results[0] == tfg.group_results[0] - assert nfg5.group_results[-1] == tfm2.get_results() + assert nfg5.results.group_results[0] == tfg.results.group_results[0] + assert nfg5.results.group_results[-1] == tfm2.results.get_results() # Check combining objects with no data tfm2._reset_data_results(False, True, True) tfg2._reset_data_results(False, True, True, True) nfg6 = combine_model_objs([tfm2, tfg2]) - assert len(nfg6) == 1 + len(tfg2) - assert nfg6.power_spectra is None + assert len(nfg6.results) == 1 + len(tfg2.results) + assert nfg6.data.power_spectra is None def test_combine_errors(tfm, tfg): # Incompatible settings for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - f_obj2.peak_width_limits = [2, 4] - f_obj2._reset_internal_settings() + f_obj2.algorithm.peak_width_limits = [2, 4] + f_obj2.algorithm._reset_internal_settings() with raises(IncompatibleSettingsError): combine_model_objs([f_obj, f_obj2]) @@ -119,7 +130,7 @@ def test_combine_errors(tfm, tfg): # Incompatible data information for f_obj in [tfm, tfg]: f_obj2 = f_obj.copy() - f_obj2.freq_range = [5, 30] + f_obj2.data.freq_range = [5, 30] with raises(IncompatibleSettingsError): combine_model_objs([f_obj, f_obj2]) @@ -138,5 +149,5 @@ def test_fit_models_3d(tfg): assert len(fgs) == n_groups == spectra_shape[0] for fg in fgs: assert fg - assert len(fg) == n_spectra - assert fg.power_spectra.shape == spectra_shape[1:] + assert len(fg.results) == n_spectra + assert fg.data.power_spectra.shape == spectra_shape[1:] diff --git a/specparam/tests/modes/test_info.py b/specparam/tests/modes/test_info.py index ca6176608..2bebeacf2 100644 --- a/specparam/tests/modes/test_info.py +++ b/specparam/tests/modes/test_info.py @@ -5,15 +5,16 @@ ################################################################################################### ################################################################################################### -def test_get_description(tfm): +## TEMP: REMOVE? +# def test_get_description(tfm): - desc = get_description() - objs = dir(tfm) +# desc = get_description() +# objs = dir(tfm) - # Test that everything in dict is a valid component of the model object - for ke, va in desc.items(): - for it in va: - assert it in objs +# # Test that everything in dict is a valid component of the model object +# for ke, va in desc.items(): +# for it in va: +# assert it in objs def test_get_peak_indices(): @@ -46,9 +47,10 @@ def test_get_indices(): all_indices_knee = get_indices('knee') assert len(all_indices_knee) == 6 -def test_get_info(tfm, tfg): +# TEMP: TO DROP? +# def test_get_info(tfm, tfg): - for f_obj in [tfm, tfg]: - assert get_info(f_obj, 'settings') - assert get_info(f_obj, 'meta_data') - assert get_info(f_obj, 'results') +# for f_obj in [tfm, tfg]: +# assert get_info(f_obj, 'settings') +# assert get_info(f_obj, 'meta_data') +# assert get_info(f_obj, 'results') diff --git a/specparam/tests/modes/test_modes.py b/specparam/tests/modes/test_modes.py new file mode 100644 index 000000000..b09a9467b --- /dev/null +++ b/specparam/tests/modes/test_modes.py @@ -0,0 +1,37 @@ +"""Tests for specparam.modes.modes.""" + +from specparam.data import ModelModes +from specparam.modes.definitions import AP_MODES, PE_MODES + +from specparam.modes.modes import * + +################################################################################################### +################################################################################################### + +def test_modes(): + + modes = Modes(aperiodic='fixed', periodic='gaussian') + assert modes + assert isinstance(modes.aperiodic, Mode) + assert isinstance(modes.periodic, Mode) + +def test_modes_get_modes(): + + ap_mode_name = 'fixed' + pe_mode_name = 'gaussian' + + modes = Modes(aperiodic=ap_mode_name, periodic=pe_mode_name) + mode_names = modes.get_modes() + assert isinstance(mode_names, ModelModes) + assert mode_names.aperiodic_mode == ap_mode_name + assert mode_names.periodic_mode == pe_mode_name + +def test_check_mode_definition(): + + for ap_mode in AP_MODES.keys(): + mode = check_mode_definition(ap_mode, AP_MODES) + assert isinstance(mode, Mode) + + for pe_mode in PE_MODES.keys(): + mode = check_mode_definition(pe_mode, PE_MODES) + assert isinstance(mode, Mode) diff --git a/specparam/tests/objs/test_base.py b/specparam/tests/objs/test_base.py index 16dc89032..01c00fed9 100644 --- a/specparam/tests/objs/test_base.py +++ b/specparam/tests/objs/test_base.py @@ -11,12 +11,12 @@ def test_common_base(): - tobj = CommonBase() + tobj = CommonBase(verbose=False) assert isinstance(tobj, CommonBase) def test_common_base_copy(): - tobj = CommonBase() + tobj = CommonBase(verbose=False) ntobj = tobj.copy() assert ntobj != tobj @@ -36,8 +36,6 @@ def test_base2d(): tobj2d = BaseObject2D() assert isinstance(tobj2d, CommonBase) assert isinstance(tobj2d, BaseObject2D) - assert isinstance(tobj2d, BaseResults2D) - assert isinstance(tobj2d, BaseObject2D) ## 2DT Base Object @@ -46,8 +44,6 @@ def test_base2dt(): tobj2dt = BaseObject2DT() assert isinstance(tobj2dt, CommonBase) assert isinstance(tobj2dt, BaseObject2DT) - assert isinstance(tobj2dt, BaseResults2DT) - assert isinstance(tobj2dt, BaseObject2DT) ## 3D Base Object @@ -56,6 +52,4 @@ def test_base3d(): tobj3d = BaseObject3D() assert isinstance(tobj3d, CommonBase) assert isinstance(tobj3d, BaseObject2DT) - assert isinstance(tobj3d, BaseResults2DT) - assert isinstance(tobj3d, BaseObject2DT) assert isinstance(tobj3d, BaseObject3D) diff --git a/specparam/tests/objs/test_metrics.py b/specparam/tests/objs/test_metrics.py new file mode 100644 index 000000000..37ba6bb79 --- /dev/null +++ b/specparam/tests/objs/test_metrics.py @@ -0,0 +1,53 @@ +"""Tests for specparam.objs.metrics.""" + +from pytest import raises + +from specparam.measures.error import compute_mean_abs_error +from specparam.measures.gof import compute_r_squared + +from specparam.objs.metrics import * + +################################################################################################### +################################################################################################### + +def test_metric(tfm): + + metric = Metric('error', 'mae', compute_mean_abs_error) + assert isinstance(metric, Metric) + assert isinstance(metric.label, str) + + metric.compute_metric(tfm.data, tfm.results) + assert isinstance(metric.output, float) + +def test_metrics_null(): + + metrics = Metrics() + assert isinstance(metrics, Metrics) + +def test_metrics_obj(tfm): + + er_metric = Metric('error', 'mae', compute_mean_abs_error) + gof_metric = Metric('gof', 'r_squared', compute_r_squared) + + metrics = Metrics([er_metric, gof_metric]) + assert isinstance(metrics, Metrics) + + metrics.compute_metrics(tfm.data, tfm.results) + assert isinstance(metrics.outputs, dict) + + # Check indexing + met_out = metrics['error-mae'] + assert isinstance(met_out, Metric) + with raises(ValueError): + metrics['bad-label'] + +def test_metrics_dict(tfm): + + er_met_def = {'measure' : 'error', 'metric' : 'mae', 'func' : compute_mean_abs_error} + gof_met_def = {'measure' : 'gof', 'metric' : 'r_squared', 'func' : compute_r_squared} + + metrics = Metrics([er_met_def, gof_met_def]) + assert isinstance(metrics, Metrics) + + metrics.compute_metrics(tfm.data, tfm.results) + assert isinstance(metrics.outputs, dict) diff --git a/specparam/tests/objs/test_results.py b/specparam/tests/objs/test_results.py index c23c54d7f..e162cef6f 100644 --- a/specparam/tests/objs/test_results.py +++ b/specparam/tests/objs/test_results.py @@ -1,7 +1,6 @@ """Tests for specparam.objs.results, including the data object and it's methods.""" from specparam.modes.items import OBJ_DESC -from specparam.data import ModelSettings from specparam.objs.results import * @@ -12,28 +11,12 @@ def test_base_results(): - tres1 = BaseResults(None, None) + tres1 = BaseResults() assert isinstance(tres1, BaseResults) - tres2 = BaseResults(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tres2, BaseResults) - -def test_base_results_settings(): - - tres = BaseResults(None, None) - - settings = ModelSettings([1, 4], 6, 0, 2) - tres.add_settings(settings) - for setting in OBJ_DESC['settings']: - assert getattr(tres, setting) == getattr(settings, setting) - - settings_out = tres.get_settings() - assert isinstance(settings, ModelSettings) - assert settings_out == settings - def test_base_results_results(tresults): - tres = BaseResults(None, None) + tres = BaseResults() tres.add_results(tresults) assert tres.has_model @@ -48,16 +31,13 @@ def test_base_results_results(tresults): def test_base_results2d(): - tres2d1 = BaseResults2D(None, None) + tres2d1 = BaseResults2D() assert isinstance(tres2d1, BaseResults) assert isinstance(tres2d1, BaseResults2D) - tres2d2 = BaseResults2D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tres2d2, BaseResults2D) - def test_base_results2d_results(tresults): - tres2d = BaseResults2D(None, None) + tres2d = BaseResults2D() results = [tresults, tresults] tres2d.add_results(results) @@ -70,17 +50,14 @@ def test_base_results2d_results(tresults): def test_base_results2dt(): - tres2dt1 = BaseResults2DT(None, None) + tres2dt1 = BaseResults2DT() assert isinstance(tres2dt1, BaseResults) assert isinstance(tres2dt1, BaseResults2D) assert isinstance(tres2dt1, BaseResults2DT) - tres2dt2 = BaseResults2DT(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tres2dt2, BaseResults2DT) - def test_base_results2d_results(tresults): - tres2dt = BaseResults2DT(None, None) + tres2dt = BaseResults2DT() results = [tresults, tresults] tres2dt.add_results(results) @@ -94,18 +71,15 @@ def test_base_results2d_results(tresults): def test_base_results3d(): - tres3d1 = BaseResults3D(None, None) + tres3d1 = BaseResults3D() assert isinstance(tres3d1, BaseResults) assert isinstance(tres3d1, BaseResults2D) assert isinstance(tres3d1, BaseResults2DT) assert isinstance(tres3d1, BaseResults3D) - tres3d2 = BaseResults3D(aperiodic_mode='fixed', periodic_mode='gaussian') - assert isinstance(tres3d2, BaseResults3D) - def test_base_results3d_results(tresults): - tres3d = BaseResults3D(None, None) + tres3d = BaseResults3D() eresults = [[tresults, tresults], [tresults, tresults]] tres3d.add_results(eresults) diff --git a/specparam/tests/plts/test_model.py b/specparam/tests/plts/test_model.py index c94dd3154..e30e3f120 100644 --- a/specparam/tests/plts/test_model.py +++ b/specparam/tests/plts/test_model.py @@ -22,8 +22,8 @@ def test_plot_model(tfm, skip_if_no_mpl): def test_plot_model_custom(tfm, skip_if_no_mpl): # Extract broader range of data available in the object - custom_freqs = tfm.freqs - custom_power_spectrum = np.power(10, tfm.power_spectrum) + custom_freqs = tfm.data.freqs + custom_power_spectrum = np.power(10, tfm.data.power_spectrum) # Make sure model has been fit - set custom frequency range tfm.fit(custom_freqs, custom_power_spectrum, freq_range=[5, 35]) diff --git a/specparam/tests/plts/test_spectra.py b/specparam/tests/plts/test_spectra.py index a138851c3..960f8972e 100644 --- a/specparam/tests/plts/test_spectra.py +++ b/specparam/tests/plts/test_spectra.py @@ -16,44 +16,44 @@ def test_plot_spectra(tfm, tfg, skip_if_no_mpl): # Test with 1d inputs - 1d freq array & list of 1d power spectra - plot_spectra(tfm.freqs, tfm.power_spectrum, + plot_spectra(tfm.data.freqs, tfm.data.power_spectrum, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png') # Test with 1d inputs - 1d freq array & list of 1d power spectra - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + plot_spectra(tfg.data.freqs, [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d.png') # Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra - plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + plot_spectra([tfg.data.freqs, tfg.data.freqs], [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_list_1d_freqs.png') # Test with multiple lists - list of 1d freqs & list of 1d power spectra (different f ranges) - plot_spectra([tfg.freqs, tfg.freqs[:-5]], - [tfg.power_spectra[0, :], tfg.power_spectra[1, :-5]], + plot_spectra([tfg.data.freqs, tfg.data.freqs[:-5]], + [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :-5]], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_lists_1d.png') # Test with 2d array inputs - plot_spectra(np.vstack([tfg.freqs, tfg.freqs]), - np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]), + plot_spectra(np.vstack([tfg.data.freqs, tfg.data.freqs]), + np.vstack([tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]]), file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png') # Test with labels - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'], + plot_spectra(tfg.data.freqs, [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]], labels=['A', 'B'], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_labels.png') @plot_test def test_plot_spectra_shading(tfm, tfg, skip_if_no_mpl): - plot_spectra_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True, + plot_spectra_shading(tfm.data.freqs, tfm.data.power_spectrum, shades=[8, 12], add_center=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrum_shading1.png') - plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + plot_spectra_shading(tfg.data.freqs, [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]], shades=[8, 12], add_center=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_shading2.png') # Test with **kwargs that pass into plot_spectra - plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + plot_spectra_shading(tfg.data.freqs, [tfg.data.power_spectra[0, :], tfg.data.power_spectra[1, :]], shades=[8, 12], add_center=True, log_freqs=True, log_powers=True, labels=['A', 'B'], file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_shading_kwargs.png') @@ -61,8 +61,8 @@ def test_plot_spectra_shading(tfm, tfg, skip_if_no_mpl): @plot_test def test_plot_spectra_yshade(skip_if_no_mpl, tfg): - freqs = tfg.freqs - powers = tfg.power_spectra + freqs = tfg.data.freqs + powers = tfg.data.power_spectra # Invalid 1d array, without shade with raises(ValueError): @@ -87,8 +87,8 @@ def test_plot_spectra_yshade(skip_if_no_mpl, tfg): @plot_test def test_plot_spectrogram(skip_if_no_mpl, tft): - freqs = tft.freqs - spectrogram = np.tile(tft.power_spectra.T, 50) + freqs = tft.data.freqs + spectrogram = np.tile(tft.data.power_spectra.T, 50) plot_spectrogram(freqs, spectrogram, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrogram.png') diff --git a/specparam/utils/select.py b/specparam/utils/select.py index 07fa1eb4e..229cb179f 100644 --- a/specparam/utils/select.py +++ b/specparam/utils/select.py @@ -55,6 +55,7 @@ def nearest_ind(array, value): return np.argmin(np.abs(array - value)) +# TEMP: TO DROP? # def get_freq_ind(freqs, freq): # """Get the index of the closest frequency value to a specified input frequency.