|  | 
| 237 | 237 |     end | 
| 238 | 238 | end | 
| 239 | 239 | 
 | 
| 240 |  | -@testset verbose = true "AD / SamplingContext" begin | 
| 241 |  | -    # AD tests for gradient-based samplers need to be run with SamplingContext | 
| 242 |  | -    # because samplers can potentially use this to define custom behaviour in | 
| 243 |  | -    # the tilde-pipeline and thus change the code executed during model | 
| 244 |  | -    # evaluation. | 
| 245 |  | -    @testset "adtype=$adtype" for adtype in ADTYPES | 
| 246 |  | -        @testset "alg=$alg" for alg in [ | 
| 247 |  | -            HMC(0.1, 10; adtype=adtype), | 
| 248 |  | -            HMCDA(0.8, 0.75; adtype=adtype), | 
| 249 |  | -            NUTS(1000, 0.8; adtype=adtype), | 
| 250 |  | -            SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype), | 
| 251 |  | -            SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype), | 
| 252 |  | -        ] | 
| 253 |  | -            @info "Testing AD for $alg" | 
| 254 |  | - | 
| 255 |  | -            @testset "model=$(model.f)" for model in DEMO_MODELS | 
| 256 |  | -                rng = StableRNG(123) | 
| 257 |  | -                spl_model = DynamicPPL.contextualize( | 
| 258 |  | -                    model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg)) | 
| 259 |  | -                ) | 
| 260 |  | -                @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any | 
| 261 |  | -            end | 
| 262 |  | -        end | 
| 263 |  | -    end | 
| 264 |  | -end | 
| 265 |  | - | 
| 266 | 240 | @testset verbose = true "AD / GibbsContext" begin | 
| 267 |  | -    # Gibbs sampling also needs extra AD testing because the models are | 
|  | 241 | +    # Gibbs sampling needs some extra AD testing because the models are | 
| 268 | 242 |     # executed with GibbsContext and a subsetted varinfo. (see e.g. | 
| 269 | 243 |     # `gibbs_initialstep_recursive` and `gibbs_step_recursive` in | 
| 270 | 244 |     # src/mcmc/gibbs.jl -- the code here mimics what happens in those | 
|  | 
| 283 | 257 |                     model, varnames, deepcopy(global_vi) | 
| 284 | 258 |                 ) | 
| 285 | 259 |                 rng = StableRNG(123) | 
| 286 |  | -                spl_model = DynamicPPL.contextualize( | 
| 287 |  | -                    model, DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(HMC(0.1, 10))) | 
| 288 |  | -                ) | 
| 289 |  | -                @test run_ad(spl_model, adtype; test=true, benchmark=false) isa Any | 
|  | 260 | +                @test run_ad(model, adtype; test=true, benchmark=false) isa Any | 
| 290 | 261 |             end | 
| 291 | 262 |         end | 
| 292 | 263 |     end | 
|  | 
0 commit comments