Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
strategy:
fail-fast: true
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
os: [ubuntu-latest, windows-latest, macOS-latest]

steps:
Expand All @@ -49,7 +49,12 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-cov torch wheel
pip install pytest pytest-cov wheel

- name: Install torch # no torch on 3.11 to test no-torch scenario
if: ${{ matrix.python-version != '3.11' }}
run: |
pip install torch

# We only want to install this on one run, because otherwise we'll have
# duplicate annotations.
Expand Down
78 changes: 69 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,32 @@ To install PTED, run the following:
pip install pted
```

## Usage
If you want to run PTED on GPUs using PyTorch, then also install torch:

```bash
pip install torch
```

The two functions are ``pted.pted`` and ``pted.pted_coverage_test``. For
information about each argument, just use ``help(pted.pted)`` or
``help(pted.pted_coverage_test)``.

## What does PTED do?

PTED (pronounced "ted") takes in `x` and `y` two datasets and determines if they
come from the same underlying distribution. For information about each argument,
just use ``help(pted.pted)`` or ``help(pted.pted_coverage_test)``.
come from the same underlying distribution.

PTED is useful for:

The returned value is a p-value, an estimate of the probability of a more
extreme instance occurring. Under the null hypothesis, a p-value is drawn from a
random uniform distribution (range 0 to 1). If the null hypothesis is false, one
would expect to see very low p-values and so one can set a limit such as
`p=0.01` below which we reject the null hypothesis. In this case `1/100`th of
the time even when the null hypothesis is true, we will reject the null.
* "were these two samples drawn from the same distribution?" this works even with noise, so long as the noise distribution is also the same for each sample
* Evaluate the coverage of a posterior sampling procedure
* Check for MCMC chain convergence. Split the chain in half or take two chains, that's two samples, if the chain is well mixed then these ought to be drawn from the same distribution
* Evaluate the performance of a generative model. PTED is powerful here as it can detect overfitting to the training sample.
* Evaluate if a simulator generates true "data-like" samples
* PTED can be a distance metric for Approximate Bayesian Computing posteriors
* Check for drift in a time series, comparing samples before/after some cutoff time

And much more!

## Example: Two-Sample-Test

Expand Down Expand Up @@ -123,6 +137,52 @@ test on this chi2 distribution meaning that if your posterior is underconfident
or overconfident, you will get a small p-value that can be used to reject the
null.

## Example: Sensitivity comparison with KS-test

There is no single universally optimal two sample test, but a widely used method
in 1D is called the Kolmogorov-Smirnov (KS)-test. The KS-test operates
fundamentally differently from PTED and can only really work in 1D. Here I do a
super basic comparison of the two methods. Draw two samples of 100 Gaussian
distributed points, thus the null hypothesis is true for these points. Then
slowly bias one of the samples by changing the standard deviation up to 2 sigma.
By tracking how the p-value drops we can see which method is more sensitive to
this kind of mismatched sample. If you run this test a hundred times you will
find that PTED is more sensitive to this kind of bias than the KS-test. Observe
that both methods start around p=0.5 in the true null case (scale = 1), since
they are both exact tests that truly sample U(0,1) under the null.

```python
from pted import pted
import numpy as np
from scipy.stats import kstest
import matplotlib.pyplot as plt

np.random.seed(0)

scale = np.linspace(1.0, 2.0, 10)
pted_p = np.zeros((10, 100))
ks_p = np.zeros((10, 100))
for i, s in enumerate(scale):
for trial in range(100):
x = np.random.normal(size=(100, 1))
y = np.random.normal(scale=s, size=(100, 1))
pted_p[i][trial] = pted(x, y, two_tailed=False)
ks_p[i][trial] = kstest(x[:, 0], y[:, 0]).pvalue

plt.plot(scale, np.mean(pted_p, axis=1), linewidth=3, c="b", label="PTED")
plt.plot(scale, np.mean(ks_p, axis=1), linewidth=3, c="r", label="KS")
plt.legend()
plt.ylim(0, None)
plt.xlim(1, 2.0)
plt.xlabel("Out of distribution scale [*sigma]")
plt.ylabel("p-value")

plt.savefig("pted_demo.png", bbox_inches="tight")
plt.show()
```

![pted demo KS comparison](media/pted_ks.png)

## Interpreting the results

### Two sample test
Expand Down
Binary file added media/pted_ks.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ keywords = [
"pytorch"
]
classifiers=[
"Development Status :: 1 - Planning",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
Expand All @@ -40,6 +40,10 @@ dev = [
"pytest>=8.0,<9",
"pytest-cov>=4.1,<5",
"pytest-mock>=3.12,<4",
"torch>=2.0,<3",
]
torch = [
"torch>=2.0,<3",
]

[tool.hatch.metadata.hooks.requirements_txt]
Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
numpy
scipy
torch
scipy
48 changes: 27 additions & 21 deletions src/pted/pted.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Union, Optional
import numpy as np
from scipy.stats import chi2 as chi2_dist
from torch import Tensor

from .utils import (
_pted_torch,
_pted_numpy,
_pted_chunk_torch,
_pted_chunk_numpy,
is_torch_tensor,
pted_torch,
pted_numpy,
pted_chunk_torch,
pted_chunk_numpy,
two_tailed_p,
confidence_alert,
)
Expand All @@ -16,13 +15,14 @@


def pted(
x: Union[np.ndarray, Tensor],
y: Union[np.ndarray, Tensor],
x: Union[np.ndarray, "Tensor"],
y: Union[np.ndarray, "Tensor"],
permutations: int = 1000,
metric: Union[str, float] = "euclidean",
return_all: bool = False,
chunk_size: Optional[int] = None,
chunk_iter: Optional[int] = None,
two_tailed: bool = True,
) -> Union[float, tuple[float, np.ndarray]]:
"""
Two sample null hypothesis test using a permutation test on the energy
Expand Down Expand Up @@ -51,8 +51,8 @@ def pted(
z = shuffle(z)
x, y = z[:nx], z[nx:]
permute_stats.append(energy_distance(x, y))
p = mean(permute_stats > test_stat)
return p
p = sum(permute_stats > test_stat)
return (1 + p) / (1 + permutations)

Example
-------
Expand Down Expand Up @@ -85,6 +85,9 @@ def pted(
dataset.
chunk_iter (Optional[int]): The chunk iter is the number of iterations
to use with the given chunk size.
two_tailed (bool): if True, compute a two-tailed p-value. This is useful
if you want to reject the null hypothesis when x and y are either
too similar or too different. Default is True.

Note
----
Expand Down Expand Up @@ -118,19 +121,19 @@ def pted(
if len(y.shape) > 2:
y = y.reshape(y.shape[0], -1)

if isinstance(x, Tensor) and chunk_size is not None:
test, permute = _pted_chunk_torch(
if is_torch_tensor(x) and chunk_size is not None:
test, permute = pted_chunk_torch(
x,
y,
permutations=permutations,
metric=metric,
chunk_size=int(chunk_size),
chunk_iter=int(chunk_iter),
)
elif isinstance(x, Tensor):
test, permute = _pted_torch(x, y, permutations=permutations, metric=metric)
elif is_torch_tensor(x):
test, permute = pted_torch(x, y, permutations=permutations, metric=metric)
elif chunk_size is not None:
test, permute = _pted_chunk_numpy(
test, permute = pted_chunk_numpy(
x,
y,
permutations=permutations,
Expand All @@ -139,19 +142,23 @@ def pted(
chunk_iter=int(chunk_iter),
)
else:
test, permute = _pted_numpy(x, y, permutations=permutations, metric=metric)
test, permute = pted_numpy(x, y, permutations=permutations, metric=metric)

permute = np.array(permute)
if return_all:
return test, permute

# Compute p-value
return np.mean(permute > test)
if two_tailed:
q = 2 * min(np.sum(permute >= test), np.sum(permute <= test))
else:
q = np.sum(permute >= test)
return (1.0 + q) / (1.0 + permutations)


def pted_coverage_test(
g: Union[np.ndarray, Tensor],
s: Union[np.ndarray, Tensor],
g: Union[np.ndarray, "Tensor"],
s: Union[np.ndarray, "Tensor"],
permutations: int = 1000,
metric: str = "euclidean",
warn_confidence: Optional[float] = 1e-3,
Expand Down Expand Up @@ -273,8 +280,7 @@ def pted_coverage_test(
# Compute p-values
if nsim == 1:
return np.mean(permute_stats > test_stats[0])
pvals = np.mean(permute_stats > test_stats[:, None], axis=1)
pvals[pvals == 0] = 1.0 / permutations # handle pvals == 0
pvals = (1.0 + np.sum(permute_stats > test_stats[:, None], axis=1)) / (1.0 + permutations)
chi2 = np.sum(-2 * np.log(pvals))
if warn_confidence is not None:
confidence_alert(chi2, 2 * nsim, warn_confidence)
Expand Down
12 changes: 6 additions & 6 deletions src/pted/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ def test():
x = np.random.normal(size=(100, D))
y = np.random.normal(size=(100, D))
p = pted(x, y)
assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))"
assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))"

x = np.random.normal(size=(100, D))
y = np.random.uniform(size=(100, D))
p = pted(x, y)
assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

x = np.random.normal(size=(100, D))
p = pted(x, x)
assert p > 0.9999, f"p-value {p} is not in the expected range (~1)"
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

# example coverage
n_sims = 100
Expand All @@ -43,12 +43,12 @@ def test():

# correct
p = pted_coverage_test(g, s_corr, permutations=200)
assert p > 1e-4 and p < 0.9999, f"p-value {p} is not in the expected range (U(0,1))"
assert p > 1e-2 and p < 0.99, f"p-value {p} is not in the expected range (U(0,1))"
# overconfident
p = pted_coverage_test(g, s_over, permutations=200, warn_confidence=None)
assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"
# underconfident
p = pted_coverage_test(g, s_under, permutations=200, warn_confidence=None)
assert p < 1e-4, f"p-value {p} is not in the expected range (~0)"
assert p < 1e-2, f"p-value {p} is not in the expected range (~0)"

print("Tests passed!")
41 changes: 34 additions & 7 deletions src/pted/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,38 @@

import numpy as np
from scipy.spatial.distance import cdist
import torch
from scipy.stats import chi2 as chi2_dist
from scipy.optimize import root_scalar

try:
import torch
except ImportError:

__all__ = ["_pted_numpy", "_pted_chunk_numpy", "_pted_torch", "_pted_chunk_torch"]
class torch:
__version__ = "null"
Tensor = np.ndarray


__all__ = (
"is_torch_tensor",
"pted_numpy",
"pted_chunk_numpy",
"pted_torch",
"pted_chunk_torch",
"two_tailed_p",
"confidence_alert",
)


def is_torch_tensor(o):
t = type(o)
return (
hasattr(t, "__module__")
and t.__module__.startswith("torch")
and hasattr(o, "device")
and hasattr(o, "dtype")
and hasattr(o, "shape")
)


def _energy_distance_precompute(
Expand Down Expand Up @@ -82,7 +108,7 @@ def _energy_distance_estimate_torch(
return np.mean(E_est)


def _pted_chunk_numpy(
def pted_chunk_numpy(
x: np.ndarray,
y: np.ndarray,
permutations: int = 100,
Expand All @@ -105,14 +131,15 @@ def _pted_chunk_numpy(
return test_stat, permute_stats


def _pted_chunk_torch(
def pted_chunk_torch(
x: torch.Tensor,
y: torch.Tensor,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
chunk_size: int = 100,
chunk_iter: int = 10,
) -> tuple[float, list[float]]:
assert torch.__version__ != "null", "PyTorch is not installed! try: `pip install torch`"
assert torch.all(torch.isfinite(x)) and torch.all(
torch.isfinite(y)
), "Input contains NaN or Inf!"
Expand All @@ -130,7 +157,7 @@ def _pted_chunk_torch(
return test_stat, permute_stats


def _pted_numpy(
def pted_numpy(
x: np.ndarray, y: np.ndarray, permutations: int = 100, metric: str = "euclidean"
) -> tuple[float, list[float]]:
z = np.concatenate((x, y), axis=0)
Expand All @@ -151,13 +178,13 @@ def _pted_numpy(
return test_stat, permute_stats


@torch.no_grad()
def _pted_torch(
def pted_torch(
x: torch.Tensor,
y: torch.Tensor,
permutations: int = 100,
metric: Union[str, float] = "euclidean",
) -> tuple[float, list[float]]:
assert torch.__version__ != "null", "PyTorch is not installed! try: `pip install torch`"
z = torch.cat((x, y), dim=0)
assert torch.all(torch.isfinite(z)), "Input contains NaN or Inf!"
if metric == "euclidean":
Expand Down
Loading
Loading