Skip to content

Commit 789f0cc

Browse files
pstefanou12claude
andcommitted
Fix truncated exponential family distributions: use configs, correct canonical/natural conversions.
- Replace Parameters + check_and_fill_args with configs.make_config throughout - Use to_canonical/to_natural staticmethods for parameter conversions in TruncatedExponential, TruncatedBooleanProduct, TruncatedPoisson, TruncatedWeibull - Add _calc_emp_model to TruncatedExponential for empirical initialization - Fix best_*/ema_*/avg_* properties to return canonical parameters - Update tests to use dict args instead of Parameters Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent cbcd9a1 commit 789f0cc

9 files changed

Lines changed: 97 additions & 135 deletions

delphi/truncated/distributions/truncated_boolean_product.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
import logging
55
from collections.abc import Callable
66

7-
import torch as ch
87

98
from delphi import delphi_logger
109
from delphi.exponential_family import boolean_product
1110
from delphi.truncated.distributions import truncated_exponential_family_distributions
12-
from delphi.utils import defaults, helpers
11+
from delphi.utils import configs
1312

1413

1514
class TruncatedBooleanProduct(
@@ -20,22 +19,23 @@ class TruncatedBooleanProduct(
2019
dist = boolean_product.ExponentialFamilyBooleanProduct
2120

2221
def __init__(
23-
self, args: helpers.Parameters, phi: Callable, alpha: float, dims: int
22+
self,
23+
args: dict | configs.TruncatedExponentialFamilyDistributionConfig,
24+
phi: Callable,
25+
alpha: float,
26+
dims: int,
2427
):
2528
"""Initialize TruncatedBooleanProduct.
2629
2730
Args:
28-
args: Parameter object holding hyperparameters.
31+
args: Hyperparameter dict or Pydantic config.
2932
phi: Truncation set oracle.
3033
alpha: Survival probability lower bound.
3134
dims: Number of dimensions.
32-
33-
Raises:
34-
TypeError: If args is not a Parameters instance.
3535
"""
36-
if not isinstance(args, helpers.Parameters):
37-
raise TypeError(f"args is type {type(args).__name__}; expected Parameters.")
38-
args = defaults.check_and_fill_args(args, defaults.TRUNC_BOOL_PROD_DEFAULTS)
36+
args = configs.make_config(
37+
args, configs.TruncatedExponentialFamilyDistributionConfig
38+
)
3939

4040
logger = (
4141
delphi_logger.delphiLogger()
@@ -50,33 +50,33 @@ def __init__(
5050
logger,
5151
)
5252

53-
def _reparameterize_nat_form(self, theta):
54-
"""Convert canonical probability parameter to natural log-odds form."""
55-
return ch.log(theta / (1 - theta))
56-
57-
def _reparameterize_canon_form(self, theta):
58-
"""Convert natural log-odds to canonical probability parameter."""
59-
return ch.exp(theta) / (1 + ch.exp(theta))
60-
6153
@property
6254
def best_p_(self):
6355
"""Return the best probability parameter estimate."""
64-
return self.best_params
56+
return boolean_product.ExponentialFamilyBooleanProduct.to_canonical(
57+
self.best_params
58+
)
6559

6660
@property
6761
def final_p_(self):
6862
"""Return the final probability parameter estimate."""
69-
return self.final_params
63+
return boolean_product.ExponentialFamilyBooleanProduct.to_canonical(
64+
self.final_params
65+
)
7066

7167
@property
7268
def ema_p_(self):
7369
"""Return the EMA probability parameter estimate."""
74-
return self.ema_params
70+
return boolean_product.ExponentialFamilyBooleanProduct.to_canonical(
71+
self.ema_params
72+
)
7573

7674
@property
7775
def avg_p_(self):
7876
"""Return the averaged probability parameter estimate."""
79-
return self.avg_params
77+
return boolean_product.ExponentialFamilyBooleanProduct.to_canonical(
78+
self.avg_params
79+
)
8080

8181
def __str__(self):
8282
"""Return a human-readable name for this distribution."""

delphi/truncated/distributions/truncated_exponential.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from collections.abc import Callable
66

77
import torch as ch
8+
from torch import nn
89

910
from delphi import delphi_logger
1011
from delphi.exponential_family import exponential
1112
from delphi.truncated.distributions import truncated_exponential_family_distributions
12-
from delphi.utils import defaults, helpers
13+
from delphi.utils import configs
1314

1415

1516
class TruncatedExponential(
@@ -21,25 +22,22 @@ class TruncatedExponential(
2122

2223
def __init__(
2324
self,
24-
args: helpers.Parameters,
25+
args: dict | configs.TruncatedExponentialFamilyDistributionConfig,
2526
phi: Callable,
2627
alpha: float,
2728
dims: int,
2829
):
2930
"""Initialize TruncatedExponential.
3031
3132
Args:
32-
args: Parameter object holding hyperparameters.
33+
args: Hyperparameter dict or Pydantic config.
3334
phi: Truncation set oracle.
3435
alpha: Survival probability lower bound.
3536
dims: Number of dimensions.
36-
37-
Raises:
38-
TypeError: If args is not a Parameters instance.
3937
"""
40-
if not isinstance(args, helpers.Parameters):
41-
raise TypeError(f"args is type {type(args).__name__}; expected Parameters.")
42-
args = defaults.check_and_fill_args(args, defaults.TRUNC_EXP_DEFAULTS)
38+
args = configs.make_config(
39+
args, configs.TruncatedExponentialFamilyDistributionConfig
40+
)
4341

4442
logger = (
4543
delphi_logger.delphiLogger()
@@ -54,37 +52,36 @@ def __init__(
5452
logger,
5553
)
5654

55+
def _calc_emp_model(self):
56+
"""Initialize theta at the natural parameter corresponding to the empirical rate."""
57+
dataset_s = self.train_loader_.dataset.S # pylint: disable=invalid-name
58+
emp_rate = 1.0 / dataset_s.mean(0)
59+
self.emp_theta = exponential.ExponentialFamilyExponential.to_natural(emp_rate)
60+
self.register_parameter("theta", nn.Parameter(self.emp_theta.clone()))
61+
5762
def _constraints(self, theta):
5863
"""Clamp theta to be strictly negative."""
5964
return ch.clamp(theta, max=-1e-6)
6065

61-
def _reparameterize_nat_form(self, theta):
62-
"""Convert canonical rate parameter to natural form."""
63-
return -theta
64-
65-
def _reparameterize_canon_form(self, theta):
66-
"""Convert natural parameters to canonical rate parameter."""
67-
return -theta
68-
6966
@property
7067
def best_lambda_(self):
71-
"""Return the best rate parameter estimate."""
72-
return self.best_params
68+
"""Return the best rate (canonical) parameter estimate."""
69+
return exponential.ExponentialFamilyExponential.to_canonical(self.best_params)
7370

7471
@property
7572
def final_lambda_(self):
76-
"""Return the final rate parameter estimate."""
77-
return self.final_params
73+
"""Return the final rate (canonical) parameter estimate."""
74+
return exponential.ExponentialFamilyExponential.to_canonical(self.final_params)
7875

7976
@property
8077
def ema_lambda_(self):
81-
"""Return the EMA rate parameter estimate."""
82-
return self.ema_params
78+
"""Return the EMA rate (canonical) parameter estimate."""
79+
return exponential.ExponentialFamilyExponential.to_canonical(self.ema_params)
8380

8481
@property
8582
def avg_lambda_(self):
86-
"""Return the averaged rate parameter estimate."""
87-
return self.avg_params
83+
"""Return the averaged rate (canonical) parameter estimate."""
84+
return exponential.ExponentialFamilyExponential.to_canonical(self.avg_params)
8885

8986
def __str__(self):
9087
"""Return a human-readable name for this distribution."""

delphi/truncated/distributions/truncated_exponential_family_distributions.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
"""Parent class for truncated exponential distribution model classes."""
33

44
import abc
5-
import functools
65
from collections.abc import Callable
76

87
import torch as ch
@@ -63,15 +62,12 @@ def __init__(
6362
self.alpha = alpha
6463
self.dims = dims
6564

66-
dist_cls = (
67-
self.dist.func if isinstance(self.dist, functools.partial) else self.dist
68-
)
6965
self.criterion = losses.TruncatedExponentialFamilyDistributionNLL.apply
7066
self.criterion_params = [
7167
self.phi,
7268
self.dims,
7369
self.dist,
74-
dist_cls.calc_suff_stat,
70+
self.dist.calc_suff_stat,
7571
self.args.num_samples,
7672
self.args.eps,
7773
]
@@ -115,20 +111,18 @@ def fit(self, S: ch.Tensor): # pylint: disable=invalid-name
115111
f"or equal to the number of samples ({self.args.num_samples})."
116112
)
117113

118-
dist_cls = (
119-
self.dist.func if isinstance(self.dist, functools.partial) else self.dist
120-
)
121114
self.train_loader_, self.val_loader_ = datasets.make_train_and_val_distr(
122115
self.args,
123116
S,
124117
datasets.TruncatedExponentialDistributionDataset,
125-
{"calc_suff_stat": dist_cls.calc_suff_stat},
118+
{"calc_suff_stat": self.dist.calc_suff_stat},
126119
)
127120

128121
self.prev_best_loss = None
129122
self.radius_history = []
130123
self.loss_history = []
131124
self._calc_emp_model()
125+
self.nll_init = self._compute_nll(self.emp_theta)
132126
self.radius = self.args.min_radius
133127

134128
phase = 0
@@ -224,13 +218,8 @@ def nll_threshold(self) -> float:
224218
def _calc_emp_model(self):
225219
"""Calculate empirical natural parameters from training data."""
226220
dataset_s = self.train_loader_.dataset.S # pylint: disable=invalid-name
227-
dist_cls = (
228-
self.dist.func if isinstance(self.dist, functools.partial) else self.dist
229-
)
230-
self.emp_theta = dist_cls.calc_suff_stat(dataset_s).mean(0)
221+
self.emp_theta = self.dist.calc_suff_stat(dataset_s).mean(0)
231222
self.register_parameter("theta", nn.Parameter(self.emp_theta.clone()))
232-
with ch.no_grad():
233-
self.nll_init = self._compute_nll(self.emp_theta)
234223

235224
def _compute_nll(self, theta: ch.Tensor) -> float:
236225
"""Compute non-truncated NLL of training samples under theta.
@@ -241,9 +230,10 @@ def _compute_nll(self, theta: ch.Tensor) -> float:
241230
Returns:
242231
Mean negative log-likelihood over training samples.
243232
"""
244-
S = self.train_loader_.dataset.S # pylint: disable=invalid-name
245-
D = self.dist(theta.detach(), self.dims) # pylint: disable=invalid-name
246-
return -D.log_prob(S).mean().item()
233+
with ch.no_grad():
234+
S = self.train_loader_.dataset.S # pylint: disable=invalid-name
235+
D = self.dist(theta, self.dims) # pylint: disable=invalid-name
236+
return -D.log_prob(S).mean().item()
247237

248238
def _project_onto_sublevel_set(self, theta: ch.Tensor) -> ch.Tensor:
249239
"""Project theta onto {θ : L(θ) ≤ nll_threshold} via bisection.
@@ -296,6 +286,12 @@ def step_post_hook(self, optimizer, args, kwargs) -> None:
296286
"""
297287
if self.args.project:
298288
with ch.no_grad():
289+
# Enforce constraints before projection so that _compute_nll
290+
# inside _project_onto_sublevel_set always receives a theta in
291+
# the valid domain (e.g. strictly negative for exponential /
292+
# Weibull). Constraints are applied again after projection to
293+
# keep the final value valid as well.
294+
self._write_theta(self._constraints(self.theta))
299295
proj_theta = self._project_onto_sublevel_set(self.theta)
300296
self._write_theta(self._constraints(proj_theta))
301297

delphi/truncated/distributions/truncated_multivariate_normal_known_covariance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
multivariate_normal.ExponentialFamilyMultivariateNormalKnownCovariance,
5454
covariance_matrix,
5555
)
56+
self.dist.calc_suff_stat = multivariate_normal.ExponentialFamilyMultivariateNormalKnownCovariance.calc_suff_stat
5657
super().__init__(
5758
args,
5859
phi,

delphi/truncated/distributions/truncated_poisson.py

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44
from collections.abc import Callable
55
import logging
66

7-
import torch as ch
87

98
from delphi import delphi_logger
109
from delphi.exponential_family import poisson
1110
from delphi.truncated.distributions import truncated_exponential_family_distributions
12-
from delphi.utils import defaults, helpers
11+
from delphi.utils import configs
1312

1413

1514
class TruncatedPoisson(
@@ -21,25 +20,22 @@ class TruncatedPoisson(
2120

2221
def __init__(
2322
self,
24-
args: helpers.Parameters,
23+
args: dict | configs.TruncatedExponentialFamilyDistributionConfig,
2524
phi: Callable,
2625
alpha: float,
2726
dims: int,
2827
):
2928
"""Initialize TruncatedPoisson.
3029
3130
Args:
32-
args: Parameter object holding hyperparameters.
31+
args: Hyperparameter dict or Pydantic config.
3332
phi: Truncation set oracle.
3433
alpha: Survival probability lower bound.
3534
dims: Number of dimensions.
36-
37-
Raises:
38-
TypeError: If args is not a Parameters instance.
3935
"""
40-
if not isinstance(args, helpers.Parameters):
41-
raise TypeError(f"args is type {type(args).__name__}; expected Parameters.")
42-
args = defaults.check_and_fill_args(args, defaults.TRUNC_POISS_DEFAULTS)
36+
args = configs.make_config(
37+
args, configs.TruncatedExponentialFamilyDistributionConfig
38+
)
4339

4440
logger = (
4541
delphi_logger.delphiLogger()
@@ -54,33 +50,25 @@ def __init__(
5450
logger,
5551
)
5652

57-
def _reparameterize_nat_form(self, theta):
58-
"""Convert canonical rate parameter to natural log form."""
59-
return ch.log(theta)
60-
61-
def _reparameterize_canon_form(self, theta):
62-
"""Convert natural parameters to canonical rate parameter."""
63-
return ch.exp(theta)
64-
6553
@property
6654
def best_lambda_(self):
6755
"""Return the best rate parameter estimate."""
68-
return self.best_params
56+
return poisson.ExponentialFamilyPoisson.to_canonical(self.best_params)
6957

7058
@property
7159
def final_lambda_(self):
7260
"""Return the final rate parameter estimate."""
73-
return self.final_params
61+
return poisson.ExponentialFamilyPoisson.to_canonical(self.final_params)
7462

7563
@property
7664
def ema_lambda_(self):
7765
"""Return the EMA rate parameter estimate."""
78-
return self.ema_params
66+
return poisson.ExponentialFamilyPoisson.to_canonical(self.ema_params)
7967

8068
@property
8169
def avg_lambda_(self):
8270
"""Return the averaged rate parameter estimate."""
83-
return self.avg_params
71+
return poisson.ExponentialFamilyPoisson.to_canonical(self.avg_params)
8472

8573
def __str__(self):
8674
"""Return a human-readable name for this distribution."""

0 commit comments

Comments
 (0)