Skip to content

Commit 87d969a

Browse files
authored
Merge pull request #29 from pstefanou12/waterboi/migrate-distributions-to-truncated
Migrate distributions/ to truncated/distributions/ and refactor
2 parents 0804c8e + f97cefd commit 87d969a

41 files changed

Lines changed: 2706 additions & 2224 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

delphi/delphi.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
import torch as ch
10+
from pydantic import BaseModel
1011
from torch.optim import SGD, LBFGS, Adam, AdamW, lr_scheduler
1112

1213
from delphi.utils.constants import (
@@ -15,15 +16,7 @@
1516
PythonFrameworks,
1617
SchedulerType,
1718
)
18-
from delphi.utils.defaults import (
19-
check_and_fill_args,
20-
DELPHI_DEFAULTS,
21-
SGD_DEFAULTS,
22-
LBFGS_DEFAULTS,
23-
ADAM_DEFAULTS,
24-
ADAMW_DEFAULTS,
25-
)
26-
from delphi.utils.helpers import AverageMeter, Parameters
19+
from delphi.utils.helpers import AverageMeter
2720

2821

2922
class delphi(ch.nn.Module): # pylint: disable=invalid-name,too-many-instance-attributes,abstract-method
@@ -40,22 +33,23 @@ class delphi(ch.nn.Module): # pylint: disable=invalid-name,too-many-instance-at
4033

4134
_OPTIMIZER_REGISTRY: ClassVar[dict[str, Callable]] = {}
4235

43-
def __init__(self, args: Parameters):
36+
def __init__(self, args: BaseModel):
4437
"""Initialize the delphi model.
4538
4639
Args:
47-
args: Hyperparameter object; see DELPHI_DEFAULTS for supported keys.
40+
args: Fully constructed Pydantic config. Concrete subclasses are
41+
responsible for converting a user-supplied dict to their
42+
specific config before calling super().__init__.
4843
4944
Raises:
50-
TypeError: If args is not a Parameters instance.
45+
TypeError: If args is not a Pydantic BaseModel.
5146
"""
5247
super().__init__()
53-
if not isinstance(args, Parameters):
48+
if not isinstance(args, BaseModel):
5449
raise TypeError(
55-
f"args is type {type(args).__name__}; "
56-
"expected delphi.utils.helpers.Parameters"
50+
f"args is type {type(args).__name__}; expected a pydantic.BaseModel."
5751
)
58-
self.args: Parameters = check_and_fill_args(args, DELPHI_DEFAULTS)
52+
self.args = args
5953

6054
self.optimizer: ch.optim.Optimizer | None = None
6155
self.schedule: lr_scheduler.LRScheduler | None = None
@@ -149,7 +143,6 @@ def _remove_none_config(self, config: dict) -> dict:
149143

150144
def _create_sgd(self, params: list[dict]) -> SGD:
151145
"""Create an SGD optimizer from args."""
152-
check_and_fill_args(self.args, SGD_DEFAULTS)
153146
config = {
154147
"lr": self.args.lr,
155148
"momentum": getattr(self.args, "momentum", 0),
@@ -165,7 +158,6 @@ def _create_sgd(self, params: list[dict]) -> SGD:
165158

166159
def _create_lbfgs(self, params: list[dict]) -> LBFGS:
167160
"""Create an L-BFGS optimizer from args."""
168-
check_and_fill_args(self.args, LBFGS_DEFAULTS)
169161
config = {
170162
"lr": getattr(self.args, "lr", 1.0),
171163
"max_iter": getattr(self.args, "max_iter", 20),
@@ -179,7 +171,6 @@ def _create_lbfgs(self, params: list[dict]) -> LBFGS:
179171

180172
def _create_adam(self, params: list[dict]) -> Adam:
181173
"""Create an Adam optimizer from args."""
182-
check_and_fill_args(self.args, ADAM_DEFAULTS)
183174
config = {
184175
"lr": getattr(self.args, "lr", 1e-1),
185176
"betas": (
@@ -199,7 +190,6 @@ def _create_adam(self, params: list[dict]) -> Adam:
199190

200191
def _create_adamw(self, params: list[dict]) -> AdamW:
201192
"""Create an AdamW optimizer from args."""
202-
check_and_fill_args(self.args, ADAMW_DEFAULTS)
203193
config = {
204194
"lr": getattr(self.args, "lr", 1e-3),
205195
"betas": (

delphi/distributions/__init__.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,2 @@
11
# Author: pstefanou12@
2-
"""Distribution models subpackage for delphi."""
3-
4-
from delphi.distributions.truncated_normal import TruncatedNormal
5-
from delphi.distributions.truncated_multivariate_normal import (
6-
TruncatedMultivariateNormal,
7-
)
8-
from delphi.distributions.unknown_truncated_normal import UnknownTruncationNormal
9-
from delphi.distributions.unknown_truncated_multivariate_normal import (
10-
UnknownTruncationMultivariateNormal,
11-
Exp_h,
12-
)
13-
from delphi.distributions.truncated_boolean_product import TruncatedBooleanProduct
14-
from delphi.distributions.truncated_exponential import TruncatedExponential
15-
from delphi.distributions.truncated_poisson import TruncatedPoisson
16-
from delphi.distributions.truncated_weibull import TruncatedWeibull
17-
18-
__all__ = [
19-
"TruncatedNormal",
20-
"TruncatedMultivariateNormal",
21-
"UnknownTruncationNormal",
22-
"UnknownTruncationMultivariateNormal",
23-
"Exp_h",
24-
"TruncatedBooleanProduct",
25-
"TruncatedExponential",
26-
"TruncatedPoisson",
27-
"TruncatedWeibull",
28-
]
2+
"""Exponential family distributions in natural parameterization for delphi."""
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Author: pstefanou12@
2+
"""Boolean product distribution in natural parameterization."""
3+
4+
import torch as ch
5+
from torch.distributions import Bernoulli
6+
7+
8+
def calc_bool_prod_suff_stat(x):
9+
"""Return sufficient statistics for boolean product distribution."""
10+
return x
11+
12+
13+
class ExponentialFamilyBooleanProduct(Bernoulli): # pylint: disable=abstract-method
14+
"""Boolean product distribution parameterized by natural parameters."""
15+
16+
def __init__(self, theta: ch.Tensor, dims: int):
17+
"""Initialize with natural parameter theta and dimension."""
18+
self.dims = dims
19+
p = ch.exp(theta) / (1 + ch.exp(theta))
20+
super().__init__(p)
21+
22+
def log_prob(self, value):
23+
"""Compute summed log probability over all dimensions."""
24+
result = super().log_prob(value)
25+
return result.sum(-1)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Author: pstefanou12@
2+
"""Exponential distribution in natural parameterization."""
3+
4+
import torch as ch
5+
from torch.distributions import Exponential
6+
7+
8+
def calc_exp_suff_stat(x):
9+
"""Return sufficient statistics for exponential distribution."""
10+
return x
11+
12+
13+
class ExponentialFamilyExponential(Exponential): # pylint: disable=abstract-method
14+
"""Exponential distribution parameterized by natural parameters."""
15+
16+
def __init__(self, theta: ch.Tensor, dims: int):
17+
"""Initialize with natural parameter theta and dimension."""
18+
self.dims = dims
19+
lambda_ = -theta
20+
super().__init__(lambda_)
21+
22+
def log_prob(self, value):
23+
"""Compute summed log probability over all dimensions."""
24+
result = super().log_prob(value)
25+
return result.sum(-1)
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Author: pstefanou12@
2+
"""Multivariate normal distribution in natural parameterization."""
3+
4+
import torch as ch
5+
from torch import Tensor
6+
from torch.distributions import MultivariateNormal
7+
8+
9+
class ExponentialFamilyMultivariateNormalKnownCovariance( # pylint: disable=abstract-method
10+
MultivariateNormal
11+
):
12+
"""Multivariate normal parameterized by natural parameters with known covariance."""
13+
14+
def __init__(self, covariance_matrix: Tensor, theta: Tensor, dims: int):
15+
"""Initialize with covariance matrix, natural parameter theta, and dimension."""
16+
self.dims = dims
17+
v = theta
18+
mu = (covariance_matrix @ v).view(self.dims)
19+
super().__init__(mu, covariance_matrix)
20+
21+
@staticmethod
22+
def calc_suff_stat(x: Tensor) -> Tensor:
23+
"""Return sufficient statistics for multivariate normal with known covariance."""
24+
return x
25+
26+
@staticmethod
27+
def to_natural(theta: Tensor, covariance_matrix: Tensor) -> Tensor:
28+
"""Convert canonical mean to natural parameters."""
29+
inv_cov = covariance_matrix.inverse()
30+
v = theta @ inv_cov # pylint: disable=invalid-name
31+
return v.flatten()
32+
33+
@staticmethod
34+
def to_canonical(theta: Tensor, covariance_matrix: Tensor) -> Tensor:
35+
"""Convert natural parameters to canonical mean."""
36+
loc = theta @ covariance_matrix
37+
return loc.flatten()
38+
39+
40+
class ExponentialFamilyMultivariateNormal(MultivariateNormal): # pylint: disable=abstract-method
41+
"""Multivariate normal parameterized by natural parameters."""
42+
43+
def __init__(self, theta: Tensor, dims: int):
44+
"""Initialize with natural parameter theta and dimension."""
45+
self.dims = dims
46+
T, v = theta[: self.dims**2], theta[self.dims**2 :] # pylint: disable=invalid-name
47+
covariance_matrix = ch.inverse(-2 * T.view(self.dims, self.dims))
48+
mu = (covariance_matrix @ v).view(self.dims)
49+
super().__init__(mu, covariance_matrix)
50+
51+
@staticmethod
52+
def calc_suff_stat(x: Tensor) -> Tensor:
53+
"""Return sufficient statistics for multivariate normal."""
54+
return ch.cat([ch.bmm(x.unsqueeze(2), x.unsqueeze(1)).flatten(1), x], 1)
55+
56+
@staticmethod
57+
def to_natural(theta: Tensor, dims: int) -> Tensor:
58+
"""Convert canonical parameters to natural form."""
59+
cov_matrix = theta[: dims**2].view(dims, dims)
60+
loc = theta[dims**2 :]
61+
mat_t = cov_matrix.inverse() # pylint: disable=invalid-name
62+
v = loc @ mat_t # pylint: disable=invalid-name
63+
return ch.cat([-0.5 * mat_t.flatten(), v.flatten()])
64+
65+
@staticmethod
66+
def to_canonical(theta: Tensor, dims: int) -> Tensor:
67+
"""Convert natural parameters to canonical form."""
68+
mat_t = theta[: dims**2].view(dims, dims) # pylint: disable=invalid-name
69+
v = theta[dims**2 :]
70+
covariance_matrix = (-2 * mat_t).inverse()
71+
loc = v @ covariance_matrix
72+
return ch.cat([covariance_matrix.flatten(), loc.flatten()])

delphi/distributions/poisson.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Author: pstefanou12@
2+
"""Poisson distribution in natural parameterization."""
3+
4+
import torch as ch
5+
from torch.distributions import Poisson
6+
7+
8+
def calc_poiss_suff_stat(x):
9+
"""Return sufficient statistics for Poisson distribution."""
10+
return x
11+
12+
13+
class ExponentialFamilyPoisson(Poisson): # pylint: disable=abstract-method
14+
"""Poisson distribution parameterized by natural parameters."""
15+
16+
def __init__(self, theta: ch.Tensor, dims: int):
17+
"""Initialize with natural parameter theta and dimension."""
18+
self.dims = dims
19+
lambda_ = ch.exp(theta)
20+
super().__init__(lambda_)
21+
22+
def log_prob(self, value):
23+
"""Compute summed log probability over all dimensions."""
24+
result = super().log_prob(value)
25+
return result.sum(-1)

0 commit comments

Comments
 (0)