diff --git a/src/pted/pted.py b/src/pted/pted.py index cfb4420..e4bca39 100644 --- a/src/pted/pted.py +++ b/src/pted/pted.py @@ -23,7 +23,7 @@ def pted( chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, two_tailed: bool = True, -) -> Union[float, tuple[float, np.ndarray]]: +) -> Union[float, tuple[float, np.ndarray, float]]: """ Two sample null hypothesis test using a permutation test on the energy distance. @@ -78,7 +78,8 @@ def pted( using PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. return_all (bool): if True, return the test statistic and the permuted - statistics. If False, just return the p-value. bool (False by default) + statistics with the p-value. If False, just return the p-value. + bool (default: False) chunk_size (Optional[int]): if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full @@ -145,15 +146,19 @@ def pted( test, permute = pted_numpy(x, y, permutations=permutations, metric=metric) permute = np.array(permute) - if return_all: - return test, permute # Compute p-value if two_tailed: q = 2 * min(np.sum(permute >= test), np.sum(permute <= test)) + q = min(q, permutations) else: q = np.sum(permute >= test) - return (1.0 + q) / (1.0 + permutations) + + p = (1.0 + q) / (1.0 + permutations) + + if return_all: + return test, permute, p + return p def pted_coverage_test( @@ -165,7 +170,7 @@ def pted_coverage_test( return_all: bool = False, chunk_size: Optional[int] = None, chunk_iter: Optional[int] = None, -) -> Union[float, tuple[np.ndarray, np.ndarray]]: +) -> Union[float, tuple[np.ndarray, np.ndarray, float]]: """ Coverage test using a permutation test on the energy distance. @@ -221,8 +226,8 @@ def pted_coverage_test( PyTorch, note that the metric is passed as the "p" for torch.cdist and therefore is a float from 0 to inf. return_all (bool): if True, return the test statistic and the permuted - statistics. If False, just return the p-value. bool (False by - default) + statistics with the p-value. If False, just return the p-value. bool + (default: False) chunk_size (Optional[int]): if not None, use chunked energy distance estimation. This is useful for large datasets. The chunk size is the number of samples to use for each chunk. If None, use the full @@ -259,29 +264,34 @@ def pted_coverage_test( test_stats = [] permute_stats = [] + pvals = [] for i in range(nsim): - test, permute = pted( + test, permute, p = pted( g[:, i], s[:, i], permutations=permutations, metric=metric, return_all=True, + two_tailed=False, chunk_size=chunk_size, chunk_iter=chunk_iter, ) test_stats.append(test) permute_stats.append(permute) + pvals.append(p) test_stats = np.array(test_stats) permute_stats = np.stack(permute_stats) - - if return_all: - return test_stats, permute_stats + pvals = np.array(pvals) # Compute p-values if nsim == 1: - return np.mean(permute_stats > test_stats[0]) - pvals = (1.0 + np.sum(permute_stats > test_stats[:, None], axis=1)) / (1.0 + permutations) + return pvals[0] chi2 = np.sum(-2 * np.log(pvals)) - if warn_confidence is not None: + if warn_confidence is not None and warn_confidence is not False: confidence_alert(chi2, 2 * nsim, warn_confidence) - return two_tailed_p(chi2, 2 * nsim) + + p = two_tailed_p(chi2, 2 * nsim) + + if return_all: + return test_stats, permute_stats, p + return p diff --git a/tests/test_pted.py b/tests/test_pted.py index 9f3358c..252350b 100644 --- a/tests/test_pted.py +++ b/tests/test_pted.py @@ -52,7 +52,7 @@ def test_pted_torch(): assert p < 1e-2, f"p-value {p} is not in the expected range (~0)" x = torch.randn(100, D) - t, p = pted.pted(x, x, return_all=True) + t, p, _ = pted.pted(x, x, return_all=True) q = 2 * min(np.sum(p > t), np.sum(p < t)) p = (1 + q) / (len(p) + 1) # add one to numerator and denominator to avoid p=0 assert p < 1e-2, f"p-value {p} is not in the expected range (~0)" @@ -64,7 +64,7 @@ def test_pted_coverage_full(): size=(200, 100, 10) ) # posterior samples (n_samples, n_simulations, n_dimensions) - test, permute = pted.pted_coverage_test(g, s, permutations=100, return_all=True) + test, permute, _ = pted.pted_coverage_test(g, s, permutations=100, return_all=True) assert test.shape == (100,) assert permute.shape == (100, 100)