diff --git a/src/pted/pted.py b/src/pted/pted.py index a252c69..1bd99ba 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -1,9 +1,9 @@ -from typing import Union +from typing import Union, Optional import numpy as np from scipy.stats import chi2 as chi2_dist from torch import Tensor -from .utils import _pted_torch, _pted_numpy +from .utils import _pted_torch, _pted_numpy, _pted_chunk_torch, _pted_chunk_numpy __all__ = ["pted", "pted_coverage_test"] @@ -14,6 +14,8 @@ def pted( permutations: int = 1000, metric: str = "euclidean", return_all: bool = False, + chunk_size: Optional[int] = None, + chunk_iter: Optional[int] = None, ): """ Two sample test using a permutation test on the energy distance. @@ -25,12 +27,32 @@ def pted( permutations (int): number of permutations to run. This determines how accurately the p-value is computed. metric (str): distance metric to use. See scipy.spatial.distance.cdist - for the list of available metrics with numpy. See torch.cdist when using - PyTorch, note that the metric is passed as the "p" for torch.cdist and - therefore is a float from 0 to inf. + for the list of available metrics with numpy. See torch.cdist when + using PyTorch, note that the metric is passed as the "p" for + torch.cdist and therefore is a float from 0 to inf. return_all (bool): if True, return the test statistic and the permuted - statistics. If False, just return the p-value. bool (False by - default) + statistics. If False, just return the p-value. bool (False by default) + chunk_size (Optional[int]): if not None, use chunked energy distance + estimation. This is useful for large datasets. The chunk size is the + number of samples to use for each chunk. If None, use the full + dataset. + chunk_iter (Optional[int]): The chunk iter is the number of iterations + to use with the given chunk size. + + Note + ---- + PTED has O(n^2 * D * P) time complexity, where n is the number of + samples in x and y, D is the number of dimensions, and P is the number + of permutations. For large datasets this can get unwieldy, so chunking + is recommended. For chunking, the energy distance will be estimated at + each iteration rather than fully computed. To estimate the energy + distance, we take `chunk_size` sub-samples from x and y, and compute the + energy distance on those sub-samples. This is repeated `chunk_iter` + times, and the average is taken. This is a trade-off between speed and + accuracy. The larger the chunk size and larger chunk_iter, the more + accurate the estimate, but the slower the computation. PTED remains an + exact p-value test even when chunking, it simply becomes less sensitive + to the difference between x and y. """ assert type(x) == type(y), f"x and y must be of the same type, not {type(x)} and {type(y)}" assert len(x.shape) >= 2, f"x must be at least 2D, not {x.shape}" @@ -43,9 +65,34 @@ def pted( if len(y.shape) > 2: y = y.reshape(y.shape[0], -1) - if isinstance(x, Tensor) and isinstance(y, Tensor): - return _pted_torch(x, y, permutations=permutations, metric=metric, return_all=return_all) - return _pted_numpy(x, y, permutations=permutations, metric=metric, return_all=return_all) + if isinstance(x, Tensor) and chunk_size is not None: + test, permute = _pted_chunk_torch( + x, + y, + permutations=permutations, + metric=metric, + chunk_size=chunk_size, + chunk_iter=chunk_iter, + ) + elif isinstance(x, Tensor): + test, permute = _pted_torch(x, y, permutations=permutations, metric=metric) + elif chunk_size is not None: + test, permute = _pted_chunk_numpy( + x, + y, + permutations=permutations, + metric=metric, + chunk_size=chunk_size, + chunk_iter=chunk_iter, + ) + else: + test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric) + + if return_all: + return test, permute + + # Compute p-value + return np.mean(np.array(permute) > test) def pted_coverage_test( @@ -54,6 +101,8 @@ def pted_coverage_test( permutations: int = 1000, metric: str = "euclidean", return_all: bool = False, + chunk_size: Optional[int] = None, + chunk_iter: Optional[int] = None, ): """ Coverage test using a permutation test on the energy distance. @@ -71,6 +120,27 @@ def pted_coverage_test( return_all (bool): if True, return the test statistic and the permuted statistics. If False, just return the p-value. bool (False by default) + chunk_size (Optional[int]): if not None, use chunked energy distance + estimation. This is useful for large datasets. The chunk size is the + number of samples to use for each chunk. If None, use the full + dataset. + chunk_iter (Optional[int]): The chunk iter is the number of iterations + to use with the given chunk size. + + Note + ---- + PTED has O(n^2 * D * P) time complexity, where n is the number of + samples in x and y, D is the number of dimensions, and P is the number + of permutations. For large datasets this can get unwieldy, so chunking + is recommended. For chunking, the energy distance will be estimated at + each iteration rather than fully computed. To estimate the energy + distance, we take `chunk_size` sub-samples from x and y, and compute the + energy distance on those sub-samples. This is repeated `chunk_iter` + times, and the average is taken. This is a trade-off between speed and + accuracy. The larger the chunk size and larger chunk_iter, the more + accurate the estimate, but the slower the computation. PTED remains an + exact p-value test even when chunking, it simply becomes less sensitive + to the difference between x and y. """ nsamp, nsim, *D = s.shape assert ( @@ -79,18 +149,27 @@ def pted_coverage_test( if len(s.shape) > 3: s = s.reshape(nsamp, nsim, -1) g = g.reshape(1, nsim, -1) + test_stats = [] permute_stats = [] for i in range(nsim): test, permute = pted( - g[:, i], s[:, i], permutations=permutations, metric=metric, return_all=True + g[:, i], + s[:, i], + permutations=permutations, + metric=metric, + return_all=True, + chunk_size=chunk_size, + chunk_iter=chunk_iter, ) test_stats.append(test) permute_stats.append(permute) test_stats = np.array(test_stats) permute_stats = np.array(permute_stats) + if return_all: return test_stats, permute_stats + # Compute p-values pvals = np.mean(permute_stats > test_stats[:, None], axis=1) pvals[pvals == 0] = 1.0 / permutations # handle pvals == 0 diff --git a/src/pted/utils.py b/src/pted/utils.py index 8390c27..484566d 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -2,19 +2,76 @@ from scipy.spatial.distance import cdist import torch -__all__ = ["_pted_numpy", "_pted_torch"] +__all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"] -def _energy_distance_precompute(D, ix, iy): - nx = len(ix) - ny = len(iy) - Exx = (D[ix.reshape(nx, 1), ix.reshape(1, nx)]).sum() / nx**2 - Eyy = (D[iy.reshape(ny, 1), iy.reshape(1, ny)]).sum() / ny**2 - Exy = (D[ix.reshape(nx, 1), iy.reshape(1, ny)]).sum() / (nx * ny) +def _energy_distance_precompute(D, nx, ny): + Exx = D[:nx, :nx].sum() / nx**2 + Eyy = D[nx:, nx:].sum() / ny**2 + Exy = D[:nx, nx:].sum() / (nx * ny) return 2 * Exy - Exx - Eyy -def _pted_numpy(x, y, permutations=100, metric="euclidean", return_all=False): +def _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric="euclidean"): + is_torch = isinstance(x, torch.Tensor) + + E_est = [] + for _ in range(chunk_iter): + # Randomly sample a chunk of data + idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) + if is_torch: + idx = torch.tensor(idx, device=x.device) + x_chunk = x[idx] + idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) + if is_torch: + idy = torch.tensor(idy, device=y.device) + y_chunk = y[idy] + + # Compute the distance matrix + if is_torch: + z_chunk = torch.cat((x_chunk, y_chunk), dim=0) + else: + z_chunk = np.concatenate((x_chunk, y_chunk), axis=0) + dmatrix = cdist(z_chunk, z_chunk, metric=metric) + + # Compute the energy distance + E_est.append(_energy_distance_precompute(dmatrix, len(x_chunk), len(y_chunk))) + if is_torch: + E_est[-1] = E_est[-1].item() + return np.mean(E_est) + + +def _pted_chunk_numpy(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): + assert np.all(np.isfinite(x)) and np.all(np.isfinite(y)), "Input contains NaN or Inf!" + nx = len(x) + + test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + permute_stats = [] + for _ in range(permutations): + z = np.concatenate((x, y), axis=0) + z = z[np.random.permutation(len(z))] + x, y = z[:nx], z[nx:] + permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + return test_stat, permute_stats + + +def _pted_chunk_torch(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): + assert torch.all(torch.isfinite(x)) and torch.all( + torch.isfinite(y) + ), "Input contains NaN or Inf!" + nx = len(x) + + test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + permute_stats = [] + for _ in range(permutations): + z = torch.cat((x, y), dim=0) + z = z[torch.randperm(len(z))] + x, y = z[:nx], z[nx:] + permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + return test_stat, permute_stats + + +def _pted_numpy(x, y, permutations=100, metric="euclidean"): z = np.concatenate((x, y), axis=0) assert np.all(np.isfinite(z)), "Input contains NaN or Inf!" dmatrix = cdist(z, z, metric=metric) @@ -22,21 +79,19 @@ def _pted_numpy(x, y, permutations=100, metric="euclidean", return_all=False): np.isfinite(dmatrix) ), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)." nx = len(x) - I = np.arange(len(z)) + ny = len(y) - test_stat = _energy_distance_precompute(dmatrix, I[:nx], I[nx:]) + test_stat = _energy_distance_precompute(dmatrix, nx, ny) permute_stats = [] for _ in range(permutations): - np.random.shuffle(I) - permute_stats.append(_energy_distance_precompute(dmatrix, I[:nx], I[nx:])) - if return_all: - return test_stat, permute_stats - # Compute p-value - return np.mean(np.array(permute_stats) > test_stat) + I = np.random.permutation(len(z)) + dmatrix = dmatrix[I][:, I] + permute_stats.append(_energy_distance_precompute(dmatrix, nx, ny)) + return test_stat, permute_stats @torch.no_grad() -def _pted_torch(x, y, permutations=100, metric="euclidean", return_all=False): +def _pted_torch(x, y, permutations=100, metric="euclidean"): z = torch.cat((x, y), dim=0) assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!" if metric == "euclidean": @@ -46,14 +101,12 @@ def _pted_torch(x, y, permutations=100, metric="euclidean", return_all=False): torch.isfinite(dmatrix) ), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)." nx = len(x) - I = torch.arange(len(z)) + ny = len(y) - test_stat = _energy_distance_precompute(dmatrix, I[:nx], I[nx:]).item() + test_stat = _energy_distance_precompute(dmatrix, nx, ny).item() permute_stats = [] for _ in range(permutations): - I = I[torch.randperm(len(I))] - permute_stats.append(_energy_distance_precompute(dmatrix, I[:nx], I[nx:]).item()) - if return_all: - return test_stat, permute_stats - # Compute p-value - return np.mean(np.array(permute_stats) > test_stat) + I = torch.randperm(len(z)) + dmatrix = dmatrix[I][:, I] + permute_stats.append(_energy_distance_precompute(dmatrix, nx, ny).item()) + return test_stat, permute_stats diff --git a/tests/test_pted.py b/tests/test_pted.py index 4a5de5a..01ac54d 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -56,3 +56,34 @@ def test_pted_coverage_full(): test, permute = pted.pted_coverage_test(g, s, permutations=100, return_all=True) assert test.shape == (100,) assert permute.shape == (100, 100) + + +def test_pted_chunk_torch(): + np.random.seed(42) + torch.manual_seed(42) + + # example 2 sample test + D = 10 + x = torch.randn(1000, D) + y = torch.randn(1000, D) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))" + + y = torch.rand(1000, D) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p < 1e-4, f"p-value {p} is not in the expected range (~0)" + + +def test_pted_chunk_numpy(): + np.random.seed(42) + + # example 2 sample test + D = 10 + x = np.random.normal(size=(1000, D)) + y = np.random.normal(size=(1000, D)) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))" + + y = np.random.uniform(size=(1000, D)) + p = pted.pted(x, y, chunk_size=100, chunk_iter=10) + assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"