Skip to content

Commit 87a4b08

Browse files
pstefanou12claude
andcommitted
Fix imports and add calc_suff_stat/to_natural/to_canonical staticmethods.
Replace from-imports of torch.distributions classes with module-level imports; use ch.distributions.* as base classes. Move calc_suff_stat, to_natural, and to_canonical as @staticmethods onto each distribution class. Remove redundant _calc_suff_stat overrides from truncated distribution subclasses and inline direct calls where needed. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 34d70ae commit 87a4b08

11 files changed

Lines changed: 89 additions & 73 deletions

delphi/distributions/boolean_product.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
"""Boolean product distribution in natural parameterization."""
33

44
import torch as ch
5-
from torch.distributions import Bernoulli
5+
import torch.distributions
66

77

8-
def calc_bool_prod_suff_stat(x):
9-
"""Return sufficient statistics for boolean product distribution."""
10-
return x
11-
12-
13-
class ExponentialFamilyBooleanProduct(Bernoulli): # pylint: disable=abstract-method
8+
class ExponentialFamilyBooleanProduct(ch.distributions.Bernoulli): # pylint: disable=abstract-method
149
"""Boolean product distribution parameterized by natural parameters."""
1510

1611
def __init__(self, theta: ch.Tensor, dims: int):
@@ -19,6 +14,21 @@ def __init__(self, theta: ch.Tensor, dims: int):
1914
p = ch.exp(theta) / (1 + ch.exp(theta))
2015
super().__init__(p)
2116

17+
@staticmethod
18+
def calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
19+
"""Return sufficient statistics for boolean product distribution."""
20+
return x
21+
22+
@staticmethod
23+
def to_natural(theta: ch.Tensor) -> ch.Tensor:
24+
"""Convert canonical probability to natural log-odds parameter."""
25+
return ch.log(theta / (1 - theta))
26+
27+
@staticmethod
28+
def to_canonical(theta: ch.Tensor) -> ch.Tensor:
29+
"""Convert natural log-odds to canonical probability parameter."""
30+
return ch.exp(theta) / (1 + ch.exp(theta))
31+
2232
def log_prob(self, value):
2333
"""Compute summed log probability over all dimensions."""
2434
result = super().log_prob(value)

delphi/distributions/exponential.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
"""Exponential distribution in natural parameterization."""
33

44
import torch as ch
5-
from torch.distributions import Exponential
5+
import torch.distributions
66

77

8-
def calc_exp_suff_stat(x):
9-
"""Return sufficient statistics for exponential distribution."""
10-
return x
11-
12-
13-
class ExponentialFamilyExponential(Exponential): # pylint: disable=abstract-method
8+
class ExponentialFamilyExponential(ch.distributions.Exponential): # pylint: disable=abstract-method
149
"""Exponential distribution parameterized by natural parameters."""
1510

1611
def __init__(self, theta: ch.Tensor, dims: int):
@@ -19,6 +14,21 @@ def __init__(self, theta: ch.Tensor, dims: int):
1914
lambda_ = -theta
2015
super().__init__(lambda_)
2116

17+
@staticmethod
18+
def calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
19+
"""Return sufficient statistics for exponential distribution."""
20+
return x
21+
22+
@staticmethod
23+
def to_natural(theta: ch.Tensor) -> ch.Tensor:
24+
"""Convert canonical rate parameter to natural form."""
25+
return -theta
26+
27+
@staticmethod
28+
def to_canonical(theta: ch.Tensor) -> ch.Tensor:
29+
"""Convert natural parameters to canonical rate parameter."""
30+
return -theta
31+
2232
def log_prob(self, value):
2333
"""Compute summed log probability over all dimensions."""
2434
result = super().log_prob(value)

delphi/distributions/multivariate_normal.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,46 @@
22
"""Multivariate normal distribution in natural parameterization."""
33

44
import torch as ch
5-
from torch import Tensor
6-
from torch.distributions import MultivariateNormal
5+
import torch.distributions
76

87

98
class ExponentialFamilyMultivariateNormalKnownCovariance( # pylint: disable=abstract-method
10-
MultivariateNormal
9+
ch.distributions.MultivariateNormal
1110
):
1211
"""Multivariate normal parameterized by natural parameters with known covariance."""
1312

14-
def __init__(self, covariance_matrix: Tensor, theta: Tensor, dims: int):
13+
def __init__(self, covariance_matrix: ch.Tensor, theta: ch.Tensor, dims: int):
1514
"""Initialize with covariance matrix, natural parameter theta, and dimension."""
1615
self.dims = dims
1716
v = theta
1817
mu = (covariance_matrix @ v).view(self.dims)
1918
super().__init__(mu, covariance_matrix)
2019

2120
@staticmethod
22-
def calc_suff_stat(x: Tensor) -> Tensor:
21+
def calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
2322
"""Return sufficient statistics for multivariate normal with known covariance."""
2423
return x
2524

2625
@staticmethod
27-
def to_natural(theta: Tensor, covariance_matrix: Tensor) -> Tensor:
26+
def to_natural(theta: ch.Tensor, covariance_matrix: ch.Tensor) -> ch.Tensor:
2827
"""Convert canonical mean to natural parameters."""
2928
inv_cov = covariance_matrix.inverse()
3029
v = theta @ inv_cov # pylint: disable=invalid-name
3130
return v.flatten()
3231

3332
@staticmethod
34-
def to_canonical(theta: Tensor, covariance_matrix: Tensor) -> Tensor:
33+
def to_canonical(theta: ch.Tensor, covariance_matrix: ch.Tensor) -> ch.Tensor:
3534
"""Convert natural parameters to canonical mean."""
3635
loc = theta @ covariance_matrix
3736
return loc.flatten()
3837

3938

40-
class ExponentialFamilyMultivariateNormal(MultivariateNormal): # pylint: disable=abstract-method
39+
class ExponentialFamilyMultivariateNormal( # pylint: disable=abstract-method
40+
ch.distributions.MultivariateNormal
41+
):
4142
"""Multivariate normal parameterized by natural parameters."""
4243

43-
def __init__(self, theta: Tensor, dims: int):
44+
def __init__(self, theta: ch.Tensor, dims: int):
4445
"""Initialize with natural parameter theta and dimension."""
4546
self.dims = dims
4647
T, v = theta[: self.dims**2], theta[self.dims**2 :] # pylint: disable=invalid-name
@@ -49,12 +50,12 @@ def __init__(self, theta: Tensor, dims: int):
4950
super().__init__(mu, covariance_matrix)
5051

5152
@staticmethod
52-
def calc_suff_stat(x: Tensor) -> Tensor:
53+
def calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
5354
"""Return sufficient statistics for multivariate normal."""
5455
return ch.cat([ch.bmm(x.unsqueeze(2), x.unsqueeze(1)).flatten(1), x], 1)
5556

5657
@staticmethod
57-
def to_natural(theta: Tensor, dims: int) -> Tensor:
58+
def to_natural(theta: ch.Tensor, dims: int) -> ch.Tensor:
5859
"""Convert canonical parameters to natural form."""
5960
cov_matrix = theta[: dims**2].view(dims, dims)
6061
loc = theta[dims**2 :]
@@ -63,7 +64,7 @@ def to_natural(theta: Tensor, dims: int) -> Tensor:
6364
return ch.cat([-0.5 * mat_t.flatten(), v.flatten()])
6465

6566
@staticmethod
66-
def to_canonical(theta: Tensor, dims: int) -> Tensor:
67+
def to_canonical(theta: ch.Tensor, dims: int) -> ch.Tensor:
6768
"""Convert natural parameters to canonical form."""
6869
mat_t = theta[: dims**2].view(dims, dims) # pylint: disable=invalid-name
6970
v = theta[dims**2 :]

delphi/distributions/poisson.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
"""Poisson distribution in natural parameterization."""
33

44
import torch as ch
5-
from torch.distributions import Poisson
5+
import torch.distributions
66

77

8-
def calc_poiss_suff_stat(x):
9-
"""Return sufficient statistics for Poisson distribution."""
10-
return x
11-
12-
13-
class ExponentialFamilyPoisson(Poisson): # pylint: disable=abstract-method
8+
class ExponentialFamilyPoisson(ch.distributions.Poisson): # pylint: disable=abstract-method
149
"""Poisson distribution parameterized by natural parameters."""
1510

1611
def __init__(self, theta: ch.Tensor, dims: int):
@@ -19,6 +14,21 @@ def __init__(self, theta: ch.Tensor, dims: int):
1914
lambda_ = ch.exp(theta)
2015
super().__init__(lambda_)
2116

17+
@staticmethod
18+
def calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
19+
"""Return sufficient statistics for Poisson distribution."""
20+
return x
21+
22+
@staticmethod
23+
def to_natural(theta: ch.Tensor) -> ch.Tensor:
24+
"""Convert canonical rate parameter to natural log form."""
25+
return ch.log(theta)
26+
27+
@staticmethod
28+
def to_canonical(theta: ch.Tensor) -> ch.Tensor:
29+
"""Convert natural parameters to canonical rate parameter."""
30+
return ch.exp(theta)
31+
2232
def log_prob(self, value):
2333
"""Compute summed log probability over all dimensions."""
2434
result = super().log_prob(value)

delphi/distributions/weibull.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,10 @@
22
"""Weibull distribution in natural parameterization."""
33

44
import torch as ch
5-
from torch.distributions import Weibull
5+
import torch.distributions
66

77

8-
def calc_weibull_suff_stat(k, x):
9-
"""Return sufficient statistics for Weibull distribution."""
10-
return x.pow(k)
11-
12-
13-
class ExponentialFamilyWeibull(Weibull): # pylint: disable=abstract-method
8+
class ExponentialFamilyWeibull(ch.distributions.Weibull): # pylint: disable=abstract-method
149
"""Weibull distribution parameterized by natural parameters."""
1510

1611
def __init__(self, k: ch.Tensor, theta: ch.Tensor, dims: int):
@@ -19,6 +14,21 @@ def __init__(self, k: ch.Tensor, theta: ch.Tensor, dims: int):
1914
lambda_ = (-1 / theta).pow(1 / k)
2015
super().__init__(lambda_, k)
2116

17+
@staticmethod
18+
def calc_suff_stat(k: ch.Tensor, x: ch.Tensor) -> ch.Tensor:
19+
"""Return sufficient statistics for Weibull distribution."""
20+
return x.pow(k)
21+
22+
@staticmethod
23+
def to_natural(k: ch.Tensor, theta: ch.Tensor) -> ch.Tensor:
24+
"""Convert canonical scale parameter to natural form."""
25+
return -1.0 / theta.pow(k)
26+
27+
@staticmethod
28+
def to_canonical(k: ch.Tensor, theta: ch.Tensor) -> ch.Tensor:
29+
"""Convert natural parameters to canonical scale parameter."""
30+
return (-1 / theta).pow(1 / k)
31+
2232
def log_prob(self, value):
2333
"""Compute summed log probability over all dimensions."""
2434
result = super().log_prob(value)

delphi/truncated/distributions/truncated_boolean_product.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@ def __init__(
4949
logger,
5050
)
5151

52-
@staticmethod
53-
def _calc_suff_stat(x):
54-
"""Compute sufficient statistics for the Boolean product distribution."""
55-
return boolean_product.calc_bool_prod_suff_stat(x)
56-
5752
def _reparameterize_nat_form(self, theta):
5853
"""Convert canonical probability parameter to natural log-odds form."""
5954
return ch.log(theta / (1 - theta))

delphi/truncated/distributions/truncated_exponential.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ def __init__(
5353
logger,
5454
)
5555

56-
@staticmethod
57-
def _calc_suff_stat(x):
58-
"""Compute sufficient statistics for the exponential distribution."""
59-
return exponential.calc_exp_suff_stat(x)
60-
6156
def _constraints(self, theta):
6257
"""Clamp theta to be strictly negative."""
6358
return ch.clamp(theta, max=-1e-6)

delphi/truncated/distributions/truncated_multivariate_normal.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,14 @@ def __init__(
6565
self.emp_T = None # pylint: disable=invalid-name
6666
self.emp_v = None
6767

68-
@staticmethod
69-
def _calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
70-
"""Compute sufficient statistics for multivariate normal."""
71-
return multivariate_normal.ExponentialFamilyMultivariateNormal.calc_suff_stat(x)
72-
7368
def _calc_emp_model(self):
7469
"""Calculate empirical natural parameters and register T and v as nn.Parameters."""
7570
dataset_s = self.train_loader_.dataset.S # pylint: disable=invalid-name
76-
suff_stats = self._calc_suff_stat(dataset_s).mean(0)
71+
suff_stats = (
72+
multivariate_normal.ExponentialFamilyMultivariateNormal.calc_suff_stat(
73+
dataset_s
74+
).mean(0)
75+
)
7776
second_moment = suff_stats[: self.dims**2].view(self.dims, self.dims)
7877
loc = suff_stats[self.dims**2 :]
7978
# Center the second moment to get the empirical covariance: Σ = E[xx^T] - μμ^T.

delphi/truncated/distributions/truncated_multivariate_normal_known_covariance.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,12 @@ def __init__(
6262
self.covariance_matrix = covariance_matrix
6363
self._sampler = sampler
6464

65-
@staticmethod
66-
def _calc_suff_stat(x: ch.Tensor) -> ch.Tensor:
67-
"""Compute sufficient statistics for multivariate normal with known covariance."""
68-
return multivariate_normal.ExponentialFamilyMultivariateNormalKnownCovariance.calc_suff_stat(
69-
x
70-
)
71-
7265
def _calc_emp_model(self):
7366
"""Calculate empirical natural parameters and register theta as an nn.Parameter."""
7467
dataset_s = self.train_loader_.dataset.S # pylint: disable=invalid-name
75-
emp_mean = self._calc_suff_stat(dataset_s).mean(0)
68+
emp_mean = multivariate_normal.ExponentialFamilyMultivariateNormalKnownCovariance.calc_suff_stat(
69+
dataset_s
70+
).mean(0)
7671
self.emp_theta = multivariate_normal.ExponentialFamilyMultivariateNormalKnownCovariance.to_natural(
7772
emp_mean, self.covariance_matrix
7873
)

delphi/truncated/distributions/truncated_poisson.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,6 @@ def __init__(
5353
logger,
5454
)
5555

56-
@staticmethod
57-
def _calc_suff_stat(x):
58-
"""Compute sufficient statistics for the Poisson distribution."""
59-
return poisson.calc_poiss_suff_stat(x)
60-
6156
def _reparameterize_nat_form(self, theta):
6257
"""Convert canonical rate parameter to natural log form."""
6358
return ch.log(theta)

0 commit comments

Comments
 (0)