22"""Multivariate normal distribution in natural parameterization."""
33
44import torch as ch
5- from torch import Tensor
6- from torch .distributions import MultivariateNormal
5+ import torch .distributions
76
87
98class ExponentialFamilyMultivariateNormalKnownCovariance ( # pylint: disable=abstract-method
10- MultivariateNormal
9+ ch . distributions . MultivariateNormal
1110):
1211 """Multivariate normal parameterized by natural parameters with known covariance."""
1312
14- def __init__ (self , covariance_matrix : Tensor , theta : Tensor , dims : int ):
13+ def __init__ (self , covariance_matrix : ch . Tensor , theta : ch . Tensor , dims : int ):
1514 """Initialize with covariance matrix, natural parameter theta, and dimension."""
1615 self .dims = dims
1716 v = theta
1817 mu = (covariance_matrix @ v ).view (self .dims )
1918 super ().__init__ (mu , covariance_matrix )
2019
2120 @staticmethod
22- def calc_suff_stat (x : Tensor ) -> Tensor :
21+ def calc_suff_stat (x : ch . Tensor ) -> ch . Tensor :
2322 """Return sufficient statistics for multivariate normal with known covariance."""
2423 return x
2524
2625 @staticmethod
27- def to_natural (theta : Tensor , covariance_matrix : Tensor ) -> Tensor :
26+ def to_natural (theta : ch . Tensor , covariance_matrix : ch . Tensor ) -> ch . Tensor :
2827 """Convert canonical mean to natural parameters."""
2928 inv_cov = covariance_matrix .inverse ()
3029 v = theta @ inv_cov # pylint: disable=invalid-name
3130 return v .flatten ()
3231
3332 @staticmethod
34- def to_canonical (theta : Tensor , covariance_matrix : Tensor ) -> Tensor :
33+ def to_canonical (theta : ch . Tensor , covariance_matrix : ch . Tensor ) -> ch . Tensor :
3534 """Convert natural parameters to canonical mean."""
3635 loc = theta @ covariance_matrix
3736 return loc .flatten ()
3837
3938
40- class ExponentialFamilyMultivariateNormal (MultivariateNormal ): # pylint: disable=abstract-method
39+ class ExponentialFamilyMultivariateNormal ( # pylint: disable=abstract-method
40+ ch .distributions .MultivariateNormal
41+ ):
4142 """Multivariate normal parameterized by natural parameters."""
4243
43- def __init__ (self , theta : Tensor , dims : int ):
44+ def __init__ (self , theta : ch . Tensor , dims : int ):
4445 """Initialize with natural parameter theta and dimension."""
4546 self .dims = dims
4647 T , v = theta [: self .dims ** 2 ], theta [self .dims ** 2 :] # pylint: disable=invalid-name
@@ -49,12 +50,12 @@ def __init__(self, theta: Tensor, dims: int):
4950 super ().__init__ (mu , covariance_matrix )
5051
5152 @staticmethod
52- def calc_suff_stat (x : Tensor ) -> Tensor :
53+ def calc_suff_stat (x : ch . Tensor ) -> ch . Tensor :
5354 """Return sufficient statistics for multivariate normal."""
5455 return ch .cat ([ch .bmm (x .unsqueeze (2 ), x .unsqueeze (1 )).flatten (1 ), x ], 1 )
5556
5657 @staticmethod
57- def to_natural (theta : Tensor , dims : int ) -> Tensor :
58+ def to_natural (theta : ch . Tensor , dims : int ) -> ch . Tensor :
5859 """Convert canonical parameters to natural form."""
5960 cov_matrix = theta [: dims ** 2 ].view (dims , dims )
6061 loc = theta [dims ** 2 :]
@@ -63,7 +64,7 @@ def to_natural(theta: Tensor, dims: int) -> Tensor:
6364 return ch .cat ([- 0.5 * mat_t .flatten (), v .flatten ()])
6465
6566 @staticmethod
66- def to_canonical (theta : Tensor , dims : int ) -> Tensor :
67+ def to_canonical (theta : ch . Tensor , dims : int ) -> ch . Tensor :
6768 """Convert natural parameters to canonical form."""
6869 mat_t = theta [: dims ** 2 ].view (dims , dims ) # pylint: disable=invalid-name
6970 v = theta [dims ** 2 :]
0 commit comments