|
| 1 | +#import scipy |
| 2 | +import numpy as np |
| 3 | +import logging |
| 4 | +from utils import mult_diag, counter |
| 5 | +import random |
| 6 | +import itertools as itools |
| 7 | + |
| 8 | +zs = lambda v: (v-v.mean(0))/v.std(0) ## z-score function |
| 9 | + |
| 10 | +ridge_logger = logging.getLogger("ridge_corr") |
| 11 | + |
| 12 | +def ridge(stim, resp, alpha, singcutoff=1e-10, normalpha=False, logger=ridge_logger): |
| 13 | + """Uses ridge regression to find a linear transformation of [stim] that approximates |
| 14 | + [resp]. The regularization parameter is [alpha]. |
| 15 | +
|
| 16 | + Parameters |
| 17 | + ---------- |
| 18 | + stim : array_like, shape (T, N) |
| 19 | + Stimuli with T time points and N features. |
| 20 | + resp : array_like, shape (T, M) |
| 21 | + Responses with T time points and M separate responses. |
| 22 | + alpha : float or array_like, shape (M,) |
| 23 | + Regularization parameter. Can be given as a single value (which is applied to |
| 24 | + all M responses) or separate values for each response. |
| 25 | + normalpha : boolean |
| 26 | + Whether ridge parameters should be normalized by the largest singular value of stim. Good for |
| 27 | + comparing models with different numbers of parameters. |
| 28 | +
|
| 29 | + Returns |
| 30 | + ------- |
| 31 | + wt : array_like, shape (N, M) |
| 32 | + Linear regression weights. |
| 33 | + """ |
| 34 | + try: |
| 35 | + U,S,Vh = np.linalg.svd(stim, full_matrices=False) |
| 36 | + except np.linalg.LinAlgError: |
| 37 | + logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") |
| 38 | + from text.regression.svd_dgesvd import svd_dgesvd |
| 39 | + U,S,Vh = svd_dgesvd(stim, full_matrices=False) |
| 40 | + |
| 41 | + UR = np.dot(U.T, np.nan_to_num(resp)) |
| 42 | + |
| 43 | + # Expand alpha to a collection if it's just a single value |
| 44 | + if isinstance(alpha, float): |
| 45 | + alpha = np.ones(resp.shape[1]) * alpha |
| 46 | + |
| 47 | + # Normalize alpha by the LSV norm |
| 48 | + norm = S[0] |
| 49 | + if normalpha: |
| 50 | + nalphas = alpha * norm |
| 51 | + else: |
| 52 | + nalphas = alpha |
| 53 | + |
| 54 | + # Compute weights for each alpha |
| 55 | + ualphas = np.unique(nalphas) |
| 56 | + wt = np.zeros((stim.shape[1], resp.shape[1])) |
| 57 | + for ua in ualphas: |
| 58 | + selvox = np.nonzero(nalphas==ua)[0] |
| 59 | + awt = reduce(np.dot, [Vh.T, np.diag(S/(S**2+ua**2)), UR[:,selvox]]) |
| 60 | + wt[:,selvox] = awt |
| 61 | + |
| 62 | + return wt |
| 63 | + |
| 64 | + |
| 65 | +def ridge_corr(Rstim, Pstim, Rresp, Presp, alphas, normalpha=False, corrmin=0.2, |
| 66 | + singcutoff=1e-10, use_corr=True, logger=ridge_logger): |
| 67 | + """Uses ridge regression to find a linear transformation of [Rstim] that approximates [Rresp], |
| 68 | + then tests by comparing the transformation of [Pstim] to [Presp]. This procedure is repeated |
| 69 | + for each regularization parameter alpha in [alphas]. The correlation between each prediction and |
| 70 | + each response for each alpha is returned. The regression weights are NOT returned, because |
| 71 | + computing the correlations without computing regression weights is much, MUCH faster. |
| 72 | +
|
| 73 | + Parameters |
| 74 | + ---------- |
| 75 | + Rstim : array_like, shape (TR, N) |
| 76 | + Training stimuli with TR time points and N features. Each feature should be Z-scored across time. |
| 77 | + Pstim : array_like, shape (TP, N) |
| 78 | + Test stimuli with TP time points and N features. Each feature should be Z-scored across time. |
| 79 | + Rresp : array_like, shape (TR, M) |
| 80 | + Training responses with TR time points and M responses (voxels, neurons, what-have-you). |
| 81 | + Each response should be Z-scored across time. |
| 82 | + Presp : array_like, shape (TP, M) |
| 83 | + Test responses with TP time points and M responses. |
| 84 | + alphas : list or array_like, shape (A,) |
| 85 | + Ridge parameters to be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. |
| 86 | + normalpha : boolean |
| 87 | + Whether ridge parameters should be normalized by the largest singular value (LSV) norm of |
| 88 | + Rstim. Good for comparing models with different numbers of parameters. |
| 89 | + corrmin : float in [0..1] |
| 90 | + Purely for display purposes. After each alpha is tested, the number of responses with correlation |
| 91 | + greater than corrmin minus the number of responses with correlation less than negative corrmin |
| 92 | + will be printed. For long-running regressions this vague metric of non-centered skewness can |
| 93 | + give you a rough sense of how well the model is working before it's done. |
| 94 | + singcutoff : float |
| 95 | + The first step in ridge regression is computing the singular value decomposition (SVD) of the |
| 96 | + stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal |
| 97 | + to zero and the corresponding singular vectors will be noise. These singular values/vectors |
| 98 | + should be removed both for speed (the fewer multiplications the better!) and accuracy. Any |
| 99 | + singular values less than singcutoff will be removed. |
| 100 | + use_corr : boolean |
| 101 | + If True, this function will use correlation as its metric of model fit. If False, this function |
| 102 | + will instead use variance explained (R-squared) as its metric of model fit. For ridge regression |
| 103 | + this can make a big difference -- highly regularized solutions will have very small norms and |
| 104 | + will thus explain very little variance while still leading to high correlations, as correlation |
| 105 | + is scale-free while R**2 is not. |
| 106 | +
|
| 107 | + Returns |
| 108 | + ------- |
| 109 | + Rcorrs : array_like, shape (A, M) |
| 110 | + The correlation between each predicted response and each column of Presp for each alpha. |
| 111 | + |
| 112 | + """ |
| 113 | + ## Calculate SVD of stimulus matrix |
| 114 | + logger.info("Doing SVD...") |
| 115 | + try: |
| 116 | + U,S,Vh = np.linalg.svd(Rstim, full_matrices=False) |
| 117 | + except np.linalg.LinAlgError: |
| 118 | + logger.info("NORMAL SVD FAILED, trying more robust dgesvd..") |
| 119 | + from text.regression.svd_dgesvd import svd_dgesvd |
| 120 | + U,S,Vh = svd_dgesvd(Rstim, full_matrices=False) |
| 121 | + |
| 122 | + ## Truncate tiny singular values for speed |
| 123 | + origsize = S.shape[0] |
| 124 | + ngoodS = np.sum(S > singcutoff) |
| 125 | + nbad = origsize-ngoodS |
| 126 | + U = U[:,:ngoodS] |
| 127 | + S = S[:ngoodS] |
| 128 | + Vh = Vh[:ngoodS] |
| 129 | + logger.info("Dropped %d tiny singular values.. (U is now %s)"%(nbad, str(U.shape))) |
| 130 | + |
| 131 | + ## Normalize alpha by the LSV norm |
| 132 | + norm = S[0] |
| 133 | + logger.info("Training stimulus has LSV norm: %0.03f"%norm) |
| 134 | + if normalpha: |
| 135 | + nalphas = alphas * norm |
| 136 | + else: |
| 137 | + nalphas = alphas |
| 138 | + |
| 139 | + ## Precompute some products for speed |
| 140 | + UR = np.dot(U.T, Rresp) ## Precompute this matrix product for speed |
| 141 | + PVh = np.dot(Pstim, Vh.T) ## Precompute this matrix product for speed |
| 142 | + |
| 143 | + #Prespnorms = np.apply_along_axis(np.linalg.norm, 0, Presp) ## Precompute test response norms |
| 144 | + zPresp = zs(Presp) |
| 145 | + #Prespvar = Presp.var(0) |
| 146 | + Prespvar_actual = Presp.var(0) |
| 147 | + Prespvar = (np.ones_like(Prespvar_actual) + Prespvar_actual) / 2.0 |
| 148 | + logger.info("Average difference between actual & assumed Prespvar: %0.3f" % (Prespvar_actual - Prespvar).mean()) |
| 149 | + Rcorrs = [] ## Holds training correlations for each alpha |
| 150 | + for na, a in zip(nalphas, alphas): |
| 151 | + #D = np.diag(S/(S**2+a**2)) ## Reweight singular vectors by the ridge parameter |
| 152 | + D = S / (S ** 2 + na ** 2) ## Reweight singular vectors by the (normalized?) ridge parameter |
| 153 | + |
| 154 | + pred = np.dot(mult_diag(D, PVh, left=False), UR) ## Best (1.75 seconds to prediction in test) |
| 155 | + # pred = np.dot(mult_diag(D, np.dot(Pstim, Vh.T), left=False), UR) ## Better (2.0 seconds to prediction in test) |
| 156 | + |
| 157 | + # pvhd = reduce(np.dot, [Pstim, Vh.T, D]) ## Pretty good (2.4 seconds to prediction in test) |
| 158 | + # pred = np.dot(pvhd, UR) |
| 159 | + |
| 160 | + # wt = reduce(np.dot, [Vh.T, D, UR]).astype(dtype) ## Bad (14.2 seconds to prediction in test) |
| 161 | + # wt = reduce(np.dot, [Vh.T, D, U.T, Rresp]).astype(dtype) ## Worst |
| 162 | + # pred = np.dot(Pstim, wt) ## Predict test responses |
| 163 | + |
| 164 | + if use_corr: |
| 165 | + #prednorms = np.apply_along_axis(np.linalg.norm, 0, pred) ## Compute predicted test response norms |
| 166 | + #Rcorr = np.array([np.corrcoef(Presp[:,ii], pred[:,ii].ravel())[0,1] for ii in range(Presp.shape[1])]) ## Slowly compute correlations |
| 167 | + #Rcorr = np.array(np.sum(np.multiply(Presp, pred), 0)).squeeze()/(prednorms*Prespnorms) ## Efficiently compute correlations |
| 168 | + Rcorr = (zPresp * zs(pred)).mean(0) |
| 169 | + else: |
| 170 | + ## Compute variance explained |
| 171 | + resvar = (Presp - pred).var(0) |
| 172 | + Rsq = 1 - (resvar / Prespvar) |
| 173 | + Rcorr = np.sqrt(np.abs(Rsq)) * np.sign(Rsq) |
| 174 | + |
| 175 | + Rcorr[np.isnan(Rcorr)] = 0 |
| 176 | + Rcorrs.append(Rcorr) |
| 177 | + |
| 178 | + log_template = "Training: alpha=%0.3f, mean corr=%0.5f, max corr=%0.5f, over-under(%0.2f)=%d" |
| 179 | + log_msg = log_template % (a, |
| 180 | + np.mean(Rcorr), |
| 181 | + np.max(Rcorr), |
| 182 | + corrmin, |
| 183 | + (Rcorr>corrmin).sum()-(-Rcorr>corrmin).sum()) |
| 184 | + logger.info(log_msg) |
| 185 | + |
| 186 | + return Rcorrs |
| 187 | + |
| 188 | + |
| 189 | +def bootstrap_ridge(Rstim, Rresp, Pstim, Presp, alphas, nboots, chunklen, nchunks, |
| 190 | + corrmin=0.2, joined=None, singcutoff=1e-10, normalpha=False, single_alpha=False, |
| 191 | + use_corr=True, logger=ridge_logger): |
| 192 | + """Uses ridge regression with a bootstrapped held-out set to get optimal alpha values for each response. |
| 193 | + [nchunks] random chunks of length [chunklen] will be taken from [Rstim] and [Rresp] for each regression |
| 194 | + run. [nboots] total regression runs will be performed. The best alpha value for each response will be |
| 195 | + averaged across the bootstraps to estimate the best alpha for that response. |
| 196 | + |
| 197 | + If [joined] is given, it should be a list of lists where the STRFs for all the voxels in each sublist |
| 198 | + will be given the same regularization parameter (the one that is the best on average). |
| 199 | + |
| 200 | + Parameters |
| 201 | + ---------- |
| 202 | + Rstim : array_like, shape (TR, N) |
| 203 | + Training stimuli with TR time points and N features. Each feature should be Z-scored across time. |
| 204 | + Rresp : array_like, shape (TR, M) |
| 205 | + Training responses with TR time points and M different responses (voxels, neurons, what-have-you). |
| 206 | + Each response should be Z-scored across time. |
| 207 | + Pstim : array_like, shape (TP, N) |
| 208 | + Test stimuli with TP time points and N features. Each feature should be Z-scored across time. |
| 209 | + Presp : array_like, shape (TP, M) |
| 210 | + Test responses with TP time points and M different responses. Each response should be Z-scored across |
| 211 | + time. |
| 212 | + alphas : list or array_like, shape (A,) |
| 213 | + Ridge parameters that will be tested. Should probably be log-spaced. np.logspace(0, 3, 20) works well. |
| 214 | + nboots : int |
| 215 | + The number of bootstrap samples to run. 15 to 30 works well. |
| 216 | + chunklen : int |
| 217 | + On each sample, the training data is broken into chunks of this length. This should be a few times |
| 218 | + longer than your delay/STRF. e.g. for a STRF with 3 delays, I use chunks of length 10. |
| 219 | + nchunks : int |
| 220 | + The number of training chunks held out to test ridge parameters for each bootstrap sample. The product |
| 221 | + of nchunks and chunklen is the total number of training samples held out for each sample, and this |
| 222 | + product should be about 20 percent of the total length of the training data. |
| 223 | + corrmin : float in [0..1] |
| 224 | + Purely for display purposes. After each alpha is tested for each bootstrap sample, the number of |
| 225 | + responses with correlation greater than this value will be printed. For long-running regressions this |
| 226 | + can give a rough sense of how well the model works before it's done. |
| 227 | + joined : None or list of array_like indices |
| 228 | + If you want the STRFs for two (or more) responses to be directly comparable, you need to ensure that |
| 229 | + the regularization parameter that they use is the same. To do that, supply a list of the response sets |
| 230 | + that should use the same ridge parameter here. For example, if you have four responses, joined could |
| 231 | + be [np.array([0,1]), np.array([2,3])], in which case responses 0 and 1 will use the same ridge parameter |
| 232 | + (which will be parameter that is best on average for those two), and likewise for responses 2 and 3. |
| 233 | + singcutoff : float |
| 234 | + The first step in ridge regression is computing the singular value decomposition (SVD) of the |
| 235 | + stimulus Rstim. If Rstim is not full rank, some singular values will be approximately equal |
| 236 | + to zero and the corresponding singular vectors will be noise. These singular values/vectors |
| 237 | + should be removed both for speed (the fewer multiplications the better!) and accuracy. Any |
| 238 | + singular values less than singcutoff will be removed. |
| 239 | + normalpha : boolean |
| 240 | + Whether ridge parameters (alphas) should be normalized by the largest singular value (LSV) |
| 241 | + norm of Rstim. Good for rigorously comparing models with different numbers of parameters. |
| 242 | + single_alpha : boolean |
| 243 | + Whether to use a single alpha for all responses. Good for identification/decoding. |
| 244 | + use_corr : boolean |
| 245 | + If True, this function will use correlation as its metric of model fit. If False, this function |
| 246 | + will instead use variance explained (R-squared) as its metric of model fit. For ridge regression |
| 247 | + this can make a big difference -- highly regularized solutions will have very small norms and |
| 248 | + will thus explain very little variance while still leading to high correlations, as correlation |
| 249 | + is scale-free while R**2 is not. |
| 250 | + |
| 251 | + Returns |
| 252 | + ------- |
| 253 | + wt : array_like, shape (N, M) |
| 254 | + Regression weights for N features and M responses. |
| 255 | + corrs : array_like, shape (M,) |
| 256 | + Validation set correlations. Predicted responses for the validation set are obtained using the regression |
| 257 | + weights: pred = np.dot(Pstim, wt), and then the correlation between each predicted response and each |
| 258 | + column in Presp is found. |
| 259 | + alphas : array_like, shape (M,) |
| 260 | + The regularization coefficient (alpha) selected for each voxel using bootstrap cross-validation. |
| 261 | + bootstrap_corrs : array_like, shape (A, M, B) |
| 262 | + Correlation between predicted and actual responses on randomly held out portions of the training set, |
| 263 | + for each of A alphas, M voxels, and B bootstrap samples. |
| 264 | + valinds : array_like, shape (TH, B) |
| 265 | + The indices of the training data that were used as "validation" for each bootstrap sample. |
| 266 | + """ |
| 267 | + nresp, nvox = Rresp.shape |
| 268 | + valinds = [] # Will hold the indices into the validation data for each bootstrap |
| 269 | + |
| 270 | + Rcmats = [] |
| 271 | + for bi in counter(range(nboots), countevery=1, total=nboots): |
| 272 | + logger.info("Selecting held-out test set..") |
| 273 | + allinds = range(nresp) |
| 274 | + indchunks = zip(*[iter(allinds)]*chunklen) |
| 275 | + random.shuffle(indchunks) |
| 276 | + heldinds = list(itools.chain(*indchunks[:nchunks])) |
| 277 | + notheldinds = list(set(allinds)-set(heldinds)) |
| 278 | + valinds.append(heldinds) |
| 279 | + |
| 280 | + RRstim = Rstim[notheldinds,:] |
| 281 | + PRstim = Rstim[heldinds,:] |
| 282 | + RRresp = Rresp[notheldinds,:] |
| 283 | + PRresp = Rresp[heldinds,:] |
| 284 | + |
| 285 | + # Run ridge regression using this test set |
| 286 | + Rcmat = ridge_corr(RRstim, PRstim, RRresp, PRresp, alphas, |
| 287 | + corrmin=corrmin, singcutoff=singcutoff, |
| 288 | + normalpha=normalpha, use_corr=use_corr, |
| 289 | + logger=logger) |
| 290 | + |
| 291 | + Rcmats.append(Rcmat) |
| 292 | + |
| 293 | + # Find best alphas |
| 294 | + if nboots>0: |
| 295 | + allRcorrs = np.dstack(Rcmats) |
| 296 | + else: |
| 297 | + allRcorrs = None |
| 298 | + |
| 299 | + if not single_alpha: |
| 300 | + if nboots==0: |
| 301 | + raise ValueError("You must run at least one cross-validation step to assign " |
| 302 | + "different alphas to each response.") |
| 303 | + |
| 304 | + logger.info("Finding best alpha for each voxel..") |
| 305 | + if joined is None: |
| 306 | + # Find best alpha for each voxel |
| 307 | + meanbootcorrs = allRcorrs.mean(2) |
| 308 | + bestalphainds = np.argmax(meanbootcorrs, 0) |
| 309 | + valphas = alphas[bestalphainds] |
| 310 | + else: |
| 311 | + # Find best alpha for each group of voxels |
| 312 | + valphas = np.zeros((nvox,)) |
| 313 | + for jl in joined: |
| 314 | + # Mean across voxels in the set, then mean across bootstraps |
| 315 | + jcorrs = allRcorrs[:,jl,:].mean(1).mean(1) |
| 316 | + bestalpha = np.argmax(jcorrs) |
| 317 | + valphas[jl] = alphas[bestalpha] |
| 318 | + else: |
| 319 | + logger.info("Finding single best alpha..") |
| 320 | + if nboots==0: |
| 321 | + if len(alphas)==1: |
| 322 | + bestalphaind = 0 |
| 323 | + bestalpha = alphas[0] |
| 324 | + else: |
| 325 | + raise ValueError("You must run at least one cross-validation step " |
| 326 | + "to choose best overall alpha, or only supply one" |
| 327 | + "possible alpha value.") |
| 328 | + else: |
| 329 | + meanbootcorr = allRcorrs.mean(2).mean(1) |
| 330 | + bestalphaind = np.argmax(meanbootcorr) |
| 331 | + bestalpha = alphas[bestalphaind] |
| 332 | + |
| 333 | + valphas = np.array([bestalpha]*nvox) |
| 334 | + logger.info("Best alpha = %0.3f"%bestalpha) |
| 335 | + |
| 336 | + # Find weights |
| 337 | + logger.info("Computing weights for each response using entire training set..") |
| 338 | + wt = ridge(Rstim, Rresp, valphas, singcutoff=singcutoff, normalpha=normalpha) |
| 339 | + |
| 340 | + # Predict responses on prediction set |
| 341 | + logger.info("Predicting responses for predictions set..") |
| 342 | + pred = np.dot(Pstim, wt) |
| 343 | + |
| 344 | + # Find prediction correlations |
| 345 | + nnpred = np.nan_to_num(pred) |
| 346 | + if use_corr: |
| 347 | + corrs = np.nan_to_num(np.array([np.corrcoef(Presp[:,ii], nnpred[:,ii].ravel())[0,1] |
| 348 | + for ii in range(Presp.shape[1])])) |
| 349 | + else: |
| 350 | + resvar = (Presp-pred).var(0) |
| 351 | + Rsqs = 1 - (resvar / Presp.var(0)) |
| 352 | + corrs = np.sqrt(np.abs(Rsqs)) * np.sign(Rsqs) |
| 353 | + |
| 354 | + return wt, corrs, valphas, allRcorrs, valinds |
0 commit comments