Skip to content

Commit

Permalink
Merge pull request #331 from rsagroup/unbalanced_movie
Browse files Browse the repository at this point in the history
added possibility to run calc_unbalanced for rdm movies
  • Loading branch information
JasperVanDenBosch authored Oct 31, 2023
2 parents 4c5aa1d + 15ae7a9 commit 330af9c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 59 deletions.
78 changes: 20 additions & 58 deletions src/rsatoolbox/rdm/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Optional, Tuple
import numpy as np
from rsatoolbox.rdm.rdms import RDMs
from rsatoolbox.rdm.rdms import concat
from rsatoolbox.rdm.calc_unbalanced import calc_rdm_unbalanced
from rsatoolbox.rdm.combine import from_partials
from rsatoolbox.data import average_dataset_by
from rsatoolbox.util.rdm_utils import _extract_triu_
from rsatoolbox.util.build_rdm import _build_rdms

if TYPE_CHECKING:
from rsatoolbox.data.base import DatasetBase
from numpy.typing import NDArray
Expand Down Expand Up @@ -101,7 +103,7 @@ def calc_rdm(dataset, method='euclidean', descriptor=None, noise=None,
def calc_rdm_movie(
dataset, method='euclidean', descriptor=None, noise=None,
cv_descriptor=None, prior_lambda=1, prior_weight=0.1,
time_descriptor='time', bins=None):
time_descriptor='time', bins=None, unbalanced=False):
"""
calculates an RDM movie from an input TemporalDataset
Expand All @@ -121,6 +123,8 @@ def calc_rdm_movie(
dimension in dataset.time_descriptors. Defaults to 'time'.
bins (array-like): list of bins, with bins[i] containing the vector
of time-points for the i-th bin. Defaults to no binning.
unbalanced (bool): if set to True use calc_rdm_unbalanced,
else and by default use calc_rdm
Returns:
rsatoolbox.rdm.rdms.RDMs: RDMs object with RDM movie
Expand Down Expand Up @@ -156,11 +160,20 @@ def calc_rdm_movie(
rdms = []
for dat in splited_data:
dat_single = dat.convert_to_dataset(time_descriptor)
rdms.append(calc_rdm(dat_single, method=method,
descriptor=descriptor, noise=noise,
cv_descriptor=cv_descriptor,
prior_lambda=prior_lambda,
prior_weight=prior_weight))
if unbalanced:
rdms.append(calc_rdm_unbalanced(
dat_single, method=method,
descriptor=descriptor, noise=noise,
cv_descriptor=cv_descriptor,
prior_lambda=prior_lambda,
prior_weight=prior_weight))
else:
rdms.append(calc_rdm(
dat_single, method=method,
descriptor=descriptor, noise=noise,
cv_descriptor=cv_descriptor,
prior_lambda=prior_lambda,
prior_weight=prior_weight))

rdm = concat(rdms)
rdm.rdm_descriptors[time_descriptor] = time
Expand Down Expand Up @@ -488,54 +501,3 @@ def _check_noise(noise, n_channel):
else:
raise ValueError('noise(s) must have shape n_channel x n_channel')
return noise


def _build_rdms(
utv: NDArray,
ds: DatasetBase,
method: str,
obs_desc_name: str | None,
obs_desc_vals: Optional[NDArray] = None,
cv: Optional[NDArray] = None,
noise: Optional[NDArray] = None
) -> RDMs:
rdms = RDMs(
dissimilarities=np.array([utv]),
dissimilarity_measure=method,
rdm_descriptors=deepcopy(ds.descriptors)
)
if (obs_desc_vals is None) and (obs_desc_name is not None):
# obtain the unique values in the target obs descriptor
_, obs_desc_vals, _ = average_dataset_by(ds, obs_desc_name)

if _averaging_occurred(ds, obs_desc_name, obs_desc_vals):
orig_obs_desc_vals = np.asarray(ds.obs_descriptors[obs_desc_name])
for dname, dvals in ds.obs_descriptors.items():
dvals = np.asarray(dvals)
avg_dvals = np.full_like(obs_desc_vals, np.nan, dtype=dvals.dtype)
for i, v in enumerate(obs_desc_vals):
subset = dvals[orig_obs_desc_vals == v]
if len(set(subset)) > 1:
break
avg_dvals[i] = subset[0]
else:
rdms.pattern_descriptors[dname] = avg_dvals
else:
rdms.pattern_descriptors = deepcopy(ds.obs_descriptors)
# Additional rdm_descriptors
if noise is not None:
rdms.descriptors['noise'] = noise
if cv is not None:
rdms.descriptors['cv_descriptor'] = cv
return rdms


def _averaging_occurred(
ds: DatasetBase,
obs_desc_name: str | None,
obs_desc_vals: NDArray | None
) -> bool:
if obs_desc_name is None:
return False
orig_obs_desc_vals = ds.obs_descriptors[obs_desc_name]
return len(obs_desc_vals) != len(orig_obs_desc_vals)
2 changes: 1 addition & 1 deletion src/rsatoolbox/rdm/calc_unbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
import numpy as np
from rsatoolbox.rdm.rdms import RDMs
from rsatoolbox.rdm.rdms import concat
from rsatoolbox.rdm.calc import _build_rdms
from rsatoolbox.util.data_utils import get_unique_inverse
from rsatoolbox.util.matrix import row_col_indicator_rdm
from rsatoolbox.util.build_rdm import _build_rdms
from rsatoolbox.cengine.similarity import calc_one, calc
if TYPE_CHECKING:
from rsatoolbox.data.base import DatasetBase
Expand Down
65 changes: 65 additions & 0 deletions src/rsatoolbox/util/build_rdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""helper methods to create RDMs at the end of calculations"""

from __future__ import annotations
from typing import TYPE_CHECKING, Optional
from copy import deepcopy
import numpy as np
from rsatoolbox.rdm.rdms import RDMs
from rsatoolbox.data import average_dataset_by

if TYPE_CHECKING:
from rsatoolbox.data.base import DatasetBase
from numpy.typing import NDArray


def _build_rdms(
utv: NDArray,
ds: DatasetBase,
method: str,
obs_desc_name: str | None,
obs_desc_vals: Optional[NDArray] = None,
cv: Optional[NDArray] = None,
noise: Optional[NDArray] = None
) -> RDMs:
rdms = RDMs(
dissimilarities=np.array([utv]),
dissimilarity_measure=method,
rdm_descriptors=deepcopy(ds.descriptors)
)
if (obs_desc_vals is None) and (obs_desc_name is not None):
# obtain the unique values in the target obs descriptor
_, obs_desc_vals, _ = average_dataset_by(ds, obs_desc_name)

if _averaging_occurred(ds, obs_desc_name, obs_desc_vals):
orig_obs_desc_vals = np.asarray(ds.obs_descriptors[obs_desc_name])
for dname, dvals in ds.obs_descriptors.items():
dvals = np.asarray(dvals)
avg_dvals = np.full_like(obs_desc_vals, np.nan, dtype=dvals.dtype)
for i, v in enumerate(obs_desc_vals):
subset = dvals[orig_obs_desc_vals == v]
if len(set(subset)) > 1:
break
avg_dvals[i] = subset[0]
else:
rdms.pattern_descriptors[dname] = avg_dvals
else:
rdms.pattern_descriptors = deepcopy(ds.obs_descriptors)
# Additional rdm_descriptors
if noise is not None:
rdms.descriptors['noise'] = noise
if cv is not None:
rdms.descriptors['cv_descriptor'] = cv
return rdms


def _averaging_occurred(
ds: DatasetBase,
obs_desc_name: str | None,
obs_desc_vals: NDArray | None
) -> bool:
if obs_desc_name is None:
return False
orig_obs_desc_vals = ds.obs_descriptors[obs_desc_name]
return len(obs_desc_vals) != len(orig_obs_desc_vals)

0 comments on commit 330af9c

Please sign in to comment.