Skip to content

Commit a3cc03c

Browse files
committed
improve simulation code
1 parent cdaa78b commit a3cc03c

4 files changed

Lines changed: 31 additions & 16 deletions

File tree

src/base.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ function _rand(prior::P, Δparams::NTuple{2,T}) where {T,P}
8787
return rand(P(postparams...))
8888
end
8989

90+
_randn!(x::AbstractVecOrMat{T}, μ::AbstractVector{T}, σ::AbstractVector{T}) where {T} =
91+
x .= muladd.(randn!(x), σ, μ)
92+
9093
_randn!(x::AbstractArray{T}, σ::Union{T,AbstractMatrix{T}}) where {T} = randn!(x) .*= σ
9194

9295
function _randn!(x::AbstractArray{T,3}, σ::T, σ₀::AbstractVecOrMat{T}) where {T}
@@ -142,5 +145,15 @@ function mrw_propose(P::Union{BetaPrime{T},Gamma{T}}, xᵒ::T) where {T<:Abstrac
142145
return xᵖ, log(xᵒ) + logpdf(P, 1 / ϵ) - log(xᵖ) - logpdf(P, ϵ) #* Needs test
143146
end
144147

148+
function _resize3(
149+
x::AbstractArray{T,3},
150+
newsz::Integer;
151+
fill = Union{Nothing,T} = nothing,
152+
) where {T}
153+
y = similar(x, size(x, 1), size(x, 2), newsz)
154+
copyto!(y, x)
155+
!isnothing(fill) && fill!(view(y, :, :, size(x, 3)+1:newsz), fill)
156+
return y
157+
end
145158
# get_Δlogprior(xᵖ::T, xᵒ::T, distr::Gamma{T}) where {T<:AbstractFloat} =
146159
# (shape(distr) - 1) * log(xᵖ / xᵒ) - (xᵖ - xᵒ) / scale(distr)

src/chain.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ function extend!(
9393
logposterior = get_logposterior(loglikelihood, tracks, msd)
9494
push!(
9595
chain.samples,
96-
Sample(tracks.onpart, msd, brightness, iter, 𝑇, logposterior, loglikelihood),
96+
Sample(tracks.onchunk, msd, brightness, iter, 𝑇, logposterior, loglikelihood),
9797
)
9898
isfull(chain) && shrink!(chain)
9999
end
@@ -110,8 +110,8 @@ function get_loglikelihood!(
110110
detector::Detector{T},
111111
psf::PointSpreadFunction{T},
112112
) where {T}
113-
seteffvalue!(tracks.onpart)
114-
set_poisson_mean!(llarray, detector, tracks.onpart.effvalue, brightness.value, psf)
113+
seteffvalue!(tracks.onchunk)
114+
set_poisson_mean!(llarray, detector, tracks.onchunk.effvalue, brightness.value, psf)
115115
return get_loglikelihood!(llarray, detector)
116116
end
117117

@@ -121,7 +121,7 @@ get_logposterior(
121121
msd::MeanSquaredDisplacement{T},
122122
) where {T} =
123123
loglikelihood +
124-
logprior(tracks.onpart, msd.value) +
124+
logprior(tracks.onchunk, msd.value) +
125125
logprior(msd) +
126126
logprior(tracks.nemitters)
127127

@@ -134,10 +134,10 @@ function parametricMCMC!(
134134
psf::PointSpreadFunction{T},
135135
𝑇::T,
136136
) where {T}
137-
update_onpart!(tracks, msd.value, brightness.value, llarray, detector, psf, 𝑇)
138-
update!(brightness, tracks.onpart.effvalue, llarray, detector, psf, 𝑇)
139-
setdisplacement²!(tracks.onpart)
140-
update!(msd, tracks.onpart.displacement², 𝑇)
137+
update_onchunk!(tracks, msd.value, brightness.value, llarray, detector, psf, 𝑇)
138+
update!(brightness, tracks.onchunk.effvalue, llarray, detector, psf, 𝑇)
139+
setdisplacement²!(tracks.onchunk)
140+
update!(msd, tracks.onchunk.displacement², 𝑇)
141141
return tracks, msd
142142
end
143143

@@ -150,9 +150,9 @@ function nonparametricMCMC!(
150150
psf::PointSpreadFunction{T},
151151
𝑇::T,
152152
) where {T}
153-
simulate!(tracks.offpart, msd.value)
153+
simulate!(tracks.offchunk, msd.value)
154154
if any(tracks)
155-
update_onpart!(tracks, msd.value, brightness.value, llarray, detector, psf, 𝑇)
155+
update_onchunk!(tracks, msd.value, brightness.value, llarray, detector, psf, 𝑇)
156156
onshuffle!(tracks)
157157
end
158158

@@ -168,7 +168,7 @@ function nonparametricMCMC!(
168168
)
169169
reassign!(tracks)
170170

171-
update!(brightness, tracks.onpart.effvalue, llarray, detector, psf, 𝑇)
171+
update!(brightness, tracks.onchunk.effvalue, llarray, detector, psf, 𝑇)
172172

173173
setdisplacement²!(tracks)
174174
update!(msd, tracks.displacement²[1], 𝑇)
@@ -209,7 +209,7 @@ function runMCMC(;
209209
parametric::Bool = false,
210210
) where {T}
211211
isnothing(annealing) && (annealing = ConstantAnnealing{T}(1))
212-
chain = Chain([Sample(tracks.onpart, msd, brightness)], sizelimit, annealing)
212+
chain = Chain([Sample(tracks.onchunk, msd, brightness)], sizelimit, annealing)
213213
runMCMC!(chain, tracks, msd, brightness, detector, psf, niters, parametric)
214214
return chain
215215
end

src/distributions.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ function Base.getproperty(d::DNormal, s::Symbol)
1717
end
1818
end
1919

20-
Distributions.params(n::DNormal) = n.μ, n.σ
20+
Distributions.params(P::DNormal) = P.μ, P.σ
2121

2222
logprior(ℕ::DNormal{T}, x::AbstractArray{T}) where {T} =
23-
sum(vec(@. -(x -.μ) / (2 *.σ^2)))
23+
sum(vec(@. -(x -.μ) / (2 *.σ^2)))
24+
25+
Random.rand!(x::AbstractVecOrMat{T}, P::DNormal{T}) where {T} = _randn!(x, P.μ, P.σ)

src/permutation.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ function _permute!(
3030
end
3131

3232
function _permute!(tracks::Tracks{T}, p::AbstractVector{<:Integer}) where {T}
33-
_permute!(tracks.onpart.value, p, tracks.proposals.value)
34-
_permute!(tracks.onpart.active, p, tracks.proposals.active)
33+
_permute!(tracks.onchunk.value, p, tracks.proposals.value)
34+
_permute!(tracks.onchunk.active, p, tracks.proposals.active)
3535
return tracks
3636
end
3737

0 commit comments

Comments
 (0)