Skip to content

Commit

Permalink
Merge pull request #403 from rsagroup/402-rdmsrank_transform-treats-n…
Browse files Browse the repository at this point in the history
…ans-as-data

Omit `nan`s when running `rdms.rank_transform`.
  • Loading branch information
JasperVanDenBosch authored Jul 16, 2024
2 parents 4331bf5 + 0d6ee2d commit 82b105e
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
numpy>=1.21.2
scipy
scipy>=1.10.1
scikit-learn
scikit-image
pandas
Expand Down
21 changes: 12 additions & 9 deletions src/rsatoolbox/rdm/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .rdms import RDMs


def rank_transform(rdms: RDMs, method='average'):
def rank_transform(rdms: RDMs, method='average') -> RDMs:
""" applies a rank_transform and generates a new RDMs object
This assigns a rank to each dissimilarity estimate in the RDM,
deals with rank ties and saves ranks as new dissimilarity estimates.
Expand All @@ -30,17 +30,20 @@ def rank_transform(rdms: RDMs, method='average'):
"""
dissimilarities = rdms.get_vectors()
dissimilarities = np.array([rankdata(dissimilarities[i], method=method)
for i in range(rdms.n_rdm)])
cfg = dict(method=method, nan_policy='omit')
dissimilarities = np.array(
[rankdata(dissimilarities[i], **cfg) for i in range(rdms.n_rdm)]
)
measure = rdms.dissimilarity_measure or ''
if '(ranks)' not in measure:
measure = (measure + ' (ranks)').strip()
rdms_new = RDMs(dissimilarities,
dissimilarity_measure=measure,
descriptors=deepcopy(rdms.descriptors),
rdm_descriptors=deepcopy(rdms.rdm_descriptors),
pattern_descriptors=deepcopy(rdms.pattern_descriptors))
return rdms_new
return RDMs(
dissimilarities,
dissimilarity_measure=measure,
descriptors=deepcopy(rdms.descriptors),
rdm_descriptors=deepcopy(rdms.rdm_descriptors),
pattern_descriptors=deepcopy(rdms.pattern_descriptors)
)


def sqrt_transform(rdms):
Expand Down
20 changes: 7 additions & 13 deletions tests/test_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,21 +234,15 @@ def square(x):
self.assertEqual(transformed_rdm.n_cond, rdms.n_cond)

def test_rank_transform(self):
from rsatoolbox.rdm import rank_transform
dis = np.zeros((8, 10))
mes = "Euclidean"
des = {'subj': 0}
pattern_des = {'type': np.array([0, 1, 2, 2, 4])}
rdm_des = {'session': np.array([0, 1, 2, 2, 4, 5, 6, 7])}
rdms = rsr.RDMs(dissimilarities=dis,
rdm_descriptors=rdm_des,
pattern_descriptors=pattern_des,
dissimilarity_measure=mes,
descriptors=des)
from rsatoolbox.rdm.transform import rank_transform
from rsatoolbox.rdm.rdms import RDMs
rdms = RDMs(
dissimilarities=np.array([[8, 6, 10, np.nan]]),
dissimilarity_measure="Euclidean",
)
rank_rdm = rank_transform(rdms)
self.assertEqual(rank_rdm.n_rdm, rdms.n_rdm)
self.assertEqual(rank_rdm.n_cond, rdms.n_cond)
self.assertEqual(rank_rdm.dissimilarity_measure, 'Euclidean (ranks)')
assert_array_equal(rank_rdm.dissimilarities, [[2, 1, 3, np.nan]])

def test_rank_transform_unknown_measure(self):
from rsatoolbox.rdm import rank_transform
Expand Down

0 comments on commit 82b105e

Please sign in to comment.