From cda4e1769ce5ceb58c003ca16cf0eda61fdca34e Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 6 May 2025 20:19:32 -0400 Subject: [PATCH 1/2] add type hints to more functions --- src/pted/pted.py | 24 ++++++---- src/pted/utils.py | 114 ++++++++++++++++++++++++++++++++++++---------- 2 files changed, 104 insertions(+), 34 deletions(-) diff --git a/src/pted/pted.py b/src/pted/pted.py index 4fd72fe..67d43f9 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -12,11 +12,11 @@ def pted( x: Union[np.ndarray, Tensor], y: Union[np.ndarray, Tensor], permutations: int = 1000, - metric: str = "euclidean", + metric: Union[str, float] = "euclidean", return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -): +) -> Union[float, Union[float, np.ndarray]]: """ Two sample null hypothesis test using a permutation test on the energy distance. @@ -100,6 +100,9 @@ def pted( assert type(x) == type(y), f"x and y must be of the same type, not {type(x)} and {type(y)}" assert len(x.shape) >= 2, f"x must be at least 2D, not {x.shape}" assert len(y.shape) >= 2, f"y must be at least 2D, not {y.shape}" + assert (chunk_size is not None) is ( + chunk_iter is not None + ), "chunk_size and chunk_iter must both be provided for chunked PTED test" assert ( x.shape[1:] == y.shape[1:] ), f"x and y samples must have the same shape (past first dim), not {x.shape} and {y.shape}" @@ -114,8 +117,8 @@ def pted( y, permutations=permutations, metric=metric, - chunk_size=chunk_size, - chunk_iter=chunk_iter, + chunk_size=int(chunk_size), + chunk_iter=int(chunk_iter), ) elif isinstance(x, Tensor): test, permute = _pted_torch(x, y, permutations=permutations, metric=metric) @@ -125,17 +128,18 @@ def pted( y, permutations=permutations, metric=metric, - chunk_size=chunk_size, - chunk_iter=chunk_iter, + chunk_size=int(chunk_size), + chunk_iter=int(chunk_iter), ) else: test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric) + permute = np.array(permute) if return_all: return test, permute # Compute p-value - return np.mean(np.array(permute) > test) + return np.mean(permute > test) def pted_coverage_test( @@ -146,7 +150,7 @@ def pted_coverage_test( return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -): +) -> Union[float, Union[np.ndarray, np.ndarray]]: """ Coverage test using a permutation test on the energy distance. @@ -229,7 +233,7 @@ def pted_coverage_test( number of iterations, D is the number of dimensions, and P is the number of permutations. For chunking to be worth it you should have c^2 * I << n^2. """ - nsamp, nsim, *D = s.shape + nsamp, nsim, *_ = s.shape assert ( g.shape == s.shape[1:] ), f"g and s must have the same shape (past first dim of s), not {g.shape} and {s.shape}" @@ -252,7 +256,7 @@ def pted_coverage_test( test_stats.append(test) permute_stats.append(permute) test_stats = np.array(test_stats) - permute_stats = np.array(permute_stats) + permute_stats = np.stack(permute_stats) if return_all: return test_stats, permute_stats diff --git a/src/pted/utils.py b/src/pted/utils.py index 484566d..55c17ec 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -1,3 +1,5 @@ +from typing import Optional, Union + import numpy as np from scipy.spatial.distance import cdist import torch @@ -5,73 +7,132 @@ __all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"] -def _energy_distance_precompute(D, nx, ny): +def _energy_distance_numpy(x, y, metric="euclidean"): + nx = len(x) + ny = len(y) + z = np.concatenate((x, y), axis=0) + D = cdist(z, z, metric=metric) + Exx = D[:nx, :nx].sum() / nx**2 + Eyy = D[nx:, nx:].sum() / ny**2 + Exy = D[:nx, nx:].sum() / (nx * ny) + return 2 * Exy - Exx - Eyy + + +def _energy_distance_torch(x, y, metric="euclidean"): + nx = len(x) + ny = len(y) + z = torch.cat((x, y), dim=0) + if metric == "euclidean": + metric = 2.0 + D = torch.cdist(z, z, p=metric) + Exx = D[:nx, :nx].sum() / nx**2 + Eyy = D[nx:, nx:].sum() / ny**2 + Exy = D[:nx, nx:].sum() / (nx * ny) + return (2 * Exy - Exx - Eyy).item() + + +def _energy_distance_precompute( + D: Union[np.ndarray, torch.Tensor], nx: int, ny: int +) -> Union[float, torch.Tensor]: Exx = D[:nx, :nx].sum() / nx**2 Eyy = D[nx:, nx:].sum() / ny**2 Exy = D[:nx, nx:].sum() / (nx * ny) return 2 * Exy - Exx - Eyy -def _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric="euclidean"): - is_torch = isinstance(x, torch.Tensor) +def _energy_distance_estimate_numpy( + x: np.ndarray, + y: np.ndarray, + chunk_size: int, + chunk_iter: int, + metric: Union[str, float] = "euclidean", +) -> float: E_est = [] for _ in range(chunk_iter): # Randomly sample a chunk of data idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) - if is_torch: - idx = torch.tensor(idx, device=x.device) x_chunk = x[idx] idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) - if is_torch: - idy = torch.tensor(idy, device=y.device) y_chunk = y[idy] - # Compute the distance matrix - if is_torch: - z_chunk = torch.cat((x_chunk, y_chunk), dim=0) - else: - z_chunk = np.concatenate((x_chunk, y_chunk), axis=0) - dmatrix = cdist(z_chunk, z_chunk, metric=metric) + # Compute the energy distance + E_est.append(_energy_distance_numpy(x_chunk, y_chunk, metric=metric)) + return np.mean(E_est) + + +def _energy_distance_estimate_torch( + x: torch.Tensor, + y: torch.Tensor, + chunk_size: int, + chunk_iter: int, + metric: Union[str, float] = "euclidean", +) -> float: + + E_est = [] + for _ in range(chunk_iter): + # Randomly sample a chunk of data + idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False) + x_chunk = x[torch.tensor(idx)] + idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False) + y_chunk = y[torch.tensor(idy)] # Compute the energy distance - E_est.append(_energy_distance_precompute(dmatrix, len(x_chunk), len(y_chunk))) - if is_torch: - E_est[-1] = E_est[-1].item() + E_est.append(_energy_distance_torch(x_chunk, y_chunk, metric=metric)) return np.mean(E_est) -def _pted_chunk_numpy(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): +def _pted_chunk_numpy( + x: np.ndarray, + y: np.ndarray, + permutations: int = 100, + metric: str = "euclidean", + chunk_size: int = 100, + chunk_iter: int = 10, +) -> tuple[float, list[float]]: assert np.all(np.isfinite(x)) and np.all(np.isfinite(y)), "Input contains NaN or Inf!" nx = len(x) - test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + test_stat = _energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric) permute_stats = [] for _ in range(permutations): z = np.concatenate((x, y), axis=0) z = z[np.random.permutation(len(z))] x, y = z[:nx], z[nx:] - permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + permute_stats.append( + _energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric) + ) return test_stat, permute_stats -def _pted_chunk_torch(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10): +def _pted_chunk_torch( + x: torch.Tensor, + y: torch.Tensor, + permutations: int = 100, + metric: Union[str, float] = "euclidean", + chunk_size: int = 100, + chunk_iter: int = 10, +) -> tuple[float, list[float]]: assert torch.all(torch.isfinite(x)) and torch.all( torch.isfinite(y) ), "Input contains NaN or Inf!" nx = len(x) - test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric) + test_stat = _energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric) permute_stats = [] for _ in range(permutations): z = torch.cat((x, y), dim=0) z = z[torch.randperm(len(z))] x, y = z[:nx], z[nx:] - permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)) + permute_stats.append( + _energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric) + ) return test_stat, permute_stats -def _pted_numpy(x, y, permutations=100, metric="euclidean"): +def _pted_numpy( + x: np.ndarray, y: np.ndarray, permutations: int = 100, metric: str = "euclidean" +) -> tuple[float, list[float]]: z = np.concatenate((x, y), axis=0) assert np.all(np.isfinite(z)), "Input contains NaN or Inf!" dmatrix = cdist(z, z, metric=metric) @@ -91,7 +152,12 @@ def _pted_numpy(x, y, permutations=100, metric="euclidean"): @torch.no_grad() -def _pted_torch(x, y, permutations=100, metric="euclidean"): +def _pted_torch( + x: torch.Tensor, + y: torch.Tensor, + permutations: int = 100, + metric: Union[str, float] = "euclidean", +) -> tuple[float, list[float]]: z = torch.cat((x, y), dim=0) assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!" if metric == "euclidean": From 590e6c50db509544e260529735e1ae03018e4bad Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Tue, 6 May 2025 20:32:33 -0400 Subject: [PATCH 2/2] combine energy dist calculations --- src/pted/pted.py | 4 ++-- src/pted/utils.py | 34 +++++++++++++++------------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/src/pted/pted.py b/src/pted/pted.py index 67d43f9..16d05eb 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -16,7 +16,7 @@ def pted( return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -) -> Union[float, Union[float, np.ndarray]]: +) -> Union[float, tuple[float, np.ndarray]]: """ Two sample null hypothesis test using a permutation test on the energy distance. @@ -150,7 +150,7 @@ def pted_coverage_test( return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -) -> Union[float, Union[np.ndarray, np.ndarray]]: +) -> Union[float, tuple[np.ndarray, np.ndarray]]: """ Coverage test using a permutation test on the energy distance. diff --git a/src/pted/utils.py b/src/pted/utils.py index 55c17ec..7a31f91 100644 --- a/src/pted/utils.py +++ b/src/pted/utils.py @@ -7,37 +7,33 @@ __all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"] -def _energy_distance_numpy(x, y, metric="euclidean"): - nx = len(x) - ny = len(y) - z = np.concatenate((x, y), axis=0) - D = cdist(z, z, metric=metric) +def _energy_distance_precompute( + D: Union[np.ndarray, torch.Tensor], nx: int, ny: int +) -> Union[float, torch.Tensor]: Exx = D[:nx, :nx].sum() / nx**2 Eyy = D[nx:, nx:].sum() / ny**2 Exy = D[:nx, nx:].sum() / (nx * ny) return 2 * Exy - Exx - Eyy -def _energy_distance_torch(x, y, metric="euclidean"): +def _energy_distance_numpy(x: np.ndarray, y: np.ndarray, metric: str = "euclidean") -> float: + nx = len(x) + ny = len(y) + z = np.concatenate((x, y), axis=0) + D = cdist(z, z, metric=metric) + return _energy_distance_precompute(D, nx, ny) + + +def _energy_distance_torch( + x: torch.Tensor, y: torch.Tensor, metric: Union[str, float] = "euclidean" +) -> float: nx = len(x) ny = len(y) z = torch.cat((x, y), dim=0) if metric == "euclidean": metric = 2.0 D = torch.cdist(z, z, p=metric) - Exx = D[:nx, :nx].sum() / nx**2 - Eyy = D[nx:, nx:].sum() / ny**2 - Exy = D[:nx, nx:].sum() / (nx * ny) - return (2 * Exy - Exx - Eyy).item() - - -def _energy_distance_precompute( - D: Union[np.ndarray, torch.Tensor], nx: int, ny: int -) -> Union[float, torch.Tensor]: - Exx = D[:nx, :nx].sum() / nx**2 - Eyy = D[nx:, nx:].sum() / ny**2 - Exy = D[:nx, nx:].sum() / (nx * ny) - return 2 * Exy - Exx - Eyy + return _energy_distance_precompute(D, nx, ny).item() def _energy_distance_estimate_numpy(