Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/pted/pted.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ def pted(
x: Union[np.ndarray, Tensor],
y: Union[np.ndarray, Tensor],
permutations: int = 1000,
metric: str = "euclidean",
metric: Union[str, float] = "euclidean",
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
):
) -> Union[float, tuple[float, np.ndarray]]:
"""
Two sample null hypothesis test using a permutation test on the energy
distance.
Expand Down Expand Up @@ -100,6 +100,9 @@ def pted(
assert type(x) == type(y), f"x and y must be of the same type, not {type(x)} and {type(y)}"
assert len(x.shape) >= 2, f"x must be at least 2D, not {x.shape}"
assert len(y.shape) >= 2, f"y must be at least 2D, not {y.shape}"
assert (chunk_size is not None) is (
chunk_iter is not None
), "chunk_size and chunk_iter must both be provided for chunked PTED test"
assert (
x.shape[1:] == y.shape[1:]
), f"x and y samples must have the same shape (past first dim), not {x.shape} and {y.shape}"
Expand All @@ -114,8 +117,8 @@ def pted(
y,
permutations=permutations,
metric=metric,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
chunk_size=int(chunk_size),
chunk_iter=int(chunk_iter),
)
elif isinstance(x, Tensor):
test, permute = _pted_torch(x, y, permutations=permutations, metric=metric)
Expand All @@ -125,17 +128,18 @@ def pted(
y,
permutations=permutations,
metric=metric,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
chunk_size=int(chunk_size),
chunk_iter=int(chunk_iter),
)
else:
test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric)

permute = np.array(permute)
if return_all:
return test, permute

# Compute p-value
return np.mean(np.array(permute) > test)
return np.mean(permute > test)


def pted_coverage_test(
Expand All @@ -146,7 +150,7 @@ def pted_coverage_test(
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
):
) -> Union[float, tuple[np.ndarray, np.ndarray]]:
"""
Coverage test using a permutation test on the energy distance.

Expand Down Expand Up @@ -229,7 +233,7 @@ def pted_coverage_test(
number of iterations, D is the number of dimensions, and P is the number
of permutations. For chunking to be worth it you should have c^2 * I << n^2.
"""
nsamp, nsim, *D = s.shape
nsamp, nsim, *_ = s.shape
assert (
g.shape == s.shape[1:]
), f"g and s must have the same shape (past first dim of s), not {g.shape} and {s.shape}"
Expand All @@ -252,7 +256,7 @@ def pted_coverage_test(
test_stats.append(test)
permute_stats.append(permute)
test_stats = np.array(test_stats)
permute_stats = np.array(permute_stats)
permute_stats = np.stack(permute_stats)

if return_all:
return test_stats, permute_stats
Expand Down
110 changes: 86 additions & 24 deletions src/pted/utils.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,134 @@
from typing import Optional, Union

import numpy as np
from scipy.spatial.distance import cdist
import torch

__all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"]


def _energy_distance_precompute(D, nx, ny):
def _energy_distance_precompute(
D: Union[np.ndarray, torch.Tensor], nx: int, ny: int
) -> Union[float, torch.Tensor]:
Exx = D[:nx, :nx].sum() / nx**2
Eyy = D[nx:, nx:].sum() / ny**2
Exy = D[:nx, nx:].sum() / (nx * ny)
return 2 * Exy - Exx - Eyy


def _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric="euclidean"):
is_torch = isinstance(x, torch.Tensor)
def _energy_distance_numpy(x: np.ndarray, y: np.ndarray, metric: str = "euclidean") -> float:
nx = len(x)
ny = len(y)
z = np.concatenate((x, y), axis=0)
D = cdist(z, z, metric=metric)
return _energy_distance_precompute(D, nx, ny)


def _energy_distance_torch(
x: torch.Tensor, y: torch.Tensor, metric: Union[str, float] = "euclidean"
) -> float:
nx = len(x)
ny = len(y)
z = torch.cat((x, y), dim=0)
if metric == "euclidean":
metric = 2.0
D = torch.cdist(z, z, p=metric)
return _energy_distance_precompute(D, nx, ny).item()


def _energy_distance_estimate_numpy(
x: np.ndarray,
y: np.ndarray,
chunk_size: int,
chunk_iter: int,
metric: Union[str, float] = "euclidean",
) -> float:

E_est = []
for _ in range(chunk_iter):
# Randomly sample a chunk of data
idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False)
if is_torch:
idx = torch.tensor(idx, device=x.device)
x_chunk = x[idx]
idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False)
if is_torch:
idy = torch.tensor(idy, device=y.device)
y_chunk = y[idy]

# Compute the distance matrix
if is_torch:
z_chunk = torch.cat((x_chunk, y_chunk), dim=0)
else:
z_chunk = np.concatenate((x_chunk, y_chunk), axis=0)
dmatrix = cdist(z_chunk, z_chunk, metric=metric)
# Compute the energy distance
E_est.append(_energy_distance_numpy(x_chunk, y_chunk, metric=metric))
return np.mean(E_est)


def _energy_distance_estimate_torch(
x: torch.Tensor,
y: torch.Tensor,
chunk_size: int,
chunk_iter: int,
metric: Union[str, float] = "euclidean",
) -> float:

E_est = []
for _ in range(chunk_iter):
# Randomly sample a chunk of data
idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False)
x_chunk = x[torch.tensor(idx)]
idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False)
y_chunk = y[torch.tensor(idy)]

# Compute the energy distance
E_est.append(_energy_distance_precompute(dmatrix, len(x_chunk), len(y_chunk)))
if is_torch:
E_est[-1] = E_est[-1].item()
E_est.append(_energy_distance_torch(x_chunk, y_chunk, metric=metric))
return np.mean(E_est)


def _pted_chunk_numpy(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10):
def _pted_chunk_numpy(
x: np.ndarray,
y: np.ndarray,
permutations: int = 100,
metric: str = "euclidean",
chunk_size: int = 100,
chunk_iter: int = 10,
) -> tuple[float, list[float]]:
assert np.all(np.isfinite(x)) and np.all(np.isfinite(y)), "Input contains NaN or Inf!"
nx = len(x)

test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)
test_stat = _energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric)
permute_stats = []
for _ in range(permutations):
z = np.concatenate((x, y), axis=0)
z = z[np.random.permutation(len(z))]
x, y = z[:nx], z[nx:]
permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric))
permute_stats.append(
_energy_distance_estimate_numpy(x, y, chunk_size, chunk_iter, metric=metric)
)
return test_stat, permute_stats


def _pted_chunk_torch(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10):
def _pted_chunk_torch(
x: torch.Tensor,
y: torch.Tensor,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
chunk_size: int = 100,
chunk_iter: int = 10,
) -> tuple[float, list[float]]:
assert torch.all(torch.isfinite(x)) and torch.all(
torch.isfinite(y)
), "Input contains NaN or Inf!"
nx = len(x)

test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)
test_stat = _energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric)
permute_stats = []
for _ in range(permutations):
z = torch.cat((x, y), dim=0)
z = z[torch.randperm(len(z))]
x, y = z[:nx], z[nx:]
permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric))
permute_stats.append(
_energy_distance_estimate_torch(x, y, chunk_size, chunk_iter, metric=metric)
)
return test_stat, permute_stats


def _pted_numpy(x, y, permutations=100, metric="euclidean"):
def _pted_numpy(
x: np.ndarray, y: np.ndarray, permutations: int = 100, metric: str = "euclidean"
) -> tuple[float, list[float]]:
z = np.concatenate((x, y), axis=0)
assert np.all(np.isfinite(z)), "Input contains NaN or Inf!"
dmatrix = cdist(z, z, metric=metric)
Expand All @@ -91,7 +148,12 @@ def _pted_numpy(x, y, permutations=100, metric="euclidean"):


@torch.no_grad()
def _pted_torch(x, y, permutations=100, metric="euclidean"):
def _pted_torch(
x: torch.Tensor,
y: torch.Tensor,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
) -> tuple[float, list[float]]:
z = torch.cat((x, y), dim=0)
assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!"
if metric == "euclidean":
Expand Down
Loading