From 756cc252c1755f90cc9a0aebac961e1a80a4d8f5 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Mon, 29 Sep 2025 17:55:35 +0100 Subject: [PATCH 1/5] Replace Medata.flags with Metadata.trans --- HISTORY.md | 2 + docs/src/api.md | 7 +-- src/DynamicPPL.jl | 4 +- src/threadsafe.jl | 7 --- src/varinfo.jl | 114 ++++++++++------------------------------------ test/varinfo.jl | 19 ++++---- 6 files changed, 36 insertions(+), 117 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index aaa5ac1eb..f2c5e9988 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -54,6 +54,8 @@ The separation of these functions was primarily implemented to avoid performing Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. +The only other flag, other than `"del"`, that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed. One can simply use `istrans` and a newly exported function called `settrans!!` instead. + **Other changes** ### Reimplementation of functions using `InitContext` diff --git a/docs/src/api.md b/docs/src/api.md index e5c483bca..e9ce257fb 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -329,9 +329,8 @@ The [Transformations section below](#Transformations) describes the methods used In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. ```@docs -set_flag! -unset_flag! -is_flagged +istrans +settrans!! ``` ```@docs @@ -423,8 +422,6 @@ DynamicPPL.StaticTransformation ``` ```@docs -DynamicPPL.istrans -DynamicPPL.settrans!! DynamicPPL.transformation DynamicPPL.link DynamicPPL.invlink diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index b1b3bc3d9..acfce2f51 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -70,10 +70,8 @@ export AbstractVarInfo, acclogjac!!, acclogprior!!, accloglikelihood!!, - is_flagged, - set_flag!, - unset_flag!, istrans, + settrans!!, link, link!!, invlink, diff --git a/src/threadsafe.jl b/src/threadsafe.jl index f89a562e3..473fe67bf 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -185,13 +185,6 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) -end -function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return is_flagged(vi.varinfo, vn, flag) -end - function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 062cc236b..69e4e4589 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,13 +15,13 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.flags`. + `md.vns`, `md.ranges` `md.dists`, and `md.trans`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.flags` is a dictionary of true/false flags. `md.flags[flag][md.idcs[vn]]` is the - value of `flag` corresponding to `vn`. +- `md.trans` is a Bitvector of true/false flags for whether a variable has been transformed. + `md.trans[md.idcs[vn]]` is the value of `trans` corresponding to `vn`. To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a @@ -56,8 +56,7 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - # Each `flag` has a `BitVector` `flags[flag]`, where `flags[flag][i]` is the true/false flag value corresonding to `vns[i]` - flags::Dict{String,BitVector} + trans::BitVector end function Base.:(==)(md1::Metadata, md2::Metadata) @@ -67,7 +66,7 @@ function Base.:(==)(md1::Metadata, md2::Metadata) md1.ranges == md2.ranges && md1.vals == md2.vals && md1.dists == md2.dists && - md1.flags == md2.flags + md1.trans == md2.trans ) end @@ -246,8 +245,8 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New flags - sym_flags = Dict(a => meta.flags[a][inds] for a in keys(meta.flags)) + # New trans + sym_trans = meta.trans[inds] # Extract new ranges and vals _ranges = getindex.((meta.ranges,), inds) @@ -263,7 +262,7 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_flags), + Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_trans), ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) @@ -406,7 +405,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.flags) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.trans) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -422,8 +421,7 @@ Construct an empty type unstable instance of `Metadata`. """ function Metadata() vals = Vector{Real}() - flags = Dict{String,BitVector}() - flags["trans"] = BitVector() + trans = BitVector() return Metadata( Dict{VarName,Int}(), @@ -431,7 +429,7 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - flags, + trans, ) end @@ -448,10 +446,7 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - for k in keys(meta.flags) - empty!(meta.flags[k]) - end - + empty!(meta.trans) return meta end @@ -535,8 +530,8 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va offset = r[end] end - flags = Dict(k => v[indices_for_vns] for (k, v) in metadata.flags) - return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], flags) + trans = trans[indices_for_vns] + return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], trans) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -607,11 +602,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - flags = Dict{String,BitVector}() - # Initialize the `flags`. - for k in union(keys(metadata_left.flags), keys(metadata_right.flags)) - flags[k] = BitVector() - end + trans = BitVector() # Range offset. offset = 0 @@ -628,12 +619,10 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - for k in keys(flags) - push!(flags[k], is_flagged(metadata_for_vn, vn, k)) - end + push!(trans, is_trans(metadata_for_vn, vn)) end - return Metadata(idcs, vns, ranges, vals, dists, flags) + return Metadata(idcs, vns, ranges, vals, dists, trans) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -807,12 +796,7 @@ function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) return vi end function settrans!!(metadata::Metadata, trans::Bool, vn::VarName) - if trans - set_flag!(metadata, vn, "trans") - else - unset_flag!(metadata, vn, "trans") - end - + metadata.trans[getidx(metadata, vn)] = trans return metadata end @@ -870,25 +854,6 @@ all_varnames_grouped_by_symbol(vi::NTVarInfo) = all_varnames_grouped_by_symbol(v return expr end -# TODO(mhauru) These set_flag! methods return the VarInfo. They should probably be called -# set_flag!!. -""" - set_flag!(vi::VarInfo, vn::VarName, flag::String) - -Set `vn`'s value for `flag` to `true` in `vi`. -""" -function set_flag!(vi::VarInfo, vn::VarName, flag::String) - set_flag!(getmetadata(vi, vn), vn, flag) - return vi -end -function set_flag!(md::Metadata, vn::VarName, flag::String) - return md.flags[flag][getidx(md, vn)] = true -end - -function set_flag!(vnv::VarNamedVector, ::VarName, flag::String) - throw(ErrorException("VarNamedVector does not support flags; Tried to set $(flag).")) -end - #### #### APIs for typed and untyped VarInfo #### @@ -927,7 +892,7 @@ Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] end istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) -istrans(md::Metadata, vn::VarName) = is_flagged(md, vn, "trans") +istrans(md::Metadata, vn::VarName) = md.trans[getidx(md, vn)] getaccs(vi::VarInfo) = vi.accs setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs @@ -1300,7 +1265,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.flags, + metadata.trans, ), cumulative_logjac end @@ -1475,7 +1440,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.flags, + metadata.trans, ), cumulative_inv_logjac end @@ -1624,7 +1589,7 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) for accname in acckeys(vi) push!(lines, (string(accname), getacc(vi, Val(accname)))) end - push!(lines, ("flags", vi.metadata.flags)) + push!(lines, ("trans", vi.metadata.trans)) max_name_length = maximum(map(length ∘ first, lines)) fmt = Printf.Format("%-$(max_name_length)s") vi_str = ( @@ -1738,7 +1703,7 @@ function Base.push!(meta::Metadata, vn, r, dist) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.flags["trans"], false) + push!(meta.trans, false) return meta end @@ -1751,39 +1716,6 @@ end # Rand & replaying method for VarInfo # ####################################### -""" - is_flagged(vi::VarInfo, vn::VarName, flag::String) - -Check whether `vn` has a true value for `flag` in `vi`. -""" -function is_flagged(vi::VarInfo, vn::VarName, flag::String) - return is_flagged(getmetadata(vi, vn), vn, flag) -end -function is_flagged(metadata::Metadata, vn::VarName, flag::String) - return metadata.flags[flag][getidx(metadata, vn)] -end -function is_flagged(::VarNamedVector, ::VarName, flag::String) - throw(ErrorException("VarNamedVector does not support flags; Tried to read $(flag).")) -end - -""" - unset_flag!(vi::VarInfo, vn::VarName, flag::String - -Set `vn`'s value for `flag` to `false` in `vi`. -""" -function unset_flag!(vi::VarInfo, vn::VarName, flag::String) - unset_flag!(getmetadata(vi, vn), vn, flag) - return vi -end -function unset_flag!(metadata::Metadata, vn::VarName, flag::String) - metadata.flags[flag][getidx(metadata, vn)] = false - return metadata -end - -function unset_flag!(vnv::VarNamedVector, ::VarName, flag::String) - throw(ErrorException("VarNamedVector does not support flags; Tried to unset $(flag).")) -end - # TODO: Maybe rename or something? """ _apply!(kernel!, vi::VarInfo, values, keys) diff --git a/test/varinfo.jl b/test/varinfo.jl index dc09ff8da..c1ef11116 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -48,9 +48,7 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - for flag in keys(meta.flags) - @test meta.flags[flag][ind] == fmeta.flags[flag][tind] - end + @test meta.trans[ind] == fmeta.trans[tind] range = meta.ranges[ind] trange = fmeta.ranges[tind] @test all(meta.vals[range] .== fmeta.vals[trange]) @@ -285,9 +283,8 @@ end @test all_accs_same(vi, vi_orig) end - @testset "flags" begin - # Test flag setting: - # is_flagged, set_flag!, unset_flag! + @testset "trans flag" begin + # Test istrans and settrans!! function test_varinfo!(vi) vn_x = @varname x dist = Normal(0, 1) @@ -296,13 +293,13 @@ end push!!(vi, vn_x, r, dist) # trans is set by default - @test !is_flagged(vi, vn_x, "trans") + @test !istrans(vi, vn_x) - set_flag!(vi, vn_x, "trans") - @test is_flagged(vi, vn_x, "trans") + vi = settrans!!(vi, vn_x, true) + @test istrans(vi, vn_x) - unset_flag!(vi, vn_x, "trans") - @test !is_flagged(vi, vn_x, "trans") + vi = settrans!!!(vi, vn_x, false) + @test !istrans(vi, vn_x) end vi = VarInfo() test_varinfo!(vi) From 1091986c01891ecd5d66f15a91b86ddeadbe2366 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 30 Sep 2025 09:55:02 +0100 Subject: [PATCH 2/5] Fix a bug --- src/varinfo.jl | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 69e4e4589..25005289b 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -20,7 +20,7 @@ Let `md` be an instance of `Metadata`: - `md.dists[md.idcs[vn]]` is the distribution of `vn`. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.trans` is a Bitvector of true/false flags for whether a variable has been transformed. +- `md.trans` is a BitVector of true/false flags for whether a variable has been transformed. `md.trans[md.idcs[vn]]` is the value of `trans` corresponding to `vn`. To make `md::Metadata` type stable, all the `md.vns` must have the same symbol @@ -1663,14 +1663,7 @@ function BangBang.push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) if vi isa NTVarInfo && ~haskey(vi.metadata, sym) # The NamedTuple doesn't have an entry for this variable, let's add one. val = tovec(r) - md = Metadata( - Dict(vn => 1), - [vn], - [1:length(val)], - val, - [dist], - Dict{String,BitVector}("trans" => [false]), - ) + md = Metadata(Dict(vn => 1), [vn], [1:length(val)], val, [dist], BitVector([false])) vi = Accessors.@set vi.metadata[sym] = md else meta = getmetadata(vi, vn) From ece8fb5e4d4e3a15f9a6ade092b1e5b28915ea6e Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 30 Sep 2025 12:21:39 +0100 Subject: [PATCH 3/5] Fix a typo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 25005289b..20c40d73e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -619,7 +619,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - push!(trans, is_trans(metadata_for_vn, vn)) + push!(trans, istrans(metadata_for_vn, vn)) end return Metadata(idcs, vns, ranges, vals, dists, trans) From a011dd64775de7cba43669464664e8addd089e8f Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 30 Sep 2025 12:24:16 +0100 Subject: [PATCH 4/5] Fix two bugs --- src/varinfo.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 20c40d73e..77a82a8ea 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -530,8 +530,9 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va offset = r[end] end - trans = trans[indices_for_vns] - return Metadata(indices, vns, ranges, vals, metadata.dists[indices_for_vns], trans) + dists = metadata.dists[indices_for_vns] + trans = metadata.trans[indices_for_vns] + return Metadata(indices, vns, ranges, vals, dists, trans) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) From 0f8c9b1bbede9a7d44c4ddb56276a869a9f0a4f3 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Tue, 30 Sep 2025 15:15:03 +0100 Subject: [PATCH 5/5] Rename trans to is_transformed --- HISTORY.md | 2 +- docs/src/api.md | 4 +- ext/DynamicPPLEnzymeCoreExt.jl | 5 +- ext/DynamicPPLMooncakeExt.jl | 4 +- src/DynamicPPL.jl | 4 +- src/abstract_varinfo.jl | 34 ++++----- src/contexts/init.jl | 6 +- src/contexts/transformation.jl | 2 +- src/simple_varinfo.jl | 44 ++++++------ src/threadsafe.jl | 14 ++-- src/varinfo.jl | 113 +++++++++++++++--------------- src/varnamedvector.jl | 26 +++---- test/contexts.jl | 2 +- test/ext/DynamicPPLMooncakeExt.jl | 6 +- test/linking.jl | 4 +- test/simple_varinfo.jl | 18 ++--- test/varinfo.jl | 62 ++++++++-------- test/varnamedvector.jl | 4 +- 18 files changed, 181 insertions(+), 173 deletions(-) diff --git a/HISTORY.md b/HISTORY.md index f2c5e9988..6c1114554 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -54,7 +54,7 @@ The separation of these functions was primarily implemented to avoid performing Previously `VarInfo` (or more correctly, the `Metadata` object within a `VarInfo`), had a flag called `"del"` for all variables. If it was set to `true` the variable was to be overwritten with a new value at the next evaluation. The new `InitContext` and related changes above make this flag unnecessary, and it has been removed. -The only other flag, other than `"del"`, that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed. One can simply use `istrans` and a newly exported function called `settrans!!` instead. +The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus the generic functions `set_flag!`, `unset_flag!` and `is_flagged!` have also been removed in favour of more specific ones. We've also used this opportunity to name the `"trans"` flag and the corresponding `istrans` function to be more explicit. The new, exported interface consists of the `is_transformed` and `set_transformed!!` functions. **Other changes** diff --git a/docs/src/api.md b/docs/src/api.md index e9ce257fb..aac2bfaea 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -329,8 +329,8 @@ The [Transformations section below](#Transformations) describes the methods used In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions. ```@docs -istrans -settrans!! +is_transformed +set_transformed!! ``` ```@docs diff --git a/ext/DynamicPPLEnzymeCoreExt.jl b/ext/DynamicPPLEnzymeCoreExt.jl index 0088f8908..35159636f 100644 --- a/ext/DynamicPPLEnzymeCoreExt.jl +++ b/ext/DynamicPPLEnzymeCoreExt.jl @@ -8,8 +8,9 @@ else using ..EnzymeCore end -# Mark istrans as having 0 derivative. The `nothing` return value is not significant, Enzyme +# Mark is_transformed as having 0 derivative. The `nothing` return value is not significant, Enzyme # only checks whether such a method exists, and never runs it. -@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.istrans), args...) = nothing +@inline EnzymeCore.EnzymeRules.inactive(::typeof(DynamicPPL.is_transformed), args...) = + nothing end diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl index f6b352fab..23a3430eb 100644 --- a/ext/DynamicPPLMooncakeExt.jl +++ b/ext/DynamicPPLMooncakeExt.jl @@ -1,9 +1,9 @@ module DynamicPPLMooncakeExt -using DynamicPPL: DynamicPPL, istrans +using DynamicPPL: DynamicPPL, is_transformed using Mooncake: Mooncake # This is purely an optimisation. -Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(istrans),Vararg} +Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(is_transformed),Vararg} end # module diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acfce2f51..a98ffac5b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -70,8 +70,8 @@ export AbstractVarInfo, acclogjac!!, acclogprior!!, accloglikelihood!!, - istrans, - settrans!!, + is_transformed, + set_transformed!!, link, link!!, invlink, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index b3cf77121..7b83efc19 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -769,7 +769,7 @@ end # Transformations """ - istrans(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) + is_transformed(vi::AbstractVarInfo[, vns::Union{VarName, AbstractVector{<:Varname}}]) Return `true` if `vi` is working in unconstrained space, and `false` if `vi` is assuming realizations to be in support of the corresponding distributions. @@ -780,27 +780,27 @@ If `vns` is provided, then only check if this/these varname(s) are transformed. Not all implementations of `AbstractVarInfo` support transforming only a subset of the variables. """ -istrans(vi::AbstractVarInfo) = istrans(vi, collect(keys(vi))) -function istrans(vi::AbstractVarInfo, vns::AbstractVector) - # This used to be: `!isempty(vns) && all(Base.Fix1(istrans, vi), vns)`. +is_transformed(vi::AbstractVarInfo) = is_transformed(vi, collect(keys(vi))) +function is_transformed(vi::AbstractVarInfo, vns::AbstractVector) + # This used to be: `!isempty(vns) && all(Base.Fix1(is_transformed, vi), vns)`. # In theory that should work perfectly fine. For unbeknownst reasons, # Julia 1.10 fails to infer its return type correctly. Thus we use this # slightly longer definition. isempty(vns) && return false for vn in vns - istrans(vi, vn) || return false + is_transformed(vi, vn) || return false end return true end """ - settrans!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) + set_transformed!!(vi::AbstractVarInfo, trans::Bool[, vn::VarName]) -Return `vi` with `istrans(vi, vn)` evaluating to `true`. +Return `vi` with `is_transformed(vi, vn)` evaluating to `true`. -If `vn` is not specified, then `istrans(vi)` evaluates to `true` for all variables. +If `vn` is not specified, then `is_transformed(vi)` evaluates to `true` for all variables. """ -function settrans!! end +function set_transformed!! end # For link!!, invlink!!, link, and invlink, we deliberately do not provide a fallback # method for the case when no `vns` is provided, that would get all the keys from the @@ -833,7 +833,7 @@ function link!!(t::DynamicTransformation, vi::AbstractVarInfo, model::Model) ctx = DynamicTransformationContext{false}() model = contextualize(model, setleafcontext(model.context, ctx)) vi = last(evaluate!!(model, vi)) - return settrans!!(vi, t) + return set_transformed!!(vi, t) end function link!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model @@ -846,7 +846,7 @@ function link!!( if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, logjac) end - return settrans!!(vi, t) + return set_transformed!!(vi, t) end """ @@ -896,7 +896,7 @@ function invlink!!(::DynamicTransformation, vi::AbstractVarInfo, model::Model) ctx = DynamicTransformationContext{true}() model = contextualize(model, setleafcontext(model.context, ctx)) vi = last(evaluate!!(model, vi)) - return settrans!!(vi, NoTransformation()) + return set_transformed!!(vi, NoTransformation()) end function invlink!!( t::StaticTransformation{<:Bijectors.Transform}, vi::AbstractVarInfo, ::Model @@ -912,7 +912,7 @@ function invlink!!( if hasacc(vi, Val(:LogJacobian)) vi = acclogjac!!(vi, inv_logjac) end - return settrans!!(vi, NoTransformation()) + return set_transformed!!(vi, NoTransformation()) end """ @@ -1020,7 +1020,7 @@ function unflatten end """ to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) -Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. +Return reconstructed `val`, possibly linked if `is_transformed(vi, vn)` is `true`. """ function to_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) f = to_maybe_linked_internal_transform(vi, vn, dist) @@ -1030,7 +1030,7 @@ end """ from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) -Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. +Return reconstructed `val`, possibly invlinked if `is_transformed(vi, vn)` is `true`. """ function from_maybe_linked_internal(vi::AbstractVarInfo, vn::VarName, dist, val) f = from_maybe_linked_internal_transform(vi, vn, dist) @@ -1087,14 +1087,14 @@ in `varinfo` to a representation compatible with `dist`. If `dist` is not present, then it is assumed that `varinfo` knows the correct output for `vn`. """ function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName, dist) - return if istrans(varinfo, vn) + return if is_transformed(varinfo, vn) from_linked_internal_transform(varinfo, vn, dist) else from_internal_transform(varinfo, vn, dist) end end function from_maybe_linked_internal_transform(varinfo::AbstractVarInfo, vn::VarName) - return if istrans(varinfo, vn) + return if is_transformed(varinfo, vn) from_linked_internal_transform(varinfo, vn) else from_internal_transform(varinfo, vn) diff --git a/src/contexts/init.jl b/src/contexts/init.jl index 4baca1b57..6ed33489c 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -165,9 +165,9 @@ function tilde_assume!!( # If the VarInfo alrady had a value for this variable, we will # keep the same linked status as in the original VarInfo. If not, we # check the rest of the VarInfo to see if other variables are linked. - # istrans(vi) returns true if vi is nonempty and all variables in vi + # is_transformed(vi) returns true if vi is nonempty and all variables in vi # are linked. - insert_transformed_value = in_varinfo ? istrans(vi, vn) : istrans(vi) + insert_transformed_value = in_varinfo ? is_transformed(vi, vn) : is_transformed(vi) f = if insert_transformed_value link_transform(dist) else @@ -183,7 +183,7 @@ function tilde_assume!!( end # Neither of these set the `trans` flag so we have to do it manually if # necessary. - insert_transformed_value && settrans!!(vi, true, vn) + insert_transformed_value && set_transformed!!(vi, true, vn) # `accumulate_assume!!` wants untransformed values as the second argument. vi = accumulate_assume!!(vi, x, logjac, vn, dist) # We always return the untransformed value here, as that will determine diff --git a/src/contexts/transformation.jl b/src/contexts/transformation.jl index 720fa978f..5153f7857 100644 --- a/src/contexts/transformation.jl +++ b/src/contexts/transformation.jl @@ -21,7 +21,7 @@ function tilde_assume!!( # vi[vn, right] always provides the value in unlinked space. x = vi[vn, right] - if istrans(vi, vn) + if is_transformed(vi, vn) isinverse || @warn "Trying to link an already transformed variable ($vn)" else isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index f430755e7..79aa81983 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -96,23 +96,23 @@ julia> _, vi = DynamicPPL.init!!(rng, m, SimpleVarInfo()); julia> vi[@varname(x)] # (✓) 0 ≤ x < ∞ 1.8632965762164932 -julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)); +julia> _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ -0.21080155351918753 -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true julia> # And with `OrderedDict` of course! - _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); + _, vi = DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(OrderedDict{VarName,Any}()), true)); julia> vi[@varname(x)] # (✓) -∞ < x < ∞ 0.6225185067787314 -julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.settrans!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; +julia> xs = [last(DynamicPPL.init!!(rng, m, DynamicPPL.set_transformed!!(SimpleVarInfo(), true)))[@varname(x)] for i = 1:10]; julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers! true @@ -121,7 +121,7 @@ true Evaluation in transformed space of course also works: ```jldoctest simplevarinfo-general -julia> vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), true) +julia> vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), true) Transformed SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) julia> # (✓) Positive probability mass on negative numbers! @@ -129,7 +129,7 @@ julia> # (✓) Positive probability mass on negative numbers! -1.3678794411714423 julia> # While if we forget to indicate that it's transformed: - vi = DynamicPPL.settrans!!(SimpleVarInfo((x = -1.0,)), false) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo((x = -1.0,)), false) SimpleVarInfo((x = -1.0,), (LogPrior = LogPriorAccumulator(0.0), LogJacobian = LogJacobianAccumulator(0.0), LogLikelihood = LogLikelihoodAccumulator(0.0))) julia> # (✓) No probability mass on negative numbers! @@ -466,32 +466,34 @@ function Base.merge(varinfo_left::SimpleVarInfo, varinfo_right::SimpleVarInfo) return SimpleVarInfo(values, accs, transformation) end -function settrans!!(vi::SimpleVarInfo, trans) - return settrans!!(vi, trans ? DynamicTransformation() : NoTransformation()) +function set_transformed!!(vi::SimpleVarInfo, trans) + return set_transformed!!(vi, trans ? DynamicTransformation() : NoTransformation()) end -function settrans!!(vi::SimpleVarInfo, transformation::AbstractTransformation) +function set_transformed!!(vi::SimpleVarInfo, transformation::AbstractTransformation) return Accessors.@set vi.transformation = transformation end -function settrans!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans) +function set_transformed!!(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, trans) + return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, trans) end -function settrans!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) +function set_transformed!!(vi::SimpleOrThreadSafeSimple, trans::Bool, ::VarName) # We keep this method around just to obey the AbstractVarInfo interface. # However, note that this would only be a valid operation if it would be a # no-op, which we check here. - if trans != istrans(vi) + if trans != is_transformed(vi) error( - "Individual variables in SimpleVarInfo cannot have different `settrans` statuses.", + "Individual variables in SimpleVarInfo cannot have different `set_transformed` statuses.", ) end end -istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) -istrans(vi::SimpleVarInfo, ::VarName) = istrans(vi) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = istrans(vi.varinfo) +is_transformed(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) +is_transformed(vi::SimpleVarInfo, ::VarName) = is_transformed(vi) +function is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) + return is_transformed(vi.varinfo, vn) +end +is_transformed(vi::ThreadSafeVarInfo{<:SimpleVarInfo}) = is_transformed(vi.varinfo) -islinked(vi::SimpleVarInfo) = istrans(vi) +islinked(vi::SimpleVarInfo) = is_transformed(vi) values_as(vi::SimpleVarInfo) = vi.values values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values @@ -618,7 +620,7 @@ function link!!( if hasacc(vi_new, Val(:LogJacobian)) vi_new = acclogjac!!(vi_new, logjac) end - return settrans!!(vi_new, t) + return set_transformed!!(vi_new, t) end function invlink!!( @@ -636,7 +638,7 @@ function invlink!!( if hasacc(vi_new, Val(:LogJacobian)) vi_new = acclogjac!!(vi_new, inv_logjac) end - return settrans!!(vi_new, NoTransformation()) + return set_transformed!!(vi_new, NoTransformation()) end # With `SimpleVarInfo`, when we're not working with linked variables, there's no need to do anything. diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 473fe67bf..feaefff2e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -106,14 +106,14 @@ function link!!(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) model = contextualize( model, setleafcontext(model.context, DynamicTransformationContext{false}()) ) - return settrans!!(last(evaluate!!(model, vi)), t) + return set_transformed!!(last(evaluate!!(model, vi)), t) end function invlink!!(::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) model = contextualize( model, setleafcontext(model.context, DynamicTransformationContext{true}()) ) - return settrans!!(last(evaluate!!(model, vi)), NoTransformation()) + return set_transformed!!(last(evaluate!!(model, vi)), NoTransformation()) end function link(t::DynamicTransformation, vi::ThreadSafeVarInfo, model::Model) @@ -185,12 +185,14 @@ end values_as(vi::ThreadSafeVarInfo) = values_as(vi.varinfo) values_as(vi::ThreadSafeVarInfo, ::Type{T}) where {T} = values_as(vi.varinfo, T) -function settrans!!(vi::ThreadSafeVarInfo, trans::Bool, vn::VarName) - return Accessors.@set vi.varinfo = settrans!!(vi.varinfo, trans, vn) +function set_transformed!!(vi::ThreadSafeVarInfo, val::Bool, vn::VarName) + return Accessors.@set vi.varinfo = set_transformed!!(vi.varinfo, val, vn) end -istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) -istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) +is_transformed(vi::ThreadSafeVarInfo, vn::VarName) = is_transformed(vi.varinfo, vn) +function is_transformed(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) + return is_transformed(vi.varinfo, vns) +end getindex_internal(vi::ThreadSafeVarInfo, vn::VarName) = getindex_internal(vi.varinfo, vn) diff --git a/src/varinfo.jl b/src/varinfo.jl index 77a82a8ea..d47aa8243 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -15,13 +15,13 @@ not. Let `md` be an instance of `Metadata`: - `md.vns` is the vector of all `VarName` instances. - `md.idcs` is the dictionary that maps each `VarName` instance to its index in - `md.vns`, `md.ranges` `md.dists`, and `md.trans`. + `md.vns`, `md.ranges` `md.dists`, and `md.is_transformed`. - `md.vns[md.idcs[vn]] == vn`. - `md.dists[md.idcs[vn]]` is the distribution of `vn`. - `md.ranges[md.idcs[vn]]` is the index range of `vn` in `md.vals`. - `md.vals[md.ranges[md.idcs[vn]]]` is the vector of values of corresponding to `vn`. -- `md.trans` is a BitVector of true/false flags for whether a variable has been transformed. - `md.trans[md.idcs[vn]]` is the value of `trans` corresponding to `vn`. +- `md.is_transformed` is a BitVector of true/false flags for whether a variable has been + transformed. `md.is_transformed[md.idcs[vn]]` is the value corresponding to `vn`. To make `md::Metadata` type stable, all the `md.vns` must have the same symbol and distribution type. However, one can have a Julia variable, say `x`, that is a @@ -56,7 +56,7 @@ struct Metadata{ # Vector of distributions correpsonding to `vns` dists::TDists # AbstractVector{<:Distribution} - trans::BitVector + is_transformed::BitVector end function Base.:(==)(md1::Metadata, md2::Metadata) @@ -66,7 +66,7 @@ function Base.:(==)(md1::Metadata, md2::Metadata) md1.ranges == md2.ranges && md1.vals == md2.vals && md1.dists == md2.dists && - md1.trans == md2.trans + md1.is_transformed == md2.is_transformed ) end @@ -171,8 +171,8 @@ function metadata_to_varnamedvector(md::Metadata) vns = copy(md.vns) ranges = copy(md.ranges) vals = copy(md.vals) - is_unconstrained = map(Base.Fix1(istrans, md), md.vns) - transforms = map(md.dists, is_unconstrained) do dist, trans + is_trans = map(Base.Fix1(is_transformed, md), md.vns) + transforms = map(md.dists, is_trans) do dist, trans if trans return from_linked_vec_transform(dist) else @@ -181,12 +181,7 @@ function metadata_to_varnamedvector(md::Metadata) end return VarNamedVector( - OrderedDict{eltype(keys(idcs)),Int}(idcs), - vns, - ranges, - vals, - transforms, - is_unconstrained, + OrderedDict{eltype(keys(idcs)),Int}(idcs), vns, ranges, vals, transforms, is_trans ) end @@ -245,8 +240,8 @@ function typed_varinfo(vi::UntypedVarInfo) sym_idcs = Dict(a => i for (i, a) in enumerate(sym_vns)) # New dists sym_dists = getindex.((meta.dists,), inds) - # New trans - sym_trans = meta.trans[inds] + # New is_transformed + sym_is_transformed = meta.is_transformed[inds] # Extract new ranges and vals _ranges = getindex.((meta.ranges,), inds) @@ -262,7 +257,9 @@ function typed_varinfo(vi::UntypedVarInfo) push!( new_metas, - Metadata(sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_trans), + Metadata( + sym_idcs, sym_vns, sym_ranges, sym_vals, sym_dists, sym_is_transformed + ), ) end nt = NamedTuple{syms_tuple}(Tuple(new_metas)) @@ -405,7 +402,7 @@ end end function unflatten_metadata(md::Metadata, x::AbstractVector) - return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.trans) + return Metadata(md.idcs, md.vns, md.ranges, x, md.dists, md.is_transformed) end unflatten_metadata(vnv::VarNamedVector, x::AbstractVector) = unflatten(vnv, x) @@ -421,7 +418,7 @@ Construct an empty type unstable instance of `Metadata`. """ function Metadata() vals = Vector{Real}() - trans = BitVector() + is_transformed = BitVector() return Metadata( Dict{VarName,Int}(), @@ -429,7 +426,7 @@ function Metadata() Vector{UnitRange{Int}}(), vals, Vector{Distribution}(), - trans, + is_transformed, ) end @@ -446,7 +443,7 @@ function empty!(meta::Metadata) empty!(meta.ranges) empty!(meta.vals) empty!(meta.dists) - empty!(meta.trans) + empty!(meta.is_transformed) return meta end @@ -531,8 +528,8 @@ function subset(metadata::Metadata, vns_given::AbstractVector{VN}) where {VN<:Va end dists = metadata.dists[indices_for_vns] - trans = metadata.trans[indices_for_vns] - return Metadata(indices, vns, ranges, vals, dists, trans) + is_transformed = metadata.is_transformed[indices_for_vns] + return Metadata(indices, vns, ranges, vals, dists, is_transformed) end function Base.merge(varinfo_left::VarInfo, varinfo_right::VarInfo) @@ -603,7 +600,7 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) ranges = Vector{UnitRange{Int}}() vals = T[] dists = D[] - trans = BitVector() + transformed = BitVector() # Range offset. offset = 0 @@ -620,10 +617,10 @@ function merge_metadata(metadata_left::Metadata, metadata_right::Metadata) offset = r[end] dist = getdist(metadata_for_vn, vn) push!(dists, dist) - push!(trans, istrans(metadata_for_vn, vn)) + push!(transformed, is_transformed(metadata_for_vn, vn)) end - return Metadata(idcs, vns, ranges, vals, dists, trans) + return Metadata(idcs, vns, ranges, vals, dists, transformed) end const VarView = Union{Int,UnitRange,Vector{Int}} @@ -792,30 +789,30 @@ function setval!(md::Metadata, val, vn::VarName) return md.vals[getrange(md, vn)] = tovec(val) end -function settrans!!(vi::VarInfo, trans::Bool, vn::VarName) - settrans!!(getmetadata(vi, vn), trans, vn) +function set_transformed!!(vi::VarInfo, val::Bool, vn::VarName) + set_transformed!!(getmetadata(vi, vn), val, vn) return vi end -function settrans!!(metadata::Metadata, trans::Bool, vn::VarName) - metadata.trans[getidx(metadata, vn)] = trans +function set_transformed!!(metadata::Metadata, val::Bool, vn::VarName) + metadata.is_transformed[getidx(metadata, vn)] = val return metadata end -function settrans!!(vi::VarInfo, trans::Bool) +function set_transformed!!(vi::VarInfo, val::Bool) for vn in keys(vi) - settrans!!(vi, trans, vn) + set_transformed!!(vi, val, vn) end return vi end -settrans!!(vi::VarInfo, trans::NoTransformation) = settrans!!(vi, false) +set_transformed!!(vi::VarInfo, ::NoTransformation) = set_transformed!!(vi, false) # HACK: This is necessary to make something like `link!!(transformation, vi, model)` # work properly, which will transform the variables according to `transformation` -# and then call `settrans!!(vi, transformation)`. An alternative would be to add +# and then call `set_transformed!!(vi, transformation)`. An alternative would be to add # the `transformation` to the `VarInfo` object, but at the moment doesn't seem # worth it as `VarInfo` has its own way of handling transformations. -settrans!!(vi::VarInfo, trans::AbstractTransformation) = settrans!!(vi, true) +set_transformed!!(vi::VarInfo, ::AbstractTransformation) = set_transformed!!(vi, true) """ syms(vi::VarInfo) @@ -892,8 +889,8 @@ Base.keys(vi::NTVarInfo{<:NamedTuple{()}}) = VarName[] return expr end -istrans(vi::VarInfo, vn::VarName) = istrans(getmetadata(vi, vn), vn) -istrans(md::Metadata, vn::VarName) = md.trans[getidx(md, vn)] +is_transformed(vi::VarInfo, vn::VarName) = is_transformed(getmetadata(vi, vn), vn) +is_transformed(md::Metadata, vn::VarName) = md.is_transformed[getidx(md, vn)] getaccs(vi::VarInfo) = vi.accs setaccs!!(vi::VarInfo, accs::AccumulatorTuple) = Accessors.@set vi.accs = accs @@ -948,11 +945,11 @@ end function _link!!(vi::UntypedVarInfo, vns) # TODO: Change to a lazy iterator over `vns` - if ~istrans(vi, vns[1]) + if ~is_transformed(vi, vns[1]) for vn in vns f = internal_to_linked_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, true, vn) + vi = set_transformed!!(vi, true, vn) end return vi else @@ -993,12 +990,12 @@ end f_vns = vi.metadata.$f.vns f_vns = filter_subsumed(vns.$f, f_vns) if !isempty(f_vns) - if !istrans(vi, f_vns[1]) + if !is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = internal_to_linked_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, true, vn) + vi = set_transformed!!(vi, true, vn) end else @warn("[DynamicPPL] attempt to link a linked vi") @@ -1054,17 +1051,17 @@ end function maybe_invlink_before_eval!!(vi::VarInfo, model::Model) # Because `VarInfo` does not contain any information about what the transformation # other than whether or not it has actually been transformed, the best we can do - # is just assume that `default_transformation` is the correct one if `istrans(vi)`. - t = istrans(vi) ? default_transformation(model, vi) : NoTransformation() + # is just assume that `default_transformation` is the correct one if `is_transformed(vi)`. + t = is_transformed(vi) ? default_transformation(model, vi) : NoTransformation() return maybe_invlink_before_eval!!(t, vi, model) end function _invlink!!(vi::UntypedVarInfo, vns) - if istrans(vi, vns[1]) + if is_transformed(vi, vns[1]) for vn in vns f = linked_internal_to_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, false, vn) + vi = set_transformed!!(vi, false, vn) end return vi else @@ -1096,12 +1093,12 @@ end quote f_vns = vi.metadata.$f.vns f_vns = filter_subsumed(vns.$f, f_vns) - if istrans(vi, f_vns[1]) + if is_transformed(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns f = linked_internal_to_internal_transform(vi, vn) vi = _inner_transform!(vi, vn, f) - vi = settrans!!(vi, false, vn) + vi = set_transformed!!(vi, false, vn) end else @warn("[DynamicPPL] attempt to invlink an invlinked vi") @@ -1231,7 +1228,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ vals_new = map(vns) do vn # Return early if we're already in unconstrained space. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) + if is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -1245,7 +1242,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ # Accumulate the log-abs-det jacobian correction. cumulative_logjac += logjac # Mark as transformed. - settrans!!(varinfo, true, vn) + set_transformed!!(varinfo, true, vn) # Return the vectorized transformed value. return yvec end @@ -1266,7 +1263,7 @@ function _link_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, target_ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.trans, + metadata.is_transformed, ), cumulative_logjac end @@ -1291,7 +1288,7 @@ function _link_metadata!!( # Fix this when attending to issue #653. cumulative_logjac += logjac1 + logjac2 metadata = setindex_internal!!(metadata, val_new, vn, transform_from_linked) - settrans!(metadata, true, vn) + set_transformed!(metadata, true, vn) end return metadata, cumulative_logjac end @@ -1406,7 +1403,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Return early if we're already in constrained space OR if we're not # supposed to touch this `vn`. # HACK: if `target_vns` is `nothing`, we ignore the `target_vns` check. - if !istrans(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) + if !is_transformed(varinfo, vn) || (target_vns !== nothing && vn ∉ target_vns) return metadata.vals[getrange(metadata, vn)] end @@ -1420,7 +1417,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ # Accumulate the log-abs-det jacobian correction. cumulative_inv_logjac += inv_logjac # Mark as no longer transformed. - settrans!!(varinfo, false, vn) + set_transformed!!(varinfo, false, vn) # Return the vectorized transformed value. return xvec end @@ -1441,7 +1438,7 @@ function _invlink_metadata!!(::Model, varinfo::VarInfo, metadata::Metadata, targ ranges_new, reduce(vcat, vals_new), metadata.dists, - metadata.trans, + metadata.is_transformed, ), cumulative_inv_logjac end @@ -1459,7 +1456,7 @@ function _invlink_metadata!!( cumulative_inv_logjac += inv_logjac new_transform = from_vec_transform(new_val) metadata = setindex_internal!!(metadata, tovec(new_val), vn, new_transform) - settrans!(metadata, false, vn) + set_transformed!(metadata, false, vn) end return metadata, cumulative_inv_logjac end @@ -1483,7 +1480,7 @@ If some but only some of the variables in `vi` are linked, this function will re This behavior will likely change in the future. """ function islinked(vi::VarInfo) - return any(istrans(vi, vn) for vn in keys(vi)) + return any(is_transformed(vi, vn) for vn in keys(vi)) end # The default getindex & setindex!() for get & set values @@ -1590,7 +1587,7 @@ function Base.show(io::IO, ::MIME"text/plain", vi::UntypedVarInfo) for accname in acckeys(vi) push!(lines, (string(accname), getacc(vi, Val(accname)))) end - push!(lines, ("trans", vi.metadata.trans)) + push!(lines, ("is_transformed", vi.metadata.is_transformed)) max_name_length = maximum(map(length ∘ first, lines)) fmt = Printf.Format("%-$(max_name_length)s") vi_str = ( @@ -1697,7 +1694,7 @@ function Base.push!(meta::Metadata, vn, r, dist) push!(meta.ranges, (l + 1):(l + n)) append!(meta.vals, val) push!(meta.dists, dist) - push!(meta.trans, false) + push!(meta.is_transformed, false) return meta end @@ -1844,7 +1841,7 @@ function _setval_kernel!(vi::VarInfoOrThreadSafeVarInfo, vn::VarName, values, ke if !isempty(indices) val = reduce(vcat, values[indices]) setval!(vi, val, vn) - settrans!!(vi, false, vn) + set_transformed!!(vi, false, vn) end return indices diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index 2336b89b6..4b2791d19 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -322,28 +322,28 @@ getrange(vnv::VarNamedVector, vn::VarName) = getrange(vnv, getidx(vnv, vn)) gettransform(vnv::VarNamedVector, idx::Int) = vnv.transforms[idx] gettransform(vnv::VarNamedVector, vn::VarName) = gettransform(vnv, getidx(vnv, vn)) -# TODO(mhauru) Eventually I would like to rename the istrans function to is_unconstrained, -# but that's significantly breaking. +# TODO(mhauru) Eventually I would like to rename the is_transformed function to +# is_unconstrained, but that's significantly breaking. """ - istrans(vnv::VarNamedVector, vn::VarName) + is_transformed(vnv::VarNamedVector, vn::VarName) Return a boolean for whether `vn` is guaranteed to have been transformed so that its domain is all of Euclidean space. """ -istrans(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] +is_transformed(vnv::VarNamedVector, vn::VarName) = vnv.is_unconstrained[getidx(vnv, vn)] """ - settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) + set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) Set the value for whether `vn` is guaranteed to have been transformed so that all of Euclidean space is its domain. """ -function settrans!(vnv::VarNamedVector, val::Bool, vn::VarName) +function set_transformed!(vnv::VarNamedVector, val::Bool, vn::VarName) return vnv.is_unconstrained[vnv.varname_to_index[vn]] = val end -function settrans!!(vnv::VarNamedVector, val::Bool, vn::VarName) - settrans!(vnv, val, vn) +function set_transformed!!(vnv::VarNamedVector, val::Bool, vn::VarName) + set_transformed!(vnv, val, vn) return vnv end @@ -548,7 +548,7 @@ julia> vnv[@varname(x)] function reset!(vnv::VarNamedVector, val, vn::VarName) f = from_vec_transform(val) retval = setindex_internal!(vnv, tovec(val), vn, f) - settrans!(vnv, false, vn) + set_transformed!(vnv, false, vn) return retval end @@ -902,7 +902,7 @@ end function reset!!(vnv::VarNamedVector, val, vn::VarName) f = from_vec_transform(val) vnv = setindex_internal!!(vnv, tovec(val), vn, f) - vnv = settrans!!(vnv, false, vn) + vnv = set_transformed!!(vnv, false, vn) return vnv end @@ -1098,13 +1098,13 @@ function Base.merge(left_vnv::VarNamedVector, right_vnv::VarNamedVector) # `vn` is only in `left`. val = getindex_internal(left_vnv, vn) f = gettransform(left_vnv, vn) - is_unconstrained[idx] = istrans(left_vnv, vn) + is_unconstrained[idx] = is_transformed(left_vnv, vn) else # `vn` is either in both or just `right`. # Note that in a `merge` the right value has precedence. val = getindex_internal(right_vnv, vn) f = gettransform(right_vnv, vn) - is_unconstrained[idx] = istrans(right_vnv, vn) + is_unconstrained[idx] = is_transformed(right_vnv, vn) end n = length(val) r = (offset + 1):(offset + n) @@ -1153,7 +1153,7 @@ function subset(vnv::VarNamedVector, vns_given::AbstractVector{<:VarName}) for vn in vnv.varnames if any(subsumes(vn_given, vn) for vn_given in vns_given) insert_internal!(vnv_new, getindex_internal(vnv, vn), vn, gettransform(vnv, vn)) - settrans!(vnv_new, istrans(vnv, vn), vn) + set_transformed!(vnv_new, is_transformed(vnv, vn), vn) end end diff --git a/test/contexts.jl b/test/contexts.jl index 2687c4336..972d833a5 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -504,7 +504,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() vi = VarInfo(model) linked_vi = DynamicPPL.link!!(vi, model) _, new_vi = DynamicPPL.init!!(model, linked_vi, strategy) - @test DynamicPPL.istrans(new_vi) + @test DynamicPPL.is_transformed(new_vi) # this is the unlinked value, since it uses `getindex` a = new_vi[@varname(a)] # internal logjoint should correspond to the transformed value diff --git a/test/ext/DynamicPPLMooncakeExt.jl b/test/ext/DynamicPPLMooncakeExt.jl index 986057da0..971956542 100644 --- a/test/ext/DynamicPPLMooncakeExt.jl +++ b/test/ext/DynamicPPLMooncakeExt.jl @@ -1,5 +1,9 @@ @testset "DynamicPPLMooncakeExt" begin Mooncake.TestUtils.test_rule( - StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true, interface_only=true + StableRNG(123456), + is_transformed, + VarInfo(); + unsafe_perturb=true, + interface_only=true, ) end diff --git a/test/linking.jl b/test/linking.jl index cae101c72..2047b9d11 100644 --- a/test/linking.jl +++ b/test/linking.jl @@ -50,9 +50,9 @@ end # Specify the link-transform to use. Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) -function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Bool) +function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, is_transformed::Bool) lp = logpdf(dist, x) - if istrans + if is_transformed lp = lp - logabsdetjac(bijector(dist), x) end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 01cbfc593..488cb8941 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -150,9 +150,9 @@ ("Dict", svi_dict), ("VarNamedVector", svi_vnv), # TODO(mhauru) Fix linked SimpleVarInfos to work with our test models. - # DynamicPPL.settrans!!(deepcopy(svi_nt), true), - # DynamicPPL.settrans!!(deepcopy(svi_dict), true), - # DynamicPPL.settrans!!(deepcopy(svi_vnv), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_nt), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_dict), true), + # DynamicPPL.set_transformed!!(deepcopy(svi_vnv), true), ) # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. @@ -172,7 +172,7 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.rand_prior_true(model) - if DynamicPPL.istrans(svi) + if DynamicPPL.is_transformed(svi) _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... ) @@ -227,9 +227,11 @@ model = DynamicPPL.TestUtils.demo_dynamic_constraint() # Initialize. - svi_nt = DynamicPPL.settrans!!(SimpleVarInfo(), true) + svi_nt = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) svi_nt = last(DynamicPPL.init!!(model, svi_nt)) - svi_vnv = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) + svi_vnv = DynamicPPL.set_transformed!!( + SimpleVarInfo(DynamicPPL.VarNamedVector()), true + ) svi_vnv = last(DynamicPPL.init!!(model, svi_vnv)) for svi in (svi_nt, svi_vnv) @@ -270,13 +272,13 @@ vi_linked = DynamicPPL.link!!(vi, model) # Make sure `maybe_invlink_before_eval!!` results in `invlink!!`. - @test !DynamicPPL.istrans( + @test !DynamicPPL.is_transformed( DynamicPPL.maybe_invlink_before_eval!!(deepcopy(vi), model) ) # Resulting varinfo should no longer be transformed. vi_result = last(DynamicPPL.init!!(model, deepcopy(vi))) - @test !DynamicPPL.istrans(vi_result) + @test !DynamicPPL.is_transformed(vi_result) # Set the values to something that is out of domain if we're in constrained space. for vn in keys(vi) diff --git a/test/varinfo.jl b/test/varinfo.jl index c1ef11116..5b541e1dd 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -48,7 +48,7 @@ end ind = meta.idcs[vn] tind = fmeta.idcs[vn] @test meta.dists[ind] == fmeta.dists[tind] - @test meta.trans[ind] == fmeta.trans[tind] + @test meta.is_transformed[ind] == fmeta.is_transformed[tind] range = meta.ranges[ind] trange = fmeta.ranges[tind] @test all(meta.vals[range] .== fmeta.vals[trange]) @@ -283,8 +283,8 @@ end @test all_accs_same(vi, vi_orig) end - @testset "trans flag" begin - # Test istrans and settrans!! + @testset "is_transformed flag" begin + # Test is_transformed and set_transformed!! function test_varinfo!(vi) vn_x = @varname x dist = Normal(0, 1) @@ -292,14 +292,14 @@ end push!!(vi, vn_x, r, dist) - # trans is set by default - @test !istrans(vi, vn_x) + # is_transformed is set by default + @test !is_transformed(vi, vn_x) - vi = settrans!!(vi, vn_x, true) - @test istrans(vi, vn_x) + vi = set_transformed!!(vi, true, vn_x) + @test is_transformed(vi, vn_x) - vi = settrans!!!(vi, vn_x, false) - @test !istrans(vi, vn_x) + vi = set_transformed!!(vi, false, vn_x) + @test !is_transformed(vi, vn_x) end vi = VarInfo() test_varinfo!(vi) @@ -483,14 +483,14 @@ end vi = VarInfo() meta = vi.metadata _, vi = DynamicPPL.init!!(model, vi, InitFromUniform()) - @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !is_transformed(vi, x), meta.vns) - # Check that linking and invlinking set the `trans` flag accordingly + # Check that linking and invlinking set the `is_transformed` flag accordingly v = copy(meta.vals) vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.vns) + @test all(x -> is_transformed(vi, x), meta.vns) vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.vns) + @test all(x -> !is_transformed(vi, x), meta.vns) @test meta.vals ≈ v atol = 1e-10 # Check that linking and invlinking preserves the values @@ -501,14 +501,14 @@ end v_x = copy(meta.x.vals) v_y = copy(meta.y.vals) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> !is_transformed(vi, x), meta.s.vns) + @test all(x -> !is_transformed(vi, x), meta.m.vns) vi = link!!(vi, model) - @test all(x -> istrans(vi, x), meta.s.vns) - @test all(x -> istrans(vi, x), meta.m.vns) + @test all(x -> is_transformed(vi, x), meta.s.vns) + @test all(x -> is_transformed(vi, x), meta.m.vns) vi = invlink!!(vi, model) - @test all(x -> !istrans(vi, x), meta.s.vns) - @test all(x -> !istrans(vi, x), meta.m.vns) + @test all(x -> !is_transformed(vi, x), meta.s.vns) + @test all(x -> !is_transformed(vi, x), meta.m.vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 @@ -527,10 +527,10 @@ end @test !isempty(target_vns) @test !isempty(other_vns) vi = link!!(vi, (vn,), model) - @test all(x -> istrans(vi, x), target_vns) - @test all(x -> !istrans(vi, x), other_vns) + @test all(x -> is_transformed(vi, x), target_vns) + @test all(x -> !is_transformed(vi, x), other_vns) vi = invlink!!(vi, (vn,), model) - @test all(x -> !istrans(vi, x), all_vns) + @test all(x -> !is_transformed(vi, x), all_vns) @test meta.s.vals ≈ v_s atol = 1e-10 @test meta.m.vals ≈ v_m atol = 1e-10 @test meta.x.vals ≈ v_x atol = 1e-10 @@ -549,7 +549,7 @@ end vi = last(DynamicPPL.init!!(model, vi, InitFromPrior())) f = DynamicPPL.from_linked_internal_transform(vi, vn, dist) x = f(DynamicPPL.getindex_internal(vi, vn)) - @test istrans(vi, vn) + @test is_transformed(vi, vn) @test getlogjoint_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @test getlogprior_internal(vi) ≈ Bijectors.logpdf_with_trans(dist, x, true) @test getloglikelihood(vi) == 0.0 @@ -564,25 +564,25 @@ end ## `untyped_varinfo` vi = DynamicPPL.untyped_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) + vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) ## `typed_varinfo` vi = DynamicPPL.typed_varinfo(model) - vi = DynamicPPL.settrans!!(vi, true, vn) + vi = DynamicPPL.set_transformed!!(vi, true, vn) test_linked_varinfo(model, vi) ### `SimpleVarInfo` ## `SimpleVarInfo{<:NamedTuple}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:Dict}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(Dict{VarName,Any}()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(Dict{VarName,Any}()), true) test_linked_varinfo(model, vi) ## `SimpleVarInfo{<:VarNamedVector}` - vi = DynamicPPL.settrans!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) + vi = DynamicPPL.set_transformed!!(SimpleVarInfo(DynamicPPL.VarNamedVector()), true) test_linked_varinfo(model, vi) end @@ -670,7 +670,7 @@ end DynamicPPL.link(varinfo, model) end for vn in keys(varinfo) - @test DynamicPPL.istrans(varinfo_linked, vn) + @test DynamicPPL.is_transformed(varinfo_linked, vn) end @test length(varinfo[:]) > length(varinfo_linked[:]) varinfo_linked_unflattened = DynamicPPL.unflatten( @@ -928,7 +928,7 @@ end varinfo_left = VarInfo(model_left) varinfo_right = VarInfo(model_right) - varinfo_right = DynamicPPL.settrans!!(varinfo_right, true, @varname(x)) + varinfo_right = DynamicPPL.set_transformed!!(varinfo_right, true, @varname(x)) varinfo_merged = merge(varinfo_left, varinfo_right) vns = [@varname(x), @varname(y), @varname(z)] @@ -936,7 +936,7 @@ end # Right has precedence. @test varinfo_merged[@varname(x)] == varinfo_right[@varname(x)] - @test DynamicPPL.istrans(varinfo_merged, @varname(x)) + @test DynamicPPL.is_transformed(varinfo_merged, @varname(x)) end end diff --git a/test/varnamedvector.jl b/test/varnamedvector.jl index af24be86f..3fd76ffe2 100644 --- a/test/varnamedvector.jl +++ b/test/varnamedvector.jl @@ -570,9 +570,9 @@ end vn = @varname(t[1]) vns = vcat(test_vns, [vn]) vnv = DynamicPPL.setindex_internal!!(vnv, [2.0], vn, x -> x .^ 2) - DynamicPPL.settrans!(vnv, true, @varname(t[1])) + DynamicPPL.set_transformed!(vnv, true, @varname(t[1])) @test vnv[@varname(t[1])] == [4.0] - @test istrans(vnv, @varname(t[1])) + @test is_transformed(vnv, @varname(t[1])) @test subset(vnv, vns) == vnv end end