Skip to content

Commit 455dcf9

Browse files
pstefanou12claude
andcommitted
Move truncated distribution configs to delphi/utils/configs.py.
TruncatedExponentialFamilyDistributionConfig and TruncatedMultivariateNormalConfig now live alongside TrainerConfig and OptimizerConfig so any module can import them without pulling in the full distribution implementation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 87d969a commit 455dcf9

6 files changed

Lines changed: 93 additions & 101 deletions

delphi/truncated/distributions/truncated_exponential_family_distributions.py

Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from __future__ import annotations
77
from collections.abc import Callable
88

9-
import pydantic
109
import torch as ch
1110
from torch import nn
1211
from torch.distributions import exp_family
@@ -16,77 +15,6 @@
1615
from delphi.utils import configs, datasets
1716

1817

19-
class TruncatedExponentialFamilyDistributionConfig(
20-
configs.TrainerConfig, configs.OptimizerConfig
21-
):
22-
"""Configuration for truncated exponential family distribution algorithms.
23-
24-
Attributes:
25-
val: Fraction of data held out for validation.
26-
eps: Numerical stability constant for the NLL criterion.
27-
min_radius: Initial NLL budget above the empirical initialization
28-
for the sublevel-set projection (phase 1).
29-
max_radius: Maximum NLL budget; the procedure stops when reached.
30-
rate: Multiplicative budget expansion factor per phase.
31-
batch_size: Mini-batch size for training.
32-
num_samples: Monte Carlo samples drawn per NLL evaluation.
33-
max_phases: Maximum number of radius-expansion phases.
34-
loss_convergence_tol: Absolute loss improvement threshold for
35-
stopping between phases.
36-
relative_loss_tol: Relative loss improvement threshold between phases.
37-
loss_increase_tol: Loss increase threshold for detecting overshoot.
38-
project: Enable per-step sublevel-set projection.
39-
"""
40-
41-
model_config = pydantic.ConfigDict(extra="ignore")
42-
43-
# Override parent defaults for distribution training.
44-
tol: float = pydantic.Field(default=1e-1, ge=0.0)
45-
record_params_every: int = pydantic.Field(default=1, ge=1)
46-
epochs: int | None = pydantic.Field(default=1, ge=1)
47-
48-
# Distribution-specific fields.
49-
val: float = pydantic.Field(default=0.2, ge=0.0, le=1.0)
50-
eps: float = pydantic.Field(default=1e-5, gt=0.0)
51-
min_radius: float = pydantic.Field(default=3.0, ge=0.0)
52-
max_radius: float = pydantic.Field(default=10.0, ge=0.0)
53-
rate: float = pydantic.Field(default=1.1, gt=1.0)
54-
batch_size: int = pydantic.Field(default=10, ge=1)
55-
num_samples: int = pydantic.Field(default=10000, ge=1)
56-
max_phases: int = pydantic.Field(default=1, ge=1)
57-
loss_convergence_tol: float = pydantic.Field(default=1e-3, ge=0.0)
58-
relative_loss_tol: float = pydantic.Field(default=float("inf"), ge=0.0)
59-
loss_increase_tol: float = pydantic.Field(default=float("inf"), ge=0.0)
60-
project: bool = True
61-
62-
@pydantic.model_validator(mode="after")
63-
def check_radius(self) -> TruncatedExponentialFamilyDistributionConfig:
64-
"""Validate that min_radius does not exceed max_radius."""
65-
if self.min_radius > self.max_radius:
66-
raise ValueError(
67-
f"min_radius ({self.min_radius}) must be <= "
68-
f"max_radius ({self.max_radius})."
69-
)
70-
return self
71-
72-
@pydantic.model_validator(mode="after")
73-
def resolve_epochs_iterations(
74-
self,
75-
) -> TruncatedExponentialFamilyDistributionConfig:
76-
"""Clear the default epochs when iterations is explicitly provided.
77-
78-
The trainer uses exactly one stopping criterion. When the user
79-
supplies iterations, the per-phase epochs default is cleared so
80-
the trainer stops on iterations instead.
81-
"""
82-
if (
83-
"iterations" in self.model_fields_set
84-
and "epochs" not in self.model_fields_set
85-
):
86-
object.__setattr__(self, "epochs", None)
87-
return self
88-
89-
9018
class TruncatedExponentialFamilyDistribution(distributions.distributions): # pylint: disable=too-many-instance-attributes
9119
"""Base class for truncated exponential family distribution models.
9220
@@ -112,7 +40,7 @@ class TruncatedExponentialFamilyDistribution(distributions.distributions): # py
11240

11341
def __init__(
11442
self,
115-
args: TruncatedExponentialFamilyDistributionConfig,
43+
args: configs.TruncatedExponentialFamilyDistributionConfig,
11644
phi: Callable,
11745
alpha: float,
11846
dims: int,

delphi/truncated/distributions/truncated_multivariate_normal.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from collections.abc import Callable
55
import logging
66

7-
from pydantic import Field
87
import torch as ch
98
from torch import nn
109

@@ -14,22 +13,6 @@
1413
from delphi.utils import configs
1514

1615

17-
class TruncatedMultivariateNormalConfig(
18-
truncated_exponential_family_distributions.TruncatedExponentialFamilyDistributionConfig
19-
):
20-
"""Configuration for truncated multivariate normal distributions.
21-
22-
Attributes:
23-
eigenvalue_lower_bound: Minimum eigenvalue enforced during the
24-
negative-definite cone projection of the precision matrix T.
25-
covariance_matrix_lr: Optional separate learning rate for the
26-
covariance matrix parameter; falls back to lr when None.
27-
"""
28-
29-
eigenvalue_lower_bound: float = Field(default=1e-2, gt=0.0)
30-
covariance_matrix_lr: float | None = Field(default=None, gt=0.0)
31-
32-
3316
class TruncatedMultivariateNormal(
3417
truncated_exponential_family_distributions.TruncatedExponentialFamilyDistribution
3518
):
@@ -44,7 +27,7 @@ class TruncatedMultivariateNormal(
4427

4528
def __init__(
4629
self,
47-
args: dict | TruncatedMultivariateNormalConfig,
30+
args: dict | configs.TruncatedMultivariateNormalConfig,
4831
phi: Callable,
4932
alpha: float,
5033
dims: int,
@@ -59,7 +42,7 @@ def __init__(
5942
dims: Number of dimensions.
6043
sampler: Optional sampler override.
6144
"""
62-
args = configs.make_config(args, TruncatedMultivariateNormalConfig)
45+
args = configs.make_config(args, configs.TruncatedMultivariateNormalConfig)
6346
self.eigenvalue_lower_bound = args.eigenvalue_lower_bound
6447

6548
logger = (

delphi/truncated/distributions/truncated_multivariate_normal_known_covariance.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from delphi import delphi_logger
1212
from delphi.distributions import multivariate_normal
1313
from delphi.truncated.distributions import truncated_exponential_family_distributions
14-
from delphi.truncated.distributions import truncated_multivariate_normal
1514
from delphi.utils import configs
1615

1716

@@ -26,7 +25,7 @@ class TruncatedMultivariateNormalKnownCovariance(
2625

2726
def __init__(
2827
self,
29-
args: dict | truncated_multivariate_normal.TruncatedMultivariateNormalConfig,
28+
args: dict | configs.TruncatedMultivariateNormalConfig,
3029
phi: Callable,
3130
alpha: float,
3231
dims: int,
@@ -43,9 +42,7 @@ def __init__(
4342
covariance_matrix: Known covariance matrix.
4443
sampler: Optional sampler override.
4544
"""
46-
args = configs.make_config(
47-
args, truncated_multivariate_normal.TruncatedMultivariateNormalConfig
48-
)
45+
args = configs.make_config(args, configs.TruncatedMultivariateNormalConfig)
4946
logger = (
5047
delphi_logger.delphiLogger()
5148
if args.verbose

delphi/truncated/distributions/truncated_normal.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from collections.abc import Callable
55

66
from delphi.truncated.distributions import truncated_multivariate_normal
7+
from delphi.utils import configs
78

89

910
class TruncatedNormal(truncated_multivariate_normal.TruncatedMultivariateNormal):
1011
"""Truncated normal distribution with unknown variance."""
1112

1213
def __init__(
1314
self,
14-
args: dict | truncated_multivariate_normal.TruncatedMultivariateNormalConfig,
15+
args: dict | configs.TruncatedMultivariateNormalConfig,
1516
phi: Callable,
1617
alpha: float,
1718
sampler: Callable = None,

delphi/truncated/distributions/truncated_normal_known_variance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import torch as ch
77

88
from delphi.truncated.distributions import (
9-
truncated_multivariate_normal,
109
truncated_multivariate_normal_known_covariance,
1110
)
11+
from delphi.utils import configs
1212

1313

1414
class TruncatedNormalKnownVariance(
@@ -18,7 +18,7 @@ class TruncatedNormalKnownVariance(
1818

1919
def __init__(
2020
self,
21-
args: dict | truncated_multivariate_normal.TruncatedMultivariateNormalConfig,
21+
args: dict | configs.TruncatedMultivariateNormalConfig,
2222
phi: Callable,
2323
alpha: float,
2424
covariance_matrix: ch.Tensor | None,

delphi/utils/configs.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from __future__ import annotations
55

6-
from pydantic import BaseModel, ConfigDict, Field
6+
from pydantic import BaseModel, ConfigDict, Field, model_validator
77

88

99
def make_config(args: dict | BaseModel, config_class: type[BaseModel]) -> BaseModel:
@@ -113,3 +113,86 @@ class OptimizerConfig(BaseModel):
113113
differentiable: bool = False
114114
fused: bool | None = None
115115
scheduler: str | None = None
116+
117+
118+
class TruncatedExponentialFamilyDistributionConfig(TrainerConfig, OptimizerConfig):
119+
"""Configuration for truncated exponential family distribution algorithms.
120+
121+
Attributes:
122+
val: Fraction of data held out for validation.
123+
eps: Numerical stability constant for the NLL criterion.
124+
min_radius: Initial NLL budget above the empirical initialization
125+
for the sublevel-set projection (phase 1).
126+
max_radius: Maximum NLL budget; the procedure stops when reached.
127+
rate: Multiplicative budget expansion factor per phase.
128+
batch_size: Mini-batch size for training.
129+
num_samples: Monte Carlo samples drawn per NLL evaluation.
130+
max_phases: Maximum number of radius-expansion phases.
131+
loss_convergence_tol: Absolute loss improvement threshold for
132+
stopping between phases.
133+
relative_loss_tol: Relative loss improvement threshold between phases.
134+
loss_increase_tol: Loss increase threshold for detecting overshoot.
135+
project: Enable per-step sublevel-set projection.
136+
"""
137+
138+
model_config = ConfigDict(extra="ignore")
139+
140+
# Override parent defaults for distribution training.
141+
tol: float = Field(default=1e-1, ge=0.0)
142+
record_params_every: int = Field(default=1, ge=1)
143+
epochs: int | None = Field(default=1, ge=1)
144+
145+
# Distribution-specific fields.
146+
val: float = Field(default=0.2, ge=0.0, le=1.0)
147+
eps: float = Field(default=1e-5, gt=0.0)
148+
min_radius: float = Field(default=3.0, ge=0.0)
149+
max_radius: float = Field(default=10.0, ge=0.0)
150+
rate: float = Field(default=1.1, gt=1.0)
151+
batch_size: int = Field(default=10, ge=1)
152+
num_samples: int = Field(default=10000, ge=1)
153+
max_phases: int = Field(default=1, ge=1)
154+
loss_convergence_tol: float = Field(default=1e-3, ge=0.0)
155+
relative_loss_tol: float = Field(default=float("inf"), ge=0.0)
156+
loss_increase_tol: float = Field(default=float("inf"), ge=0.0)
157+
project: bool = True
158+
159+
@model_validator(mode="after")
160+
def check_radius(self) -> TruncatedExponentialFamilyDistributionConfig:
161+
"""Validate that min_radius does not exceed max_radius."""
162+
if self.min_radius > self.max_radius:
163+
raise ValueError(
164+
f"min_radius ({self.min_radius}) must be <= "
165+
f"max_radius ({self.max_radius})."
166+
)
167+
return self
168+
169+
@model_validator(mode="after")
170+
def resolve_epochs_iterations(
171+
self,
172+
) -> TruncatedExponentialFamilyDistributionConfig:
173+
"""Clear the default epochs when iterations is explicitly provided.
174+
175+
The trainer uses exactly one stopping criterion. When the user
176+
supplies iterations, the per-phase epochs default is cleared so
177+
the trainer stops on iterations instead.
178+
"""
179+
if (
180+
"iterations" in self.model_fields_set
181+
and "epochs" not in self.model_fields_set
182+
):
183+
object.__setattr__(self, "epochs", None)
184+
return self
185+
186+
187+
class TruncatedMultivariateNormalConfig(TruncatedExponentialFamilyDistributionConfig):
188+
"""Configuration for truncated multivariate normal distributions.
189+
190+
Attributes:
191+
eigenvalue_lower_bound: Minimum eigenvalue enforced during the
192+
negative-definite cone projection of the precision matrix T.
193+
covariance_matrix_lr: Optional separate learning rate for the
194+
covariance matrix parameter; falls back to lr when None.
195+
"""
196+
197+
eigenvalue_lower_bound: float = Field(default=1e-2, gt=0.0)
198+
covariance_matrix_lr: float | None = Field(default=None, gt=0.0)

0 commit comments

Comments
 (0)