Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 90 additions & 11 deletions src/pted/pted.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Union
from typing import Union, Optional
import numpy as np
from scipy.stats import chi2 as chi2_dist
from torch import Tensor

from .utils import _pted_torch, _pted_numpy
from .utils import _pted_torch, _pted_numpy, _pted_chunk_torch, _pted_chunk_numpy

__all__ = ["pted", "pted_coverage_test"]

Expand All @@ -14,6 +14,8 @@ def pted(
permutations: int = 1000,
metric: str = "euclidean",
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
):
"""
Two sample test using a permutation test on the energy distance.
Expand All @@ -25,12 +27,32 @@ def pted(
permutations (int): number of permutations to run. This determines how
accurately the p-value is computed.
metric (str): distance metric to use. See scipy.spatial.distance.cdist
for the list of available metrics with numpy. See torch.cdist when using
PyTorch, note that the metric is passed as the "p" for torch.cdist and
therefore is a float from 0 to inf.
for the list of available metrics with numpy. See torch.cdist when
using PyTorch, note that the metric is passed as the "p" for
torch.cdist and therefore is a float from 0 to inf.
return_all (bool): if True, return the test statistic and the permuted
statistics. If False, just return the p-value. bool (False by
default)
statistics. If False, just return the p-value. bool (False by default)
chunk_size (Optional[int]): if not None, use chunked energy distance
estimation. This is useful for large datasets. The chunk size is the
number of samples to use for each chunk. If None, use the full
dataset.
chunk_iter (Optional[int]): The chunk iter is the number of iterations
to use with the given chunk size.

Note
----
PTED has O(n^2 * D * P) time complexity, where n is the number of
samples in x and y, D is the number of dimensions, and P is the number
of permutations. For large datasets this can get unwieldy, so chunking
is recommended. For chunking, the energy distance will be estimated at
each iteration rather than fully computed. To estimate the energy
distance, we take `chunk_size` sub-samples from x and y, and compute the
energy distance on those sub-samples. This is repeated `chunk_iter`
times, and the average is taken. This is a trade-off between speed and
accuracy. The larger the chunk size and larger chunk_iter, the more
accurate the estimate, but the slower the computation. PTED remains an
exact p-value test even when chunking, it simply becomes less sensitive
to the difference between x and y.
"""
assert type(x) == type(y), f"x and y must be of the same type, not {type(x)} and {type(y)}"
assert len(x.shape) >= 2, f"x must be at least 2D, not {x.shape}"
Expand All @@ -43,9 +65,34 @@ def pted(
if len(y.shape) > 2:
y = y.reshape(y.shape[0], -1)

if isinstance(x, Tensor) and isinstance(y, Tensor):
return _pted_torch(x, y, permutations=permutations, metric=metric, return_all=return_all)
return _pted_numpy(x, y, permutations=permutations, metric=metric, return_all=return_all)
if isinstance(x, Tensor) and chunk_size is not None:
test, permute = _pted_chunk_torch(
x,
y,
permutations=permutations,
metric=metric,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
)
elif isinstance(x, Tensor):
test, permute = _pted_torch(x, y, permutations=permutations, metric=metric)
elif chunk_size is not None:
test, permute = _pted_chunk_numpy(
x,
y,
permutations=permutations,
metric=metric,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
)
else:
test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric)

if return_all:
return test, permute

# Compute p-value
return np.mean(np.array(permute) > test)


def pted_coverage_test(
Expand All @@ -54,6 +101,8 @@ def pted_coverage_test(
permutations: int = 1000,
metric: str = "euclidean",
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
):
"""
Coverage test using a permutation test on the energy distance.
Expand All @@ -71,6 +120,27 @@ def pted_coverage_test(
return_all (bool): if True, return the test statistic and the permuted
statistics. If False, just return the p-value. bool (False by
default)
chunk_size (Optional[int]): if not None, use chunked energy distance
estimation. This is useful for large datasets. The chunk size is the
number of samples to use for each chunk. If None, use the full
dataset.
chunk_iter (Optional[int]): The chunk iter is the number of iterations
to use with the given chunk size.

Note
----
PTED has O(n^2 * D * P) time complexity, where n is the number of
samples in x and y, D is the number of dimensions, and P is the number
of permutations. For large datasets this can get unwieldy, so chunking
is recommended. For chunking, the energy distance will be estimated at
each iteration rather than fully computed. To estimate the energy
distance, we take `chunk_size` sub-samples from x and y, and compute the
energy distance on those sub-samples. This is repeated `chunk_iter`
times, and the average is taken. This is a trade-off between speed and
accuracy. The larger the chunk size and larger chunk_iter, the more
accurate the estimate, but the slower the computation. PTED remains an
exact p-value test even when chunking, it simply becomes less sensitive
to the difference between x and y.
"""
nsamp, nsim, *D = s.shape
assert (
Expand All @@ -79,18 +149,27 @@ def pted_coverage_test(
if len(s.shape) > 3:
s = s.reshape(nsamp, nsim, -1)
g = g.reshape(1, nsim, -1)

test_stats = []
permute_stats = []
for i in range(nsim):
test, permute = pted(
g[:, i], s[:, i], permutations=permutations, metric=metric, return_all=True
g[:, i],
s[:, i],
permutations=permutations,
metric=metric,
return_all=True,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
)
test_stats.append(test)
permute_stats.append(permute)
test_stats = np.array(test_stats)
permute_stats = np.array(permute_stats)

if return_all:
return test_stats, permute_stats

# Compute p-values
pvals = np.mean(permute_stats > test_stats[:, None], axis=1)
pvals[pvals == 0] = 1.0 / permutations # handle pvals == 0
Expand Down
103 changes: 78 additions & 25 deletions src/pted/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,41 +2,96 @@
from scipy.spatial.distance import cdist
import torch

__all__ = ["_pted_numpy", "_pted_torch"]
__all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"]


def _energy_distance_precompute(D, ix, iy):
nx = len(ix)
ny = len(iy)
Exx = (D[ix.reshape(nx, 1), ix.reshape(1, nx)]).sum() / nx**2
Eyy = (D[iy.reshape(ny, 1), iy.reshape(1, ny)]).sum() / ny**2
Exy = (D[ix.reshape(nx, 1), iy.reshape(1, ny)]).sum() / (nx * ny)
def _energy_distance_precompute(D, nx, ny):
Exx = D[:nx, :nx].sum() / nx**2
Eyy = D[nx:, nx:].sum() / ny**2
Exy = D[:nx, nx:].sum() / (nx * ny)
return 2 * Exy - Exx - Eyy


def _pted_numpy(x, y, permutations=100, metric="euclidean", return_all=False):
def _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric="euclidean"):
is_torch = isinstance(x, torch.Tensor)

E_est = []
for _ in range(chunk_iter):
# Randomly sample a chunk of data
idx = np.random.choice(len(x), size=min(len(x), chunk_size), replace=False)
if is_torch:
idx = torch.tensor(idx, device=x.device)
x_chunk = x[idx]
idy = np.random.choice(len(y), size=min(len(y), chunk_size), replace=False)
if is_torch:
idy = torch.tensor(idy, device=y.device)
y_chunk = y[idy]

# Compute the distance matrix
if is_torch:
z_chunk = torch.cat((x_chunk, y_chunk), dim=0)
else:
z_chunk = np.concatenate((x_chunk, y_chunk), axis=0)
dmatrix = cdist(z_chunk, z_chunk, metric=metric)

# Compute the energy distance
E_est.append(_energy_distance_precompute(dmatrix, len(x_chunk), len(y_chunk)))
if is_torch:
E_est[-1] = E_est[-1].item()
return np.mean(E_est)


def _pted_chunk_numpy(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10):
assert np.all(np.isfinite(x)) and np.all(np.isfinite(y)), "Input contains NaN or Inf!"
nx = len(x)

test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)
permute_stats = []
for _ in range(permutations):
z = np.concatenate((x, y), axis=0)
z = z[np.random.permutation(len(z))]
x, y = z[:nx], z[nx:]
permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric))
return test_stat, permute_stats


def _pted_chunk_torch(x, y, permutations=100, metric="euclidean", chunk_size=100, chunk_iter=10):
assert torch.all(torch.isfinite(x)) and torch.all(
torch.isfinite(y)
), "Input contains NaN or Inf!"
nx = len(x)

test_stat = _energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric)
permute_stats = []
for _ in range(permutations):
z = torch.cat((x, y), dim=0)
z = z[torch.randperm(len(z))]
x, y = z[:nx], z[nx:]
permute_stats.append(_energy_distance_estimate(x, y, chunk_size, chunk_iter, metric=metric))
return test_stat, permute_stats


def _pted_numpy(x, y, permutations=100, metric="euclidean"):
z = np.concatenate((x, y), axis=0)
assert np.all(np.isfinite(z)), "Input contains NaN or Inf!"
dmatrix = cdist(z, z, metric=metric)
assert np.all(
np.isfinite(dmatrix)
), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)."
nx = len(x)
I = np.arange(len(z))
ny = len(y)

test_stat = _energy_distance_precompute(dmatrix, I[:nx], I[nx:])
test_stat = _energy_distance_precompute(dmatrix, nx, ny)
permute_stats = []
for _ in range(permutations):
np.random.shuffle(I)
permute_stats.append(_energy_distance_precompute(dmatrix, I[:nx], I[nx:]))
if return_all:
return test_stat, permute_stats
# Compute p-value
return np.mean(np.array(permute_stats) > test_stat)
I = np.random.permutation(len(z))
dmatrix = dmatrix[I][:, I]
permute_stats.append(_energy_distance_precompute(dmatrix, nx, ny))
return test_stat, permute_stats


@torch.no_grad()
def _pted_torch(x, y, permutations=100, metric="euclidean", return_all=False):
def _pted_torch(x, y, permutations=100, metric="euclidean"):
z = torch.cat((x, y), dim=0)
assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!"
if metric == "euclidean":
Expand All @@ -46,14 +101,12 @@ def _pted_torch(x, y, permutations=100, metric="euclidean", return_all=False):
torch.isfinite(dmatrix)
), "Distance matrix contains NaN or Inf! Consider using a different metric or normalizing values to be more stable (i.e. z-score norm)."
nx = len(x)
I = torch.arange(len(z))
ny = len(y)

test_stat = _energy_distance_precompute(dmatrix, I[:nx], I[nx:]).item()
test_stat = _energy_distance_precompute(dmatrix, nx, ny).item()
permute_stats = []
for _ in range(permutations):
I = I[torch.randperm(len(I))]
permute_stats.append(_energy_distance_precompute(dmatrix, I[:nx], I[nx:]).item())
if return_all:
return test_stat, permute_stats
# Compute p-value
return np.mean(np.array(permute_stats) > test_stat)
I = torch.randperm(len(z))
dmatrix = dmatrix[I][:, I]
permute_stats.append(_energy_distance_precompute(dmatrix, nx, ny).item())
return test_stat, permute_stats
31 changes: 31 additions & 0 deletions tests/test_pted.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,34 @@ def test_pted_coverage_full():
test, permute = pted.pted_coverage_test(g, s, permutations=100, return_all=True)
assert test.shape == (100,)
assert permute.shape == (100, 100)


def test_pted_chunk_torch():
np.random.seed(42)
torch.manual_seed(42)

# example 2 sample test
D = 10
x = torch.randn(1000, D)
y = torch.randn(1000, D)
p = pted.pted(x, y, chunk_size=100, chunk_iter=10)
assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))"

y = torch.rand(1000, D)
p = pted.pted(x, y, chunk_size=100, chunk_iter=10)
assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"


def test_pted_chunk_numpy():
np.random.seed(42)

# example 2 sample test
D = 10
x = np.random.normal(size=(1000, D))
y = np.random.normal(size=(1000, D))
p = pted.pted(x, y, chunk_size=100, chunk_iter=10)
assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))"

y = np.random.uniform(size=(1000, D))
p = pted.pted(x, y, chunk_size=100, chunk_iter=10)
assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"
Loading