Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The separation of these functions was primarily implemented to avoid performing

Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed.

The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed in favour of more specific ones. We've also used this opportunity to name the `"trans"` flag and the corresponding `istrans` function to be more explicit. The new, exported interface consists of the `is_transformed` and `set_transformed!!` functions.

**Other changes**

### `setleafcontext(model, context)`
Expand Down
7 changes: 2 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -345,9 +345,8 @@ The [Transformations section below](#Transformations) describes the methods used
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.

```@docs
set_flag!
unset_flag!
is_flagged
is_transformed
set_transformed!!
```

```@docs
Expand Down Expand Up @@ -439,8 +438,6 @@ DynamicPPL.StaticTransformation
```

```@docs
DynamicPPL.istrans
DynamicPPL.settrans!!
DynamicPPL.transformation
DynamicPPL.link
DynamicPPL.invlink
Expand Down
5 changes: 3 additions & 2 deletions ext/DynamicPPLEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ else
using ..EnzymeCore
end

# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme
# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme
# only checks whether such a method exists, and never runs it.
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing
@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) =
nothing

end
4 changes: 2 additions & 2 deletions ext/DynamicPPLMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module DynamicPPLMooncakeExt

using DynamicPPL: DynamicPPL, istrans
using DynamicPPL: DynamicPPL, is_transformed
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg}
Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg}

end # module
6 changes: 2 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,8 @@ export AbstractVarInfo,
acclogjac!!,
acclogprior!!,
accloglikelihood!!,
is_flagged,
set_flag!,
unset_flag!,
istrans,
is_transformed,
set_transformed!!,
link,
link!!,
invlink,
Expand Down
34 changes: 17 additions & 17 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ end

# Transformations
"""
istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])
is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}])

Return `true` if `vi` is working in unconstrained space, and `false`
if `vi` is assuming realizations to be in support of the corresponding distributions.
Expand All @@ -780,27 +780,27 @@ If `vns` is provided, then only check if this/these varname(s) are transformed.
Not all implementations of `AbstractVarInfo` support transforming only a subset of
the variables.
"""
istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi)))
function istrans(vi::AbstractVarInfo, vns::AbstractVector)
# This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`.
is_transformed(vi::AbstractVarInfo) = is_transformed(vi, collect(keys(vi)))
function is_transformed(vi::AbstractVarInfo, vns::AbstractVector)
# This used to be: `!isempty(vns) && all(Base.Fix1(is_transformed, vi), vns)`.
# In theory that should work perfectly fine. For unbeknownst reasons,
# Julia 1.10 fails to infer its return type correctly. Thus we use this
# slightly longer definition.
isempty(vns) && return false
for vn in vns
istrans(vi, vn) || return false
is_transformed(vi, vn) || return false
end
return true
end

"""
settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])
set_transformed!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName])

Return `vi` with `istrans(vi, vn)` evaluating to `true`.
Return `vi` with `is_transformed(vi, vn)` evaluating to `true`.

If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables.
If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all variables.
"""
function settrans!! end
function set_transformed!! end

# For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback
# method for the case when no `vns` is provided, that would get all the keys from the
Expand Down Expand Up @@ -832,7 +832,7 @@ function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# has a dedicated implementation
model = setleafcontext(model, DynamicTransformationContext{false}())
vi = last(evaluate!!(model, vi))
return settrans!!(vi, t)
return set_transformed!!(vi, t)
end
function link!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
Expand All @@ -845,7 +845,7 @@ function link!!(
if hasacc(vi, Val(:LogJacobian))
vi = acclogjac!!(vi, logjac)
end
return settrans!!(vi, t)
return set_transformed!!(vi, t)
end

"""
Expand Down Expand Up @@ -894,7 +894,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model)
# has a dedicated implementation
model = setleafcontext(model, DynamicTransformationContext{true}())
vi = last(evaluate!!(model, vi))
return settrans!!(vi, NoTransformation())
return set_transformed!!(vi, NoTransformation())
end
function invlink!!(
t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model
Expand All @@ -910,7 +910,7 @@ function invlink!!(
if hasacc(vi, Val(:LogJacobian))
vi = acclogjac!!(vi, inv_logjac)
end
return settrans!!(vi, NoTransformation())
return set_transformed!!(vi, NoTransformation())
end

"""
Expand Down Expand Up @@ -1018,7 +1018,7 @@ function unflatten end
"""
to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
Return reconstructed `val`, possibly linked if `is_transformed(vi, vn)` is `true`.
"""
function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
f = to_maybe_linked_internal_transform(vi, vn, dist)
Expand All @@ -1028,7 +1028,7 @@ end
"""
from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
Return reconstructed `val`, possibly invlinked if `is_transformed(vi, vn)` is `true`.
"""
function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val)
f = from_maybe_linked_internal_transform(vi, vn, dist)
Expand Down Expand Up @@ -1085,14 +1085,14 @@ in `varinfo` to a representation compatible with `dist`.
If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`.
"""
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist)
return if istrans(varinfo, vn)
return if is_transformed(varinfo, vn)
from_linked_internal_transform(varinfo, vn, dist)
else
from_internal_transform(varinfo, vn, dist)
end
end
function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName)
return if istrans(varinfo, vn)
return if is_transformed(varinfo, vn)
from_linked_internal_transform(varinfo, vn)
else
from_internal_transform(varinfo, vn)
Expand Down
6 changes: 3 additions & 3 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ function tilde_assume!!(
# If the VarInfo alrady had a value for this variable, we will
# keep the same linked status as in the original VarInfo. If not, we
# check the rest of the VarInfo to see if other variables are linked.
# istrans(vi) returns true if vi is nonempty and all variables in vi
# is_transformed(vi) returns true if vi is nonempty and all variables in vi
# are linked.
insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi)
insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi)
f = if insert_transformed_value
link_transform(dist)
else
Expand All @@ -183,7 +183,7 @@ function tilde_assume!!(
end
# Neither of these set the `trans` flag so we have to do it manually if
# necessary.
insert_transformed_value && settrans!!(vi, true, vn)
insert_transformed_value && set_transformed!!(vi, true, vn)
# `accumulate_assume!!` wants untransformed values as the second argument.
vi = accumulate_assume!!(vi, x, logjac, vn, dist)
# We always return the untransformed value here, as that will determine
Expand Down
2 changes: 1 addition & 1 deletion src/contexts/transformation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function tilde_assume!!(
# vi[vn, right] always provides the value in unlinked space.
x = vi[vn, right]

if istrans(vi, vn)
if is_transformed(vi, vn)
isinverse || @warn "Trying to link an already transformed variable ($vn)"
else
isinverse && @warn "Trying to invlink a non-transformed variable ($vn)"
Expand Down
44 changes: 23 additions & 21 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,23 @@ julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo());
julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞
1.8632965762164932

julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true));
julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true));

julia> vi[@varname(x)] # (✓) -∞ < x < ∞
-0.21080155351918753

julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];

julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
true

julia> # And with `OrderedDict` of course!
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));
_, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true));

julia> vi[@varname(x)] # (✓) -∞ < x < ∞
0.6225185067787314

julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];
julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10];

julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
true
Expand All @@ -121,15 +121,15 @@ true
Evaluation in transformed space of course also works:

```jldoctest simplevarinfo-general
julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true)
julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true)
Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))

julia> # (✓) Positive probability mass on negative numbers!
getlogjoint_internal(last(DynamicPPL.evaluate!!(m, vi)))
-1.3678794411714423

julia> # While if we forget to indicate that it's transformed:
vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false)
vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false)
SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0)))

julia> # (✓) No probability mass on negative numbers!
Expand Down Expand Up @@ -466,32 +466,34 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo)
return SimpleVarInfo(values, accs, transformation)
end

function settrans!!(vi::SimpleVarInfo, trans)
return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation())
function set_transformed!!(vi::SimpleVarInfo, trans)
return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation())
end
function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation)
return Accessors.@set vi.transformation = transformation
end
function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans)
function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans)
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans)
end
function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName)
# We keep this method around just to obey the AbstractVarInfo interface.
# However, note that this would only be a valid operation if it would be a
# no-op, which we check here.
if trans != istrans(vi)
if trans != is_transformed(vi)
error(
"Individual variables in SimpleVarInfo cannot have different `settrans` statuses.",
"Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.",
)
end
end

istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo)
is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi)
function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName)
return is_transformed(vi.varinfo, vn)
end
is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo)

islinked(vi::SimpleVarInfo) = istrans(vi)
islinked(vi::SimpleVarInfo) = is_transformed(vi)
Comment on lines -494 to +496
Copy link
Member

@penelopeysm penelopeysm Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

like this line is just the same function but duplicated. so it feels like to me we could just pick one and roll with it!


values_as(vi::SimpleVarInfo) = vi.values
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
Expand Down Expand Up @@ -618,7 +620,7 @@ function link!!(
if hasacc(vi_new, Val(:LogJacobian))
vi_new = acclogjac!!(vi_new, logjac)
end
return settrans!!(vi_new, t)
return set_transformed!!(vi_new, t)
end

function invlink!!(
Expand All @@ -636,7 +638,7 @@ function invlink!!(
if hasacc(vi_new, Val(:LogJacobian))
vi_new = acclogjac!!(vi_new, inv_logjac)
end
return settrans!!(vi_new, NoTransformation())
return set_transformed!!(vi_new, NoTransformation())
end

# With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything.
Expand Down
19 changes: 7 additions & 12 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,12 @@ end
# to define `getacc(vi)`.
function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
model = setleafcontext(model, DynamicTransformationContext{false}())
return settrans!!(last(evaluate!!(model, vi)), t)
return set_transformed!!(last(evaluate!!(model, vi)), t)
end

function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
model = setleafcontext(model, DynamicTransformationContext{true}())
return settrans!!(last(evaluate!!(model, vi)), NoTransformation())
return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation())
end

function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model)
Expand Down Expand Up @@ -181,20 +181,15 @@ end
values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo)
values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T)

function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return unset_flag!(vi.varinfo, vn, flag)
end
function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String)
return is_flagged(vi.varinfo, vn, flag)
function set_transformed!!(vi::ThreadSafeVarInfo, val::Bool, vn::VarName)
return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, val, vn)
end

function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName)
return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn)
is_transformed(vi::ThreadSafeVarInfo, vn::VarName) = is_transformed(vi.varinfo, vn)
function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName})
return is_transformed(vi.varinfo, vns)
end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn)

function unflatten(vi::ThreadSafeVarInfo, x::AbstractVector)
Expand Down
Loading
Loading