@@ -17,8 +17,8 @@ using DynamicPPL:
1717 setindex!!,
1818 push!!,
1919 setlogp!!,
20- getlogp,
2120 getlogjoint,
21+ getlogjoint_internal,
2222 VarName,
2323 getsym,
2424 getdist,
@@ -123,71 +123,94 @@ end
123123# #####################
124124# Default Transition #
125125# #####################
126- # Default
127- getstats (t) = nothing
126+ getstats (:: Any ) = NamedTuple ()
128127
128+ # TODO (penelopeysm): Remove this abstract type by converting SGLDTransition,
129+ # SMCTransition, and PGTransition to Turing.Inference.Transition instead.
129130abstract type AbstractTransition end
130131
131- struct Transition{T,F<: AbstractFloat ,S <: Union{ NamedTuple,Nothing} } <: AbstractTransition
132+ struct Transition{T,F<: AbstractFloat ,N <: NamedTuple } <: AbstractTransition
132133 θ:: T
133- lp:: F # TODO : merge `lp` with `stat`
134- stat:: S
135- end
136-
137- Transition (θ, lp) = Transition (θ, lp, nothing )
138- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , t)
139- θ = getparams (model, vi)
140- lp = getlogjoint (vi)
141- return Transition (θ, lp, getstats (t))
142- end
134+ logprior:: F
135+ loglikelihood:: F
136+ stat:: N
137+
138+ """
139+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
140+
141+ Construct a new `Turing.Inference.Transition` object using the outputs of a
142+ sampler step.
143+
144+ Here, `vi` represents a VarInfo _for which the appropriate parameters have
145+ already been set_. However, the accumulators (e.g. logp) may in general
146+ have junk contents. The role of this method is to re-evaluate `model` and
147+ thus set the accumulators to the correct values.
148+
149+ `sampler_transition` is the transition object returned by the sampler
150+ itself and is only used to extract statistics of interest.
151+ """
152+ function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
153+ vi = DynamicPPL. setaccs!! (
154+ vi,
155+ (
156+ DynamicPPL. ValuesAsInModelAccumulator (true ),
157+ DynamicPPL. LogPriorAccumulator (),
158+ DynamicPPL. LogLikelihoodAccumulator (),
159+ ),
160+ )
161+ _, vi = DynamicPPL. evaluate!! (model, vi)
162+
163+ # Extract all the information we need
164+ vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
165+ logprior = DynamicPPL. getlogprior (vi)
166+ loglikelihood = DynamicPPL. getloglikelihood (vi)
167+
168+ # Get additional statistics
169+ stats = getstats (sampler_transition)
170+ return new {typeof(vals_as_in_model),typeof(logprior),typeof(stats)} (
171+ vals_as_in_model, logprior, loglikelihood, stats
172+ )
173+ end
143174
144- function metadata (t:: Transition )
145- stat = t. stat
146- if stat === nothing
147- return (lp= t. lp,)
148- else
149- return merge ((lp= t. lp,), stat)
175+ function Transition (
176+ model:: DynamicPPL.Model ,
177+ untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
178+ sampler_transition,
179+ )
180+ # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181+ # much faster to convert it to a typed varinfo first, hence this method.
182+ # https://github.com/TuringLang/Turing.jl/issues/2604
183+ return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
150184 end
151185end
152186
153- DynamicPPL. getlogjoint (t:: Transition ) = t. lp
154-
155- # Metadata of VarInfo object
156- metadata (vi:: AbstractVarInfo ) = (lp= getlogjoint (vi),)
187+ function getstats_with_lp (t:: Transition )
188+ return merge (
189+ t. stat,
190+ (
191+ lp= t. logprior + t. loglikelihood,
192+ logprior= t. logprior,
193+ loglikelihood= t. loglikelihood,
194+ ),
195+ )
196+ end
197+ function getstats_with_lp (vi:: AbstractVarInfo )
198+ return (
199+ lp= DynamicPPL. getlogjoint (vi),
200+ logprior= DynamicPPL. getlogprior (vi),
201+ loglikelihood= DynamicPPL. getloglikelihood (vi),
202+ )
203+ end
157204
158205# #########################
159206# Chain making utilities #
160207# #########################
161208
162- """
163- getparams(model, t)
164-
165- Return a named tuple of parameters.
166- """
167- getparams (model, t) = t. θ
168- function getparams (model:: DynamicPPL.Model , vi:: DynamicPPL.VarInfo )
169- # NOTE: In the past, `invlink(vi, model)` + `values_as(vi, OrderedDict)` was used.
170- # Unfortunately, using `invlink` can cause issues in scenarios where the constraints
171- # of the parameters change depending on the realizations. Hence we have to use
172- # `values_as_in_model`, which re-runs the model and extracts the parameters
173- # as they are seen in the model, i.e. in the constrained space. Moreover,
174- # this means that the code below will work both of linked and invlinked `vi`.
175- # Ref: https://github.com/TuringLang/Turing.jl/issues/2195
176- # NOTE: We need to `deepcopy` here to avoid modifying the original `vi`.
177- return DynamicPPL. values_as_in_model (model, true , deepcopy (vi))
178- end
179- function getparams (
180- model:: DynamicPPL.Model , untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata}
181- )
182- # values_as_in_model is unconscionably slow for untyped VarInfo. It's
183- # much faster to convert it to a typed varinfo before calling getparams.
184- # https://github.com/TuringLang/Turing.jl/issues/2604
185- return getparams (model, DynamicPPL. typed_varinfo (untyped_vi))
209+ getparams (:: DynamicPPL.Model , t:: AbstractTransition ) = t. θ
210+ function getparams (model:: DynamicPPL.Model , vi:: AbstractVarInfo )
211+ t = Transition (model, vi, nothing )
212+ return getparams (model, t)
186213end
187- function getparams (:: DynamicPPL.Model , :: DynamicPPL.VarInfo{NamedTuple{(),Tuple{}}} )
188- return Dict {VarName,Any} ()
189- end
190-
191214function _params_to_array (model:: DynamicPPL.Model , ts:: Vector )
192215 names_set = OrderedSet {VarName} ()
193216 # Extract the parameter names and values from each transition.
@@ -203,7 +226,6 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
203226 iters = map (DynamicPPL. varname_and_value_leaves, keys (vals), values (vals))
204227 mapreduce (collect, vcat, iters)
205228 end
206-
207229 nms = map (first, nms_and_vs)
208230 vs = map (last, nms_and_vs)
209231 for nm in nms
@@ -218,14 +240,9 @@ function _params_to_array(model::DynamicPPL.Model, ts::Vector)
218240 return names, vals
219241end
220242
221- function get_transition_extras (ts:: AbstractVector{<:VarInfo} )
222- valmat = reshape ([getlogjoint (t) for t in ts], :, 1 )
223- return [:lp ], valmat
224- end
225-
226243function get_transition_extras (ts:: AbstractVector )
227- # Extract all metadata.
228- extra_data = map (metadata , ts)
244+ # Extract stats + log probabilities from each transition or VarInfo
245+ extra_data = map (getstats_with_lp , ts)
229246 return names_values (extra_data)
230247end
231248
@@ -334,7 +351,7 @@ function AbstractMCMC.bundle_samples(
334351 vals = map (values (sym_to_vns)) do vns
335352 map (Base. Fix1 (getindex, params), vns)
336353 end
337- return merge (NamedTuple (zip (keys (sym_to_vns), vals)), metadata (t))
354+ return merge (NamedTuple (zip (keys (sym_to_vns), vals)), getstats_with_lp (t))
338355 end
339356end
340357
@@ -396,84 +413,4 @@ function DynamicPPL.get_matching_type(
396413 return Array{T,N}
397414end
398415
399- # #############
400- # Utilities #
401- # #############
402-
403- """
404-
405- transitions_from_chain(
406- [rng::AbstractRNG,]
407- model::Model,
408- chain::MCMCChains.Chains;
409- sampler = DynamicPPL.SampleFromPrior()
410- )
411-
412- Execute `model` conditioned on each sample in `chain`, and return resulting transitions.
413-
414- The returned transitions are represented in a `Vector{<:Turing.Inference.Transition}`.
415-
416- # Details
417-
418- In a bit more detail, the process is as follows:
419- 1. For every `sample` in `chain`
420- 1. For every `variable` in `sample`
421- 1. Set `variable` in `model` to its value in `sample`
422- 2. Execute `model` with variables fixed as above, sampling variables NOT present
423- in `chain` using `SampleFromPrior`
424- 3. Return sampled variables and log-joint
425-
426- # Example
427- ```julia-repl
428- julia> using Turing
429-
430- julia> @model function demo()
431- m ~ Normal(0, 1)
432- x ~ Normal(m, 1)
433- end;
434-
435- julia> m = demo();
436-
437- julia> chain = Chains(randn(2, 1, 1), ["m"]); # 2 samples of `m`
438-
439- julia> transitions = Turing.Inference.transitions_from_chain(m, chain);
440-
441- julia> [Turing.Inference.getlogjoint(t) for t in transitions] # extract the logjoints
442- 2-element Array{Float64,1}:
443- -3.6294991938628374
444- -2.5697948166987845
445-
446- julia> [first(t.θ.x) for t in transitions] # extract samples for `x`
447- 2-element Array{Array{Float64,1},1}:
448- [-2.0844148956440796]
449- [-1.704630494695469]
450- ```
451- """
452- function transitions_from_chain (
453- model:: DynamicPPL.Model , chain:: MCMCChains.Chains ; kwargs...
454- )
455- return transitions_from_chain (Random. default_rng (), model, chain; kwargs... )
456- end
457-
458- function transitions_from_chain (
459- rng:: Random.AbstractRNG ,
460- model:: DynamicPPL.Model ,
461- chain:: MCMCChains.Chains ;
462- sampler= DynamicPPL. SampleFromPrior (),
463- )
464- vi = Turing. VarInfo (model)
465-
466- iters = Iterators. product (1 : size (chain, 1 ), 1 : size (chain, 3 ))
467- transitions = map (iters) do (sample_idx, chain_idx)
468- # Set variables present in `chain` and mark those NOT present in chain to be resampled.
469- DynamicPPL. setval_and_resample! (vi, chain, sample_idx, chain_idx)
470- model (rng, vi, sampler)
471-
472- # Convert `VarInfo` into `NamedTuple` and save.
473- Transition (model, vi)
474- end
475-
476- return transitions
477- end
478-
479416end # module
0 commit comments