Skip to content

Commit 169dda2

Browse files
authored
Merge branch 'master' into patch-3
2 parents 0811bef + 8972b98 commit 169dda2

31 files changed

+861
-1098
lines changed

.github/workflows/CI.yml

+11-1
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@ permissions:
1717
actions: write
1818
contents: read
1919

20+
# Cancel existing tests on the same PR if a new commit is added to a pull request
21+
concurrency:
22+
group: ${{ github.workflow }}-${{ github.ref || github.run_id }}
23+
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
24+
2025
jobs:
2126
test:
2227
runs-on: ${{ matrix.runner.os }}
2328
strategy:
29+
fail-fast: false
30+
2431
matrix:
2532
runner:
2633
# Current stable version
@@ -58,6 +65,9 @@ jobs:
5865
os: macos-latest
5966
arch: aarch64
6067
num_threads: 2
68+
test_group:
69+
- Group1
70+
- Group2
6171

6272
steps:
6373
- uses: actions/checkout@v4
@@ -73,7 +83,7 @@ jobs:
7383

7484
- uses: julia-actions/julia-runtest@v1
7585
env:
76-
GROUP: All
86+
GROUP: ${{ matrix.test_group }}
7787
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}
7888

7989
- uses: julia-actions/julia-processcoverage@v1

.github/workflows/CompatHelper.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ jobs:
1414
env:
1515
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
1616
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
17-
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
17+
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'

.github/workflows/JuliaPre.yml

-2
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,3 @@ jobs:
2525
- uses: julia-actions/cache@v2
2626
- uses: julia-actions/julia-buildpkg@v1
2727
- uses: julia-actions/julia-runtest@v1
28-
env:
29-
GROUP: DynamicPPL

Project.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.31.4"
3+
version = "0.32.0"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -29,6 +29,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2929
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3030
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
3131
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
32+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3233
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
3334
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3435
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
@@ -37,6 +38,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
3738
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3839
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3940
DynamicPPLForwardDiffExt = ["ForwardDiff"]
41+
DynamicPPLJETExt = ["JET"]
4042
DynamicPPLMCMCChainsExt = ["MCMCChains"]
4143
DynamicPPLMooncakeExt = ["Mooncake"]
4244
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
@@ -55,6 +57,7 @@ Distributions = "0.25"
5557
DocStringExtensions = "0.9"
5658
EnzymeCore = "0.6 - 0.8"
5759
ForwardDiff = "0.10"
60+
JET = "0.9"
5861
LinearAlgebra = "1.6"
5962
LogDensityProblems = "2"
6063
LogDensityProblemsAD = "1.7.0"

docs/Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
66
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
77
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
88
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
9+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
910
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1011
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1112
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
@@ -18,6 +19,7 @@ Documenter = "1"
1819
DocumenterMermaid = "0.1"
1920
FillArrays = "0.13, 1"
2021
ForwardDiff = "0.10"
22+
JET = "0.9"
2123
LogDensityProblems = "2"
2224
MCMCChains = "5, 6"
2325
StableRNGs = "1"

docs/src/api.md

+23-6
Original file line numberDiff line numberDiff line change
@@ -265,20 +265,24 @@ AbstractVarInfo
265265

266266
But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.
267267

268-
#### `VarInfo`
268+
For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:
269269

270270
```@docs
271-
VarInfo
272-
TypedVarInfo
271+
DynamicPPL.untyped_varinfo
272+
DynamicPPL.typed_varinfo
273273
```
274274

275-
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
275+
#### `VarInfo`
276276

277277
```@docs
278-
link!
279-
invlink!
278+
VarInfo
279+
TypedVarInfo
280280
```
281281

282+
One main characteristic of [`VarInfo`](@ref) is that samples are transformed to unconstrained Euclidean space and stored in a linearized form, as described in the [transformation page](internals/transformations.md).
283+
The [Transformations section below](#Transformations) describes the methods used for this.
284+
In the specific case of `VarInfo`, it keeps track of whether samples have been transformed by setting flags on them, using the following functions.
285+
282286
```@docs
283287
set_flag!
284288
unset_flag!
@@ -425,6 +429,19 @@ DynamicPPL.loadstate
425429
DynamicPPL.initialsampler
426430
```
427431

432+
Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a given [`Model`](@ref), this is specified by [`DynamicPPL.default_varinfo`](@ref) and can thus be overloaded for each `model`-`sampler` combination. This can be useful in cases where one has explicit knowledge that one type of varinfo will be more performant for the given `model` and `sampler`.
433+
434+
```@docs
435+
DynamicPPL.default_varinfo
436+
```
437+
438+
There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:
439+
440+
```@docs
441+
DynamicPPL.Experimental.determine_suitable_varinfo
442+
DynamicPPL.Experimental.is_suitable_varinfo
443+
```
444+
428445
### [Model-Internal Functions](@id model_internal)
429446

430447
```@docs

docs/src/internals/varinfo.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,7 @@ For example, with the model above we have
7979

8080
```@example varinfo-design
8181
# Type-unstable `VarInfo`
82-
varinfo_untyped = DynamicPPL.untyped_varinfo(
83-
demo(), SampleFromPrior(), DefaultContext(), DynamicPPL.Metadata()
84-
)
82+
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
8583
typeof(varinfo_untyped.metadata)
8684
```
8785

ext/DynamicPPLJETExt.jl

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
module DynamicPPLJETExt
2+
3+
using DynamicPPL: DynamicPPL
4+
using JET: JET
5+
6+
function DynamicPPL.Experimental.is_suitable_varinfo(
7+
model::DynamicPPL.Model,
8+
context::DynamicPPL.AbstractContext,
9+
varinfo::DynamicPPL.AbstractVarInfo;
10+
only_ddpl::Bool=true,
11+
)
12+
# Let's make sure that both evaluation and sampling doesn't result in type errors.
13+
f, argtypes = DynamicPPL.DebugUtils.gen_evaluator_call_with_types(
14+
model, varinfo, context
15+
)
16+
# If specified, we only check errors originating somewhere in the DynamicPPL.jl.
17+
# This way we don't just fall back to untyped if the user's code is the issue.
18+
result = if only_ddpl
19+
JET.report_call(f, argtypes; target_modules=(JET.AnyFrameModule(DynamicPPL),))
20+
else
21+
JET.report_call(f, argtypes)
22+
end
23+
return length(JET.get_reports(result)) == 0, result
24+
end
25+
26+
function DynamicPPL.Experimental._determine_varinfo_jet(
27+
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
28+
)
29+
# First we try with the typed varinfo.
30+
varinfo = DynamicPPL.typed_varinfo(model, context)
31+
32+
# Let's make sure that both evaluation and sampling doesn't result in type errors.
33+
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
34+
model, context, varinfo; only_ddpl
35+
)
36+
37+
if !issuccess
38+
# Useful information for debugging.
39+
@debug "Evaluaton with typed varinfo failed with the following issues:"
40+
@debug result
41+
end
42+
43+
# If we didn't fail anywhere, we return the type stable one.
44+
return if issuccess
45+
varinfo
46+
else
47+
# Warn the user that we can't use the type stable one.
48+
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
49+
DynamicPPL.untyped_varinfo(model, context)
50+
end
51+
end
52+
53+
end

src/DynamicPPL.jl

+22-19
Original file line numberDiff line numberDiff line change
@@ -199,32 +199,35 @@ include("values_as_in_model.jl")
199199
include("debug_utils.jl")
200200
using .DebugUtils
201201

202+
include("experimental.jl")
202203
include("deprecated.jl")
203204

204205
if !isdefined(Base, :get_extension)
205206
using Requires
206207
end
207208

208-
@static if !isdefined(Base, :get_extension)
209+
# Better error message if users forget to load JET
210+
if isdefined(Base.Experimental, :register_error_hint)
209211
function __init__()
210-
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
211-
"../ext/DynamicPPLChainRulesCoreExt.jl"
212-
)
213-
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
214-
"../ext/DynamicPPLEnzymeCoreExt.jl"
215-
)
216-
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include(
217-
"../ext/DynamicPPLForwardDiffExt.jl"
218-
)
219-
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
220-
"../ext/DynamicPPLMCMCChainsExt.jl"
221-
)
222-
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include(
223-
"../ext/DynamicPPLReverseDiffExt.jl"
224-
)
225-
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
226-
"../ext/DynamicPPLZygoteRulesExt.jl"
227-
)
212+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
213+
requires_jet =
214+
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
215+
length(argtypes) >= 2 &&
216+
argtypes[1] <: Model &&
217+
argtypes[2] <: AbstractContext
218+
requires_jet |=
219+
exc.f === DynamicPPL.Experimental.is_suitable_varinfo &&
220+
length(argtypes) >= 3 &&
221+
argtypes[1] <: Model &&
222+
argtypes[2] <: AbstractContext &&
223+
argtypes[3] <: AbstractVarInfo
224+
if requires_jet
225+
print(
226+
io,
227+
"\n$(exc.f) requires JET.jl to be loaded. Please run `using JET` before calling $(exc.f).",
228+
)
229+
end
230+
end
228231
end
229232
end
230233

src/experimental.jl

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
module Experimental
2+
3+
using DynamicPPL: DynamicPPL
4+
5+
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
6+
"""
7+
is_suitable_varinfo(model::Model, context::AbstractContext, varinfo::AbstractVarInfo; kwargs...)
8+
9+
Check if the `model` supports evaluation using the provided `context` and `varinfo`.
10+
11+
!!! warning
12+
Loading JET.jl is required before calling this function.
13+
14+
# Arguments
15+
- `model`: The model to verify the support for.
16+
- `context`: The context to use for the model evaluation.
17+
- `varinfo`: The varinfo to verify the support for.
18+
19+
# Keyword Arguments
20+
- `only_ddpl`: If `true`, only consider error reports occuring in the tilde pipeline. Default: `true`.
21+
22+
# Returns
23+
- `issuccess`: `true` if the model supports the varinfo, otherwise `false`.
24+
- `report`: The result of `report_call` from JET.jl.
25+
"""
26+
function is_suitable_varinfo end
27+
28+
# Internal hook for JET.jl to overload.
29+
function _determine_varinfo_jet end
30+
31+
"""
32+
determine_suitable_varinfo(model[, context]; only_ddpl::Bool=true)
33+
34+
Return a suitable varinfo for the given `model`.
35+
36+
See also: [`DynamicPPL.Experimental.is_suitable_varinfo`](@ref).
37+
38+
!!! warning
39+
For full functionality, this requires JET.jl to be loaded.
40+
If JET.jl is not loaded, this function will assume the model is compatible with typed varinfo.
41+
42+
# Arguments
43+
- `model`: The model for which to determine the varinfo.
44+
- `context`: The context to use for the model evaluation. Default: `SamplingContext()`.
45+
46+
# Keyword Arguments
47+
- `only_ddpl`: If `true`, only consider error reports within DynamicPPL.jl.
48+
49+
# Examples
50+
51+
```jldoctest
52+
julia> using DynamicPPL.Experimental: determine_suitable_varinfo
53+
54+
julia> using JET: JET # needs to be loaded for full functionality
55+
56+
julia> @model function model_with_random_support()
57+
x ~ Bernoulli()
58+
if x
59+
y ~ Normal()
60+
else
61+
z ~ Normal()
62+
end
63+
end
64+
model_with_random_support (generic function with 2 methods)
65+
66+
julia> model = model_with_random_support();
67+
68+
julia> # Typed varinfo cannot handle this random support model properly
69+
# as using a single execution of the model will not see all random variables.
70+
# Hence, this this model requires untyped varinfo.
71+
vi = determine_suitable_varinfo(model);
72+
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
73+
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48
74+
75+
julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
76+
true
77+
78+
julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
79+
@model model_with_static_support() = x ~ Normal()
80+
model_with_static_support (generic function with 2 methods)
81+
82+
julia> vi = determine_suitable_varinfo(model_with_static_support());
83+
84+
julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
85+
true
86+
```
87+
"""
88+
function determine_suitable_varinfo(
89+
model::DynamicPPL.Model,
90+
context::DynamicPPL.AbstractContext=DynamicPPL.SamplingContext();
91+
only_ddpl::Bool=true,
92+
)
93+
# If JET.jl has been loaded, and thus `determine_varinfo` has been defined, we use that.
94+
return if Base.get_extension(DynamicPPL, :DynamicPPLJETExt) !== nothing
95+
_determine_varinfo_jet(model, context; only_ddpl)
96+
else
97+
# Warn the user.
98+
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
99+
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
100+
DynamicPPL.typed_varinfo(model, context)
101+
end
102+
end
103+
104+
end

0 commit comments

Comments
 (0)