diff --git a/src/pted/pted.py b/src/pted/pted.py index 4fd72fe..16d05eb 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -12,11 +12,11 @@ def pted( x: Union[np.ndarray, Tensor], y: Union[np.ndarray, Tensor], permutations: int = 1000, - metric: str = "euclidean", + metric: Union[str, float] = "euclidean", return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -): +) -> Union[float, tuple[float, np.ndarray]]: """ Two sample null hypothesis test using a permutation test on the energy distance. @@ -100,6 +100,9 @@ def pted( assert type(x) == type(y), f"x and y must be of the same type, not {type(x)} and {type(y)}" assert len(x.shape) >= 2, f"x must be at least 2D, not {x.shape}" assert len(y.shape) >= 2, f"y must be at least 2D, not {y.shape}" + assert (chunk_size is not None) is ( + chunk_iter is not None + ), "chunk_size and chunk_iter must both be provided for chunked PTED test" assert ( x.shape[1:] == y.shape[1:] ), f"x and y samples must have the same shape (past first dim), not {x.shape} and {y.shape}" @@ -114,8 +117,8 @@ def pted( y, permutations=permutations, metric=metric, - chunk_size=chunk_size, - chunk_iter=chunk_iter, + chunk_size=int(chunk_size), + chunk_iter=int(chunk_iter), ) elif isinstance(x, Tensor): test, permute = _pted_torch(x, y, permutations=permutations, metric=metric) @@ -125,17 +128,18 @@ def pted( y, permutations=permutations, metric=metric, - chunk_size=chunk_size, - chunk_iter=chunk_iter, + chunk_size=int(chunk_size), + chunk_iter=int(chunk_iter), ) else: test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric) + permute = np.array(permute) if return_all: return test, permute # Compute p-value - return np.mean(np.array(permute) > test) + return np.mean(permute > test) def pted_coverage_test( @@ -146,7 +150,7 @@ def pted_coverage_test( return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -): +) -> Union[float, tuple[np.ndarray, np.ndarray]]: """ Coverage test using a permutation test on the energy distance. @@ -229,7 +233,7 @@ def pted_coverage_test( number of iterations, D is the number of dimensions, and P is the number of permutations. For chunking to be worth it you should have c^2 * I << n^2. """ - nsamp, nsim, *D = s.shape + nsamp, nsim, *_ = s.shape assert ( g.shape == s.shape[1:] ), f"g and s must have the same shape (past first dim of s), not {g.shape} and {s.shape}" @@ -252,7 +256,7 @@ def pted_coverage_test( test_stats.append(test) permute_stats.append(permute) test_stats = np.array(test_stats) - permute_stats = np.array(permute_stats) + permute_stats = np.stack(permute_stats) if return_all: return test_stats, permute_stats diff --git a/src/pted/utils.py b/src/pted/utils.py index 484566d..7a31f91 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -1,3 +1,5 @@ +from typing import Optional, Union + import numpy as np from scipy.spatial.distance import cdist import torch @@ -5,73 +7,128 @@ __all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"] -def _energy_distance_precompute(D, nx, ny): +def _energy_distance_precompute( + D: Union[np.ndarray, torch.Tensor], nx: int, ny: int +) -> Union[float, torch.Tensor]: Exx = D[:nx, :nx].sum() / nx**2 Eyy = D[nx:, nx:].sum() / ny**2 Exy = D[:nx, nx:].sum() / (nx * ny) return 2 * Exy - Exx - Eyy -def _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric="euclidean"): - is_torch = isinstance(x, torch.Tensor) +def _energy_distance_numpy(x: np.ndarray, y: np.ndarray, metric: str = "euclidean") -> float: + nx = len(x) + ny = len(y) + z = np.concatenate((x, y), axis=0) + D = cdist(z, z, metric=metric) + return _energy_distance_precompute(D, nx, ny) + + +def _energy_distance_torch( + x: torch.Tensor, y: torch.Tensor, metric: Union[str, float] = "euclidean" +) -> float: + nx = len(x) + ny = len(y) + z = torch.cat((x, y), dim=0) + if metric == "euclidean": + metric = 2.0 + D = torch.cdist(z, z, p=metric) + return _energy_distance_precompute(D, nx, ny).item() + + +def _energy_distance_estimate_numpy( + x: np.ndarray, + y: np.ndarray, + chunk_size: int, + chunk_iter: int, + metric: Union[str, float] = "euclidean", +) -> float: E_est = [] for _ in range(chunk_iter): # Randomly sample a chunk of data idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) - if is_torch: - idx = torch.tensor(idx, device=x.device) x_chunk = x[idx] idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) - if is_torch: - idy = torch.tensor(idy, device=y.device) y_chunk = y[idy] - # Compute the distance matrix - if is_torch: - z_chunk = torch.cat((x_chunk, y_chunk), dim=0) - else: - z_chunk = np.concatenate((x_chunk, y_chunk), axis=0) - dmatrix = cdist(z_chunk, z_chunk, metric=metric) + # Compute the energy distance + E_est.append(_energy_distance_numpy(x_chunk, y_chunk, metric=metric)) + return np.mean(E_est) + + +def _energy_distance_estimate_torch( + x: torch.Tensor, + y: torch.Tensor, + chunk_size: int, + chunk_iter: int, + metric: Union[str, float] = "euclidean", +) -> float: + + E_est = [] + for _ in range(chunk_iter): + # Randomly sample a chunk of data + idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) + x_chunk = x[torch.tensor(idx)] + idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) + y_chunk = y[torch.tensor(idy)] # Compute the energy distance - E_est.append(_energy_distance_precompute(dmatrix, len(x_chunk), len(y_chunk))) - if is_torch: - E_est[-1] = E_est[-1].item() + E_est.append(_energy_distance_torch(x_chunk, y_chunk, metric=metric)) return np.mean(E_est) -def _pted_chunk_numpy(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): +def _pted_chunk_numpy( + x: np.ndarray, + y: np.ndarray, + permutations: int = 100, + metric: str = "euclidean", + chunk_size: int = 100, + chunk_iter: int = 10, +) -> tuple[float, list[float]]: assert np.all(np.isfinite(x)) and np.all(np.isfinite(y)), "Input contains NaN or Inf!" nx = len(x) - test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + test_stat = _energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric) permute_stats = [] for _ in range(permutations): z = np.concatenate((x, y), axis=0) z = z[np.random.permutation(len(z))] x, y = z[:nx], z[nx:] - permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + permute_stats.append( + _energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric) + ) return test_stat, permute_stats -def _pted_chunk_torch(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): +def _pted_chunk_torch( + x: torch.Tensor, + y: torch.Tensor, + permutations: int = 100, + metric: Union[str, float] = "euclidean", + chunk_size: int = 100, + chunk_iter: int = 10, +) -> tuple[float, list[float]]: assert torch.all(torch.isfinite(x)) and torch.all( torch.isfinite(y) ), "Input contains NaN or Inf!" nx = len(x) - test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + test_stat = _energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric) permute_stats = [] for _ in range(permutations): z = torch.cat((x, y), dim=0) z = z[torch.randperm(len(z))] x, y = z[:nx], z[nx:] - permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + permute_stats.append( + _energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric) + ) return test_stat, permute_stats -def _pted_numpy(x, y, permutations=100, metric="euclidean"): +def _pted_numpy( + x: np.ndarray, y: np.ndarray, permutations: int = 100, metric: str = "euclidean" +) -> tuple[float, list[float]]: z = np.concatenate((x, y), axis=0) assert np.all(np.isfinite(z)), "Input contains NaN or Inf!" dmatrix = cdist(z, z, metric=metric) @@ -91,7 +148,12 @@ def _pted_numpy(x, y, permutations=100, metric="euclidean"): @torch.no_grad() -def _pted_torch(x, y, permutations=100, metric="euclidean"): +def _pted_torch( + x: torch.Tensor, + y: torch.Tensor, + permutations: int = 100, + metric: Union[str, float] = "euclidean", +) -> tuple[float, list[float]]: z = torch.cat((x, y), dim=0) assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!" if metric == "euclidean":