Skip to content

Commit 7668054

Browse files
committed
added variance partitioning and related code
1 parent 8d4c330 commit 7668054

File tree

4 files changed

+1229
-0
lines changed

4 files changed

+1229
-0
lines changed

Variance Partitioning.ipynb

+627
Large diffs are not rendered by default.

npp.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""This module contains one line functions that should, by all rights, by in numpy.
2+
"""
3+
import numpy as np
4+
5+
## Demean -- remove the mean from each column
6+
demean = lambda v: v-v.mean(0)
7+
demean.__doc__ = """Removes the mean from each column of [v]."""
8+
dm = demean
9+
10+
## Z-score -- z-score each column
11+
zscore = lambda v: (v-v.mean(0))/v.std(0)
12+
zscore.__doc__ = """Z-scores (standardizes) each column of [v]."""
13+
zs = zscore
14+
15+
## Rescale -- make each column have unit variance
16+
rescale = lambda v: v/v.std(0)
17+
rescale.__doc__ = """Rescales each column of [v] to have unit variance."""
18+
rs = rescale
19+
20+
## Matrix corr -- find correlation between each column of c1 and the corresponding column of c2
21+
mcorr = lambda c1,c2: (zs(c1)*zs(c2)).mean(0)
22+
mcorr.__doc__ = """Matrix correlation. Find the correlation between each column of [c1] and the corresponding column of [c2]."""
23+
24+
## Cross corr -- find corr. between each row of c1 and EACH row of c2
25+
xcorr = lambda c1,c2: np.dot(zs(c1.T).T,zs(c2.T)) / (c1.shape[1])
26+
xcorr.__doc__ = """Cross-column correlation. Finds the correlation between each row of [c1] and each row of [c2]."""

ridge.py

+354
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
#import scipy
2+
import numpy as np
3+
import logging
4+
from utils import mult_diag, counter
5+
import random
6+
import itertools as itools
7+
8+
zs = lambda v: (v-v.mean(0))/v.std(0) ## z-score function
9+
10+
ridge_logger = logging.getLogger("ridge_corr")
11+
12+
def ridge(stim, resp, alpha, singcutoff=1e-10, normalpha=False, logger=ridge_logger):
13+
"""Uses ridge regression to find a linear transformation of [stim] that approximates
14+
[resp]. The regularization parameter is [alpha].
15+
16+
Parameters
17+
----------
18+
stim : array_like, shape (T, N)
19+
Stimuli with T time points and N features.
20+
resp : array_like, shape (T, M)
21+
Responses with T time points and M separate responses.
22+
alpha : float or array_like, shape (M,)
23+
Regularization parameter. Can be given as a single value (which is applied to
24+
all M responses) or separate values for each response.
25+
normalpha : boolean
26+
Whether ridge parameters should be normalized by the largest singular value of stim. Good for
27+
comparing models with different numbers of parameters.
28+
29+
Returns
30+
-------
31+
wt : array_like, shape (N, M)
32+
Linear regression weights.
33+
"""
34+
try:
35+
U,S,Vh = np.linalg.svd(stim, full_matrices=False)
36+
except np.linalg.LinAlgError:
37+
logger.info("NORMAL SVD FAILED, trying more robust dgesvd..")
38+
from text.regression.svd_dgesvd import svd_dgesvd
39+
U,S,Vh = svd_dgesvd(stim, full_matrices=False)
40+
41+
UR = np.dot(U.T, np.nan_to_num(resp))
42+
43+
# Expand alpha to a collection if it's just a single value
44+
if isinstance(alpha, float):
45+
alpha = np.ones(resp.shape[1]) * alpha
46+
47+
# Normalize alpha by the LSV norm
48+
norm = S[0]
49+
if normalpha:
50+
nalphas = alpha * norm
51+
else:
52+
nalphas = alpha
53+
54+
# Compute weights for each alpha
55+
ualphas = np.unique(nalphas)
56+
wt = np.zeros((stim.shape[1], resp.shape[1]))
57+
for ua in ualphas:
58+
selvox = np.nonzero(nalphas==ua)[0]
59+
awt = reduce(np.dot, [Vh.T, np.diag(S/(S**2+ua**2)), UR[:,selvox]])
60+
wt[:,selvox] = awt
61+
62+
return wt
63+
64+
65+
def ridge_corr(Rstim, Pstim, Rresp, Presp, alphas, normalpha=False, corrmin=0.2,
66+
singcutoff=1e-10, use_corr=True, logger=ridge_logger):
67+
"""Uses ridge regression to find a linear transformation of [Rstim] that approximates [Rresp],
68+
then tests by comparing the transformation of [Pstim] to [Presp]. This procedure is repeated
69+
for each regularization parameter alpha in [alphas]. The correlation between each prediction and
70+
each response for each alpha is returned. The regression weights are NOT returned, because
71+
computing the correlations without computing regression weights is much, MUCH faster.
72+
73+
Parameters
74+
----------
75+
Rstim : array_like, shape (TR, N)
76+
Training stimuli with TR time points and N features. Each feature should be Z-scored across time.
77+
Pstim : array_like, shape (TP, N)
78+
Test stimuli with TP time points and N features. Each feature should be Z-scored across time.
79+
Rresp : array_like, shape (TR, M)
80+
Training responses with TR time points and M responses (voxels, neurons, what-have-you).
81+
Each response should be Z-scored across time.
82+
Presp : array_like, shape (TP, M)
83+
Test responses with TP time points and M responses.
84+
alphas : list or array_like, shape (A,)
85+
Ridge parameters to be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well.
86+
normalpha : boolean
87+
Whether ridge parameters should be normalized by the largest singular value (LSV) norm of
88+
Rstim. Good for comparing models with different numbers of parameters.
89+
corrmin : float in [0..1]
90+
Purely for display purposes. After each alpha is tested, the number of responses with correlation
91+
greater than corrmin minus the number of responses with correlation less than negative corrmin
92+
will be printed. For long-running regressions this vague metric of non-centered skewness can
93+
give you a rough sense of how well the model is working before it's done.
94+
singcutoff : float
95+
The first step in ridge regression is computing the singular value decomposition (SVD) of the
96+
stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal
97+
to zero and the corresponding singular vectors will be noise. These singular values/vectors
98+
should be removed both for speed (the fewer multiplications the better!) and accuracy. Any
99+
singular values less than singcutoff will be removed.
100+
use_corr : boolean
101+
If True, this function will use correlation as its metric of model fit. If False, this function
102+
will instead use variance explained (R-squared) as its metric of model fit. For ridge regression
103+
this can make a big difference -- highly regularized solutions will have very small norms and
104+
will thus explain very little variance while still leading to high correlations, as correlation
105+
is scale-free while R**2 is not.
106+
107+
Returns
108+
-------
109+
Rcorrs : array_like, shape (A, M)
110+
The correlation between each predicted response and each column of Presp for each alpha.
111+
112+
"""
113+
## Calculate SVD of stimulus matrix
114+
logger.info("Doing SVD...")
115+
try:
116+
U,S,Vh = np.linalg.svd(Rstim, full_matrices=False)
117+
except np.linalg.LinAlgError:
118+
logger.info("NORMAL SVD FAILED, trying more robust dgesvd..")
119+
from text.regression.svd_dgesvd import svd_dgesvd
120+
U,S,Vh = svd_dgesvd(Rstim, full_matrices=False)
121+
122+
## Truncate tiny singular values for speed
123+
origsize = S.shape[0]
124+
ngoodS = np.sum(S > singcutoff)
125+
nbad = origsize-ngoodS
126+
U = U[:,:ngoodS]
127+
S = S[:ngoodS]
128+
Vh = Vh[:ngoodS]
129+
logger.info("Dropped %d tiny singular values.. (U is now %s)"%(nbad, str(U.shape)))
130+
131+
## Normalize alpha by the LSV norm
132+
norm = S[0]
133+
logger.info("Training stimulus has LSV norm: %0.03f"%norm)
134+
if normalpha:
135+
nalphas = alphas * norm
136+
else:
137+
nalphas = alphas
138+
139+
## Precompute some products for speed
140+
UR = np.dot(U.T, Rresp) ## Precompute this matrix product for speed
141+
PVh = np.dot(Pstim, Vh.T) ## Precompute this matrix product for speed
142+
143+
#Prespnorms = np.apply_along_axis(np.linalg.norm, 0, Presp) ## Precompute test response norms
144+
zPresp = zs(Presp)
145+
#Prespvar = Presp.var(0)
146+
Prespvar_actual = Presp.var(0)
147+
Prespvar = (np.ones_like(Prespvar_actual) + Prespvar_actual) / 2.0
148+
logger.info("Average difference between actual & assumed Prespvar: %0.3f" % (Prespvar_actual - Prespvar).mean())
149+
Rcorrs = [] ## Holds training correlations for each alpha
150+
for na, a in zip(nalphas, alphas):
151+
#D = np.diag(S/(S**2+a**2)) ## Reweight singular vectors by the ridge parameter
152+
D = S / (S ** 2 + na ** 2) ## Reweight singular vectors by the (normalized?) ridge parameter
153+
154+
pred = np.dot(mult_diag(D, PVh, left=False), UR) ## Best (1.75 seconds to prediction in test)
155+
# pred = np.dot(mult_diag(D, np.dot(Pstim, Vh.T), left=False), UR) ## Better (2.0 seconds to prediction in test)
156+
157+
# pvhd = reduce(np.dot, [Pstim, Vh.T, D]) ## Pretty good (2.4 seconds to prediction in test)
158+
# pred = np.dot(pvhd, UR)
159+
160+
# wt = reduce(np.dot, [Vh.T, D, UR]).astype(dtype) ## Bad (14.2 seconds to prediction in test)
161+
# wt = reduce(np.dot, [Vh.T, D, U.T, Rresp]).astype(dtype) ## Worst
162+
# pred = np.dot(Pstim, wt) ## Predict test responses
163+
164+
if use_corr:
165+
#prednorms = np.apply_along_axis(np.linalg.norm, 0, pred) ## Compute predicted test response norms
166+
#Rcorr = np.array([np.corrcoef(Presp[:,ii], pred[:,ii].ravel())[0,1] for ii in range(Presp.shape[1])]) ## Slowly compute correlations
167+
#Rcorr = np.array(np.sum(np.multiply(Presp, pred), 0)).squeeze()/(prednorms*Prespnorms) ## Efficiently compute correlations
168+
Rcorr = (zPresp * zs(pred)).mean(0)
169+
else:
170+
## Compute variance explained
171+
resvar = (Presp - pred).var(0)
172+
Rsq = 1 - (resvar / Prespvar)
173+
Rcorr = np.sqrt(np.abs(Rsq)) * np.sign(Rsq)
174+
175+
Rcorr[np.isnan(Rcorr)] = 0
176+
Rcorrs.append(Rcorr)
177+
178+
log_template = "Training: alpha=%0.3f, mean corr=%0.5f, max corr=%0.5f, over-under(%0.2f)=%d"
179+
log_msg = log_template % (a,
180+
np.mean(Rcorr),
181+
np.max(Rcorr),
182+
corrmin,
183+
(Rcorr>corrmin).sum()-(-Rcorr>corrmin).sum())
184+
logger.info(log_msg)
185+
186+
return Rcorrs
187+
188+
189+
def bootstrap_ridge(Rstim, Rresp, Pstim, Presp, alphas, nboots, chunklen, nchunks,
190+
corrmin=0.2, joined=None, singcutoff=1e-10, normalpha=False, single_alpha=False,
191+
use_corr=True, logger=ridge_logger):
192+
"""Uses ridge regression with a bootstrapped held-out set to get optimal alpha values for each response.
193+
[nchunks] random chunks of length [chunklen] will be taken from [Rstim] and [Rresp] for each regression
194+
run. [nboots] total regression runs will be performed. The best alpha value for each response will be
195+
averaged across the bootstraps to estimate the best alpha for that response.
196+
197+
If [joined] is given, it should be a list of lists where the STRFs for all the voxels in each sublist
198+
will be given the same regularization parameter (the one that is the best on average).
199+
200+
Parameters
201+
----------
202+
Rstim : array_like, shape (TR, N)
203+
Training stimuli with TR time points and N features. Each feature should be Z-scored across time.
204+
Rresp : array_like, shape (TR, M)
205+
Training responses with TR time points and M different responses (voxels, neurons, what-have-you).
206+
Each response should be Z-scored across time.
207+
Pstim : array_like, shape (TP, N)
208+
Test stimuli with TP time points and N features. Each feature should be Z-scored across time.
209+
Presp : array_like, shape (TP, M)
210+
Test responses with TP time points and M different responses. Each response should be Z-scored across
211+
time.
212+
alphas : list or array_like, shape (A,)
213+
Ridge parameters that will be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well.
214+
nboots : int
215+
The number of bootstrap samples to run. 15 to 30 works well.
216+
chunklen : int
217+
On each sample, the training data is broken into chunks of this length. This should be a few times
218+
longer than your delay/STRF. e.g. for a STRF with 3 delays, I use chunks of length 10.
219+
nchunks : int
220+
The number of training chunks held out to test ridge parameters for each bootstrap sample. The product
221+
of nchunks and chunklen is the total number of training samples held out for each sample, and this
222+
product should be about 20 percent of the total length of the training data.
223+
corrmin : float in [0..1]
224+
Purely for display purposes. After each alpha is tested for each bootstrap sample, the number of
225+
responses with correlation greater than this value will be printed. For long-running regressions this
226+
can give a rough sense of how well the model works before it's done.
227+
joined : None or list of array_like indices
228+
If you want the STRFs for two (or more) responses to be directly comparable, you need to ensure that
229+
the regularization parameter that they use is the same. To do that, supply a list of the response sets
230+
that should use the same ridge parameter here. For example, if you have four responses, joined could
231+
be [np.array([0,1]), np.array([2,3])], in which case responses 0 and 1 will use the same ridge parameter
232+
(which will be parameter that is best on average for those two), and likewise for responses 2 and 3.
233+
singcutoff : float
234+
The first step in ridge regression is computing the singular value decomposition (SVD) of the
235+
stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal
236+
to zero and the corresponding singular vectors will be noise. These singular values/vectors
237+
should be removed both for speed (the fewer multiplications the better!) and accuracy. Any
238+
singular values less than singcutoff will be removed.
239+
normalpha : boolean
240+
Whether ridge parameters (alphas) should be normalized by the largest singular value (LSV)
241+
norm of Rstim. Good for rigorously comparing models with different numbers of parameters.
242+
single_alpha : boolean
243+
Whether to use a single alpha for all responses. Good for identification/decoding.
244+
use_corr : boolean
245+
If True, this function will use correlation as its metric of model fit. If False, this function
246+
will instead use variance explained (R-squared) as its metric of model fit. For ridge regression
247+
this can make a big difference -- highly regularized solutions will have very small norms and
248+
will thus explain very little variance while still leading to high correlations, as correlation
249+
is scale-free while R**2 is not.
250+
251+
Returns
252+
-------
253+
wt : array_like, shape (N, M)
254+
Regression weights for N features and M responses.
255+
corrs : array_like, shape (M,)
256+
Validation set correlations. Predicted responses for the validation set are obtained using the regression
257+
weights: pred = np.dot(Pstim, wt), and then the correlation between each predicted response and each
258+
column in Presp is found.
259+
alphas : array_like, shape (M,)
260+
The regularization coefficient (alpha) selected for each voxel using bootstrap cross-validation.
261+
bootstrap_corrs : array_like, shape (A, M, B)
262+
Correlation between predicted and actual responses on randomly held out portions of the training set,
263+
for each of A alphas, M voxels, and B bootstrap samples.
264+
valinds : array_like, shape (TH, B)
265+
The indices of the training data that were used as "validation" for each bootstrap sample.
266+
"""
267+
nresp, nvox = Rresp.shape
268+
valinds = [] # Will hold the indices into the validation data for each bootstrap
269+
270+
Rcmats = []
271+
for bi in counter(range(nboots), countevery=1, total=nboots):
272+
logger.info("Selecting held-out test set..")
273+
allinds = range(nresp)
274+
indchunks = zip(*[iter(allinds)]*chunklen)
275+
random.shuffle(indchunks)
276+
heldinds = list(itools.chain(*indchunks[:nchunks]))
277+
notheldinds = list(set(allinds)-set(heldinds))
278+
valinds.append(heldinds)
279+
280+
RRstim = Rstim[notheldinds,:]
281+
PRstim = Rstim[heldinds,:]
282+
RRresp = Rresp[notheldinds,:]
283+
PRresp = Rresp[heldinds,:]
284+
285+
# Run ridge regression using this test set
286+
Rcmat = ridge_corr(RRstim, PRstim, RRresp, PRresp, alphas,
287+
corrmin=corrmin, singcutoff=singcutoff,
288+
normalpha=normalpha, use_corr=use_corr,
289+
logger=logger)
290+
291+
Rcmats.append(Rcmat)
292+
293+
# Find best alphas
294+
if nboots>0:
295+
allRcorrs = np.dstack(Rcmats)
296+
else:
297+
allRcorrs = None
298+
299+
if not single_alpha:
300+
if nboots==0:
301+
raise ValueError("You must run at least one cross-validation step to assign "
302+
"different alphas to each response.")
303+
304+
logger.info("Finding best alpha for each voxel..")
305+
if joined is None:
306+
# Find best alpha for each voxel
307+
meanbootcorrs = allRcorrs.mean(2)
308+
bestalphainds = np.argmax(meanbootcorrs, 0)
309+
valphas = alphas[bestalphainds]
310+
else:
311+
# Find best alpha for each group of voxels
312+
valphas = np.zeros((nvox,))
313+
for jl in joined:
314+
# Mean across voxels in the set, then mean across bootstraps
315+
jcorrs = allRcorrs[:,jl,:].mean(1).mean(1)
316+
bestalpha = np.argmax(jcorrs)
317+
valphas[jl] = alphas[bestalpha]
318+
else:
319+
logger.info("Finding single best alpha..")
320+
if nboots==0:
321+
if len(alphas)==1:
322+
bestalphaind = 0
323+
bestalpha = alphas[0]
324+
else:
325+
raise ValueError("You must run at least one cross-validation step "
326+
"to choose best overall alpha, or only supply one"
327+
"possible alpha value.")
328+
else:
329+
meanbootcorr = allRcorrs.mean(2).mean(1)
330+
bestalphaind = np.argmax(meanbootcorr)
331+
bestalpha = alphas[bestalphaind]
332+
333+
valphas = np.array([bestalpha]*nvox)
334+
logger.info("Best alpha = %0.3f"%bestalpha)
335+
336+
# Find weights
337+
logger.info("Computing weights for each response using entire training set..")
338+
wt = ridge(Rstim, Rresp, valphas, singcutoff=singcutoff, normalpha=normalpha)
339+
340+
# Predict responses on prediction set
341+
logger.info("Predicting responses for predictions set..")
342+
pred = np.dot(Pstim, wt)
343+
344+
# Find prediction correlations
345+
nnpred = np.nan_to_num(pred)
346+
if use_corr:
347+
corrs = np.nan_to_num(np.array([np.corrcoef(Presp[:,ii], nnpred[:,ii].ravel())[0,1]
348+
for ii in range(Presp.shape[1])]))
349+
else:
350+
resvar = (Presp-pred).var(0)
351+
Rsqs = 1 - (resvar / Presp.var(0))
352+
corrs = np.sqrt(np.abs(Rsqs)) * np.sign(Rsqs)
353+
354+
return wt, corrs, valphas, allRcorrs, valinds

0 commit comments

Comments
 (0)