Skip to content

Commit

Permalink
addressing other open tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pat-alt committed Dec 3, 2024
1 parent 292304e commit 8a73fe2
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 38 deletions.
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ jobs:
fail-fast: false
matrix:
version:
- '1.9'
- '1.10'
- '1'
os:
Expand Down
8 changes: 5 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
*Note*: We try to adhere to these practices as of version [v0.2.1].


## Version [2.0.0] - 2024-11-26
## Version [2.0.0] - 2024-12-03

### Changed

- Largely removed unicode characters from code base. [#134]
- Removed legacy v1.9 from CI testing. [#134]

### Added

- added support to MLJ [#126] [#134]

- Added general support for MLJ [#126] [#134]

## Version [1.1.1] - 2024-09-12

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ MLJBase = "1"
MLJModelInterface = "1.8.0"
MLUtils = "0.4"
Optimisers = "0.2, 0.3"
Random = "1.9, 1.10"
Random = "1"
Statistics = "1"
Tables = "1.10.1"
Test = "1"
Expand Down
34 changes: 25 additions & 9 deletions src/baselaplace/core_struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,37 @@ Container for the parameters of a Laplace approximation.
- `hessian_structure::HessianStructure`: the structure of the Hessian. Possible values are `:full` and `:kron` or a concrete subtype of `HessianStructure`.
- `backend::Symbol`: the backend to use. Possible values are `:GGN` and `:Fisher`.
- `curvature::Union{Curvature.CurvatureInterface,Nothing}`: the curvature interface. Possible values are `nothing` or a concrete subtype of `CurvatureInterface`.
- `σ::Real`: the observation noise
- `μ₀::Real`: the prior mean
- `λ::Real`: the prior precision
- `P₀::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix
- `observational_noise::Real`: the observation noise
- `prior_mean::Real`: the prior mean of the network parameters.
- `prio_precision::Real`: the prior precision for the network parameters.
- `prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix for the network parameters.
"""
Base.@kwdef struct LaplaceParams
subset_of_weights::Symbol = :all
subnetwork_indices::Union{Nothing,Vector{Vector{Int}}} = nothing
hessian_structure::Union{HessianStructure,Symbol,String} = FullHessian()
backend::Symbol = :GGN
curvature::Union{Curvature.CurvatureInterface,Nothing} = nothing
σ::Real = 1.0
μ₀::Real = 0.0
λ::Real = 1.0
P₀::Union{Nothing,AbstractMatrix,UniformScaling} = nothing
observational_noise::Real = 1.0
prior_mean::Real = 0.0
prior_precision::Real = 1.0
prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling} = nothing
end

function Base.getproperty(ce::LaplaceParams, sym::Symbol)
sym = sym === ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === ? :prior_precision : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.getfield(ce, sym)
end

function Base.setproperty!(ce::LaplaceParams, sym::Symbol, val)
sym = sym === ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === ? :prior_precision : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.setfield!(ce, sym, val)
end

include("estimation_params.jl")
Expand Down Expand Up @@ -96,7 +112,7 @@ la = Laplace(nn, likelihood=:regression)
"""
function Laplace(model::Any; likelihood::Symbol, kwargs...)
args = LaplaceParams(; kwargs...)
@assert !(args.σ != 1.0 && likelihood != :regression) "Observation noise σ ≠ 1 only available for regression."
@assert !(args.observational_noise != 1.0 && likelihood != :regression) "Observation noise σ ≠ 1 only available for regression."

# Unpack arguments and wrap in containers:
est_args = EstimationParams(args, model, likelihood)
Expand Down
4 changes: 2 additions & 2 deletions src/baselaplace/optimize_prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function optimize_prior!(
)

# Setup:
logP₀ = isnothing(λinit) ? log.(unique(diag(la.prior.P₀))) : log.([λinit]) # prior precision (scalar)
logσ = isnothing(σinit) ? log.([la.prior.σ]) : log.([σinit]) # noise (scalar)
logP₀ = isnothing(λinit) ? log.(unique(diag(la.prior.prior_precision_matrix))) : log.([λinit]) # prior precision (scalar)
logσ = isnothing(σinit) ? log.([la.prior.observational_noise]) : log.([σinit]) # noise (scalar)
opt = Adam(lr)
show_every = round(n_steps / 10)
i = 0
Expand Down
2 changes: 1 addition & 1 deletion src/baselaplace/predicting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ function predict(
if la.likelihood == :regression

# Add observational noise:
pred_var = fvar .+ la.prior.σ^2
pred_var = fvar .+ la.prior.observational_noise^2
fstd = sqrt.(pred_var)
pred_dist = [Normal(fμ[i], fstd[i]) for i in axes(fμ, 2)]

Expand Down
38 changes: 27 additions & 11 deletions src/baselaplace/prior.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,32 @@ Container for the prior parameters of a Laplace approximation.
# Fields
- `σ::Real`: the observation noise
- `μ₀::Real`: the prior mean
- `λ::Real`: the prior precision
- `P₀::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix
- `observational_noise::Real`: the observation noise
- `prior_mean::Real`: the prior mean
- `prior_precision::Real`: the prior precision
- `prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}`: the prior precision matrix
"""
mutable struct Prior
σ::Real
μ₀::Real
λ::Real
P₀::Union{Nothing,AbstractMatrix,UniformScaling}
observational_noise::Real
prior_mean::Real
prior_precision::Real
prior_precision_matrix::Union{Nothing,AbstractMatrix,UniformScaling}
end

function Base.getproperty(ce::Prior, sym::Symbol)
sym = sym === ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === ? :prior_precision : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.getfield(ce, sym)
end

function Base.setproperty!(ce::Prior, sym::Symbol, val)
sym = sym === ? :observational_noise : sym
sym = sym === :μ₀ ? :prior_mean : sym
sym = sym === ? :prior_precision : sym
sym = sym === :P₀ ? :prior_precision_matrix : sym
return Base.setfield!(ce, sym, val)
end

"""
Expand All @@ -23,16 +39,16 @@ end
Extracts the prior parameters from a `LaplaceParams` object.
"""
function Prior(params::LaplaceParams, model::Any, likelihood::Symbol)
P₀ = params.P₀
P₀ = params.prior_precision_matrix
n = LaplaceRedux.n_params(model, EstimationParams(params, model, likelihood))
if typeof(P₀) <: UniformScaling
P₀ = P₀(n)
elseif isnothing(P₀)
P₀ = UniformScaling(params.λ)(n)
P₀ = UniformScaling(params.prior_precision)(n)
end
# Sanity:
if isa(P₀, AbstractMatrix)
@assert all(size(P₀) .== n) "Dimensions of prior Hessian $(size(P₀)) do not align with number of parameters ($n)"
end
return Prior(params.σ, params.μ₀, params.λ, P₀)
return Prior(params.observational_noise, params.prior_mean, params.prior_precision, P₀)
end
20 changes: 10 additions & 10 deletions src/baselaplace/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ LaplaceRedux.n_params(la::Laplace) = LaplaceRedux.n_params(la.model, la.est_para
Helper function to extract the prior mean of the parameters from a Laplace approximation.
"""
function get_prior_mean(la::Laplace)
return la.prior.μ₀
return la.prior.prior_mean
end

"""
Expand All @@ -27,7 +27,7 @@ end
Helper function to extract the prior precision matrix from a Laplace approximation.
"""
function prior_precision(la::Laplace)
return la.prior.P₀
return la.prior.prior_precision_matrix
end

"""
Expand All @@ -39,15 +39,15 @@ on the last layer of the NN, of a `Flux.Chain` with Laplace approximation.
outdim(la::AbstractLaplace) = outdim(la.model)

@doc raw"""
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)
posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix)
Computes the posterior precision ``P`` for a fitted Laplace Approximation as follows,
``P = \sum_{n=1}^N\nabla_{\theta}^2 \log p(\mathcal{D}_n|\theta)|_{\hat\theta} + \nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}``
where ``\sum_{n=1}^N\nabla_{\theta}^2\log p(\mathcal{D}_n|\theta)|_{\hat\theta}=H`` is the Hessian and ``\nabla_{\theta}^2 \log p(\theta)|_{\hat\theta}=P_0`` is the prior precision and ``\hat\theta`` is the MAP estimate.
"""
function posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.P₀)
function posterior_precision(la::AbstractLaplace, H=la.posterior.H, P₀=la.prior.prior_precision_matrix)
@assert !isnothing(H) "Hessian not available. Either no value supplied or Laplace Approximation has not yet been estimated."
return H + P₀
end
Expand All @@ -70,7 +70,7 @@ end
function log_likelihood(la::AbstractLaplace)
factor = -_H_factor(la)
if la.likelihood == :regression
c = la.posterior.n_data * la.posterior.n_out * log(la.prior.σ * sqrt(2 * pi))
c = la.posterior.n_data * la.posterior.n_out * log(la.prior.observational_noise * sqrt(2 * pi))
else
c = 0
end
Expand All @@ -82,7 +82,7 @@ end
Returns the factor σ⁻², where σ is used in the zero-centered Gaussian prior p(θ) = N(θ;0,σ²I)
"""
_H_factor(la::AbstractLaplace) = 1 / (la.prior.σ^2)
_H_factor(la::AbstractLaplace) = 1 / (la.prior.observational_noise^2)

"""
_init_H(la::AbstractLaplace)
Expand Down Expand Up @@ -120,14 +120,14 @@ function log_marginal_likelihood(

# update prior precision:
if !isnothing(P₀)
la.prior.P₀ =
la.prior.prior_precision_matrix =
typeof(P₀) <: AbstractFloat ? UniformScaling(P₀)(la.posterior.n_params) : P₀
end

# update observation noise:
if !isnothing(σ)
@assert (la.likelihood == :regression || la.prior.σ == σ) "Can only change observational noise σ for regression."
la.prior.σ = σ
@assert (la.likelihood == :regression || la.prior.observational_noise == σ) "Can only change observational noise σ for regression."
la.prior.observational_noise = σ
end

return log_likelihood(la) - 0.5 * (log_det_ratio(la) + _weight_penalty(la))
Expand All @@ -147,7 +147,7 @@ end
"""
log_det_prior_precision(la::AbstractLaplace) = sum(log.(diag(la.prior.P₀)))
log_det_prior_precision(la::AbstractLaplace) = sum(log.(diag(la.prior.prior_precision_matrix)))

"""
log_det_posterior_precision(la::AbstractLaplace)
Expand Down

0 comments on commit 8a73fe2

Please sign in to comment.