You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently the cdf function for InverseGamma returns obviously wrong values. The current implementation returns a monotonically decreasing function which clearly cannot be right.
There is a recent PR which has been merged which should actually fix the issue. To confirm this I ran my code on the nightly PyTorch release and the issue indeed goes away.
Environment
Pyro version: 1.8.6
PyTorch version:
contains bug: 2.1.0+cu121
nightly version that fixes bug: 2.2.0.dev20231107+cpu
OS: Ubuntu 22.04.2 LTS
Python version: 3.10.12
Code Snippet
import pyro.distributions as dist
import matplotlib.pyplot as plt
import torch
xs = torch.linspace(0, 3, 200)
cdfs = dist.InverseGamma(1.0, 1.0).cdf(xs)
# Visual inspection of CDF. See output below.
plt.plot(xs, cdfs)
# gammaincc allows computing the analytic form of the InverseGamma cdf.
# This evaluate to True in version '2.1.0+cu121' but False in '2.2.0.dev20231107+cpu'.
torch.allclose(cdfs, 1 - torch.special.gammaincc(torch.ones(200), torch.ones(200) / xs))
# If the cdfs would be correct the following statement should evaluate to True instead.
torch.allclose(cdfs, torch.special.gammaincc(torch.ones(200), torch.ones(200) / xs))
The text was updated successfully, but these errors were encountered:
Issue Description
Currently the
cdf
function forInverseGamma
returns obviously wrong values. The current implementation returns a monotonically decreasing function which clearly cannot be right.On closer inspection, currently the
cdf(x)
function actually return the value1 - cdf(x)
. I think this is due to an issue in that for current releases of PyTorch thePowerTransform
has a hard-coded sign of +1 (offending line: https://github.com/pytorch/pytorch/blob/7bcf7da3a268b435777fe87c7794c382f444e86d/torch/distributions/transforms.py#L567).There is a recent PR which has been merged which should actually fix the issue. To confirm this I ran my code on the nightly PyTorch release and the issue indeed goes away.
Environment
Pyro version:
1.8.6
PyTorch version:
2.1.0+cu121
2.2.0.dev20231107+cpu
OS:
Ubuntu 22.04.2 LTS
Python version:
3.10.12
Code Snippet
The text was updated successfully, but these errors were encountered: