@@ -136,7 +136,7 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
136136 stat:: N
137137
138138 """
139- Transition(model::Model, vi::AbstractVarInfo, sampler_transition)
139+ Transition(model::Model, vi::AbstractVarInfo, sampler_transition; reevaluate=true )
140140
141141 Construct a new `Turing.Inference.Transition` object using the outputs of a
142142 sampler step.
@@ -148,17 +148,38 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
148148
149149 `sampler_transition` is the transition object returned by the sampler
150150 itself and is only used to extract statistics of interest.
151+
152+ By default, the model is re-evaluated in order to obtain values of:
153+ - the values of the parameters as per user parameterisation (`vals_as_in_model`)
154+ - the various components of the log joint probability (`logprior`, `loglikelihood`)
155+ that are guaranteed to be correct.
156+
157+ If you **know** for a fact that the VarInfo `vi` already contains this information,
158+ then you can set `reevaluate=false` to skip the re-evaluation step.
159+
160+ !!! warning
161+ Note that in general this is unsafe and may lead to wrong results.
162+
163+ If `reevaluate` is set to `false`, it is the caller's responsibility to ensure that
164+ the `VarInfo` passed in has `ValuesAsInModelAccumulator`, `LogPriorAccumulator`,
165+ and `LogLikelihoodAccumulator` set up with the correct values. Note that the
166+ `ValuesAsInModelAccumulator` must also have `include_colon_eq == true`, i.e. it
167+ must be set up to track `x := y` statements.
151168 """
152- function Transition (model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition)
153- vi = DynamicPPL. setaccs!! (
154- vi,
155- (
156- DynamicPPL. ValuesAsInModelAccumulator (true ),
157- DynamicPPL. LogPriorAccumulator (),
158- DynamicPPL. LogLikelihoodAccumulator (),
159- ),
160- )
161- _, vi = DynamicPPL. evaluate!! (model, vi)
169+ function Transition (
170+ model:: DynamicPPL.Model , vi:: AbstractVarInfo , sampler_transition; reevaluate= true
171+ )
172+ if reevaluate
173+ vi = DynamicPPL. setaccs!! (
174+ vi,
175+ (
176+ DynamicPPL. ValuesAsInModelAccumulator (true ),
177+ DynamicPPL. LogPriorAccumulator (),
178+ DynamicPPL. LogLikelihoodAccumulator (),
179+ ),
180+ )
181+ _, vi = DynamicPPL. evaluate!! (model, vi)
182+ end
162183
163184 # Extract all the information we need
164185 vals_as_in_model = DynamicPPL. getacc (vi, Val (:ValuesAsInModel )). values
@@ -175,12 +196,18 @@ struct Transition{T,F<:AbstractFloat,N<:NamedTuple} <: AbstractTransition
175196 function Transition (
176197 model:: DynamicPPL.Model ,
177198 untyped_vi:: DynamicPPL.VarInfo{<:DynamicPPL.Metadata} ,
178- sampler_transition,
199+ sampler_transition;
200+ reevaluate= true ,
179201 )
180202 # Re-evaluating the model is unconscionably slow for untyped VarInfo. It's
181203 # much faster to convert it to a typed varinfo first, hence this method.
182204 # https://github.com/TuringLang/Turing.jl/issues/2604
183- return Transition (model, DynamicPPL. typed_varinfo (untyped_vi), sampler_transition)
205+ return Transition (
206+ model,
207+ DynamicPPL. typed_varinfo (untyped_vi),
208+ sampler_transition;
209+ reevaluate= reevaluate,
210+ )
184211 end
185212end
186213
0 commit comments