Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions src/pted/pted.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def pted(
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
two_tailed: bool = True,
) -> Union[float, tuple[float, np.ndarray]]:
) -> Union[float, tuple[float, np.ndarray, float]]:
"""
Two sample null hypothesis test using a permutation test on the energy
distance.
Expand Down Expand Up @@ -78,7 +78,8 @@ def pted(
using PyTorch, note that the metric is passed as the "p" for
torch.cdist and therefore is a float from 0 to inf.
return_all (bool): if True, return the test statistic and the permuted
statistics. If False, just return the p-value. bool (False by default)
statistics with the p-value. If False, just return the p-value.
bool (default: False)
chunk_size (Optional[int]): if not None, use chunked energy distance
estimation. This is useful for large datasets. The chunk size is the
number of samples to use for each chunk. If None, use the full
Expand Down Expand Up @@ -145,15 +146,19 @@ def pted(
test, permute = pted_numpy(x, y, permutations=permutations, metric=metric)

permute = np.array(permute)
if return_all:
return test, permute

# Compute p-value
if two_tailed:
q = 2 * min(np.sum(permute >= test), np.sum(permute <= test))
q = min(q, permutations)
else:
q = np.sum(permute >= test)
return (1.0 + q) / (1.0 + permutations)

p = (1.0 + q) / (1.0 + permutations)

if return_all:
return test, permute, p
return p


def pted_coverage_test(
Expand All @@ -165,7 +170,7 @@ def pted_coverage_test(
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
) -> Union[float, tuple[np.ndarray, np.ndarray]]:
) -> Union[float, tuple[np.ndarray, np.ndarray, float]]:
"""
Coverage test using a permutation test on the energy distance.

Expand Down Expand Up @@ -221,8 +226,8 @@ def pted_coverage_test(
PyTorch, note that the metric is passed as the "p" for torch.cdist and
therefore is a float from 0 to inf.
return_all (bool): if True, return the test statistic and the permuted
statistics. If False, just return the p-value. bool (False by
default)
statistics with the p-value. If False, just return the p-value. bool
(default: False)
chunk_size (Optional[int]): if not None, use chunked energy distance
estimation. This is useful for large datasets. The chunk size is the
number of samples to use for each chunk. If None, use the full
Expand Down Expand Up @@ -259,29 +264,34 @@ def pted_coverage_test(

test_stats = []
permute_stats = []
pvals = []
for i in range(nsim):
test, permute = pted(
test, permute, p = pted(
g[:, i],
s[:, i],
permutations=permutations,
metric=metric,
return_all=True,
two_tailed=False,
chunk_size=chunk_size,
chunk_iter=chunk_iter,
)
test_stats.append(test)
permute_stats.append(permute)
pvals.append(p)
test_stats = np.array(test_stats)
permute_stats = np.stack(permute_stats)

if return_all:
return test_stats, permute_stats
pvals = np.array(pvals)

# Compute p-values
if nsim == 1:
return np.mean(permute_stats > test_stats[0])
pvals = (1.0 + np.sum(permute_stats > test_stats[:, None], axis=1)) / (1.0 + permutations)
return pvals[0]
chi2 = np.sum(-2 * np.log(pvals))
if warn_confidence is not None:
if warn_confidence is not None and warn_confidence is not False:
confidence_alert(chi2, 2 * nsim, warn_confidence)
return two_tailed_p(chi2, 2 * nsim)

p = two_tailed_p(chi2, 2 * nsim)

if return_all:
return test_stats, permute_stats, p
return p
4 changes: 2 additions & 2 deletions tests/test_pted.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_pted_torch():
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

x = torch.randn(100, D)
t, p = pted.pted(x, x, return_all=True)
t, p, _ = pted.pted(x, x, return_all=True)
q = 2 * min(np.sum(p > t), np.sum(p < t))
p = (1 + q) / (len(p) + 1) # add one to numerator and denominator to avoid p=0
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"
Expand All @@ -64,7 +64,7 @@ def test_pted_coverage_full():
size=(200, 100, 10)
) # posterior samples (n_samples, n_simulations, n_dimensions)

test, permute = pted.pted_coverage_test(g, s, permutations=100, return_all=True)
test, permute, _ = pted.pted_coverage_test(g, s, permutations=100, return_all=True)
assert test.shape == (100,)
assert permute.shape == (100, 100)

Expand Down
Loading