Skip to content

Commit e083a4b

Browse files
author
Aoife
committed
bumped chain length
1 parent 9b269a7 commit e083a4b

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

usage/stochastic-gradient-samplers/index.qmd

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,35 +70,43 @@ end
7070
model = gaussian_model(data)
7171
```
7272

73-
SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time. Note: Currently, `PolynomialStepsize` is the primary stepsize schedule available in Turing for SGLD:
73+
SGLD requires very small step sizes to ensure stability. We use a `PolynomialStepsize` that decreases over time. Note: Currently, `PolynomialStepsize` is the primary stepsize schedule available in Turing for SGLD.
74+
75+
**Important Note on Convergence**: The examples below use longer chains (10,000-15,000 samples) with the first half discarded as burn-in to ensure proper convergence. This is typical for stochastic gradient samplers, which require more samples than standard HMC/NUTS to achieve reliable results:
7476

7577
```{julia}
7678
# SGLD with polynomial stepsize schedule
7779
# stepsize(t) = a / (b + t)^γ
78-
sgld_stepsize = Turing.PolynomialStepsize(0.0001, 10000, 0.55)
79-
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 5000)
80+
# Using smaller step size and longer chain for better convergence
81+
sgld_stepsize = Turing.PolynomialStepsize(0.00005, 20000, 0.55)
82+
chain_sgld = sample(model, SGLD(stepsize=sgld_stepsize), 10000)
8083
81-
summarystats(chain_sgld)
84+
# Note: We use a longer chain (10000 samples) to ensure convergence
85+
# The first half can be considered burn-in
86+
summarystats(chain_sgld[5001:end])
8287
```
8388

8489

8590
```{julia}
86-
plot(chain_sgld)
91+
# Plot the second half of the chain to show converged behavior
92+
plot(chain_sgld[5001:end])
8793
```
8894

8995
## SGHMC (Stochastic Gradient Hamiltonian Monte Carlo)
9096

9197
SGHMC extends HMC to the stochastic gradient setting by incorporating friction to counteract the noise from stochastic gradients:
9298

9399
```{julia}
94-
# SGHMC with very small learning rate
95-
chain_sghmc = sample(model, SGHMC(learning_rate=0.00001, momentum_decay=0.1), 5000)
100+
# SGHMC with very small learning rate and longer chain
101+
chain_sghmc = sample(model, SGHMC(learning_rate=0.000005, momentum_decay=0.2), 10000)
96102
97-
summarystats(chain_sghmc)
103+
# Using the second half of the chain after burn-in
104+
summarystats(chain_sghmc[5001:end])
98105
```
99106

100107
```{julia}
101-
plot(chain_sghmc)
108+
# Plot the second half of the chain to show converged behavior
109+
plot(chain_sghmc[5001:end])
102110
```
103111

104112
## Comparison with Standard HMC
@@ -115,10 +123,11 @@ summarystats(chain_hmc)
115123
Compare the trace plots to see how the different samplers explore the posterior:
116124

117125
```{julia}
118-
p1 = plot(chain_sgld[:μ], label="SGLD", title="μ parameter traces")
126+
# Compare converged portions of the chains
127+
p1 = plot(chain_sgld[5001:end][:μ], label="SGLD (after burn-in)", title="μ parameter traces")
119128
hline!([true_μ], label="True value", linestyle=:dash, color=:red)
120129
121-
p2 = plot(chain_sghmc[:μ], label="SGHMC")
130+
p2 = plot(chain_sghmc[5001:end][:μ], label="SGHMC (after burn-in)")
122131
hline!([true_μ], label="True value", linestyle=:dash, color=:red)
123132
124133
p3 = plot(chain_hmc[:μ], label="HMC")
@@ -127,10 +136,10 @@ hline!([true_μ], label="True value", linestyle=:dash, color=:red)
127136
plot(p1, p2, p3, layout=(3,1), size=(800,600))
128137
```
129138

130-
The comparison shows that:
131-
- **SGLD** exhibits slower convergence and higher variance due to the injected noise, requiring longer chains to achieve stable estimates
132-
- **SGHMC** shows slightly better mixing than SGLD due to the momentum term, but still requires careful tuning
133-
- **HMC** converges quickly and efficiently explores the posterior, demonstrating why it's preferred for small to medium-sized problems
139+
The comparison shows that (using converged portions after burn-in):
140+
- **SGLD** exhibits slower convergence and higher variance due to the injected noise, requiring longer chains (10,000+ samples) and discarding burn-in to achieve stable estimates
141+
- **SGHMC** shows slightly better mixing than SGLD due to the momentum term, but still requires careful tuning and burn-in period
142+
- **HMC** converges quickly and efficiently explores the posterior from the start, demonstrating why it's preferred for small to medium-sized problems
134143

135144
## Bayesian Linear Regression Example
136145

@@ -162,11 +171,11 @@ lr_model = linear_regression(X, y)
162171
Sample using the stochastic gradient methods:
163172

164173
```{julia}
165-
# Very conservative parameters for stability
166-
sgld_lr_stepsize = Turing.PolynomialStepsize(0.00005, 10000, 0.55)
167-
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 5000)
174+
# Very conservative parameters for stability with longer chains
175+
sgld_lr_stepsize = Turing.PolynomialStepsize(0.00002, 30000, 0.55)
176+
chain_lr_sgld = sample(lr_model, SGLD(stepsize=sgld_lr_stepsize), 15000)
168177
169-
chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.00005, momentum_decay=0.1), 5000)
178+
chain_lr_sghmc = sample(lr_model, SGHMC(learning_rate=0.000002, momentum_decay=0.3), 15000)
170179
171180
chain_lr_hmc = sample(lr_model, HMC(0.01, 10), 1000)
172181
```
@@ -178,13 +187,14 @@ println("True β values: ", true_β)
178187
println("True σ value: ", true_σ_noise)
179188
println()
180189
181-
println("SGLD estimates:")
182-
summarystats(chain_lr_sgld)
190+
println("SGLD estimates (after burn-in):")
191+
summarystats(chain_lr_sgld[7501:end])
183192
```
184193

185194
The linear regression example demonstrates that stochastic gradient samplers can recover the true parameters, but:
186-
- They require significantly longer chains (5000 vs 1000 for HMC)
187-
- The estimates may have higher variance
195+
- They require significantly longer chains (15000 vs 1000 for HMC)
196+
- We discard the first half as burn-in to ensure convergence
197+
- The estimates may still have higher variance than HMC
188198
- Convergence diagnostics should be carefully examined before trusting the results
189199

190200
## Automatic Differentiation Backends

0 commit comments

Comments
 (0)