This repository has been archived by the owner on Nov 1, 2024. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy path11-sparse_horseshoe.jl
91 lines (80 loc) · 5.08 KB
/
11-sparse_horseshoe.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
using Turing
using CSV
using DataFrames
using StatsBase
using LinearAlgebra
# reproducibility
using Random: seed!
seed!(123)
# load data
df = CSV.read("datasets/sparse_regression.csv", DataFrame)
# define data matrix X and standardize
X = float(Matrix(select(df, Not(:y))))
X = standardize(ZScoreTransform, X; dims=1)
# define dependent variable y and standardize
y = float(df[:, :y])
y = standardize(ZScoreTransform, y; dims=1)
# define the model
@model function sparse_horseshoe_regression(X, y; predictors=size(X, 2))
# priors
α ~ TDist(3) * 2.5
λ ~ filldist(truncated(Cauchy(0, 1); lower=0), predictors)
τ ~ truncated(Cauchy(0, 1); lower=0)
σ ~ Exponential(1)
β ~ MvNormal(Diagonal((λ .* τ) .^ 2))
# likelihood
y ~ MvNormal(α .+ X * β, σ^2 * I)
return (; y, α, λ, τ, σ, β)
end
# instantiate the model
model = sparse_horseshoe_regression(X, y)
# sample with NUTS, 4 multi-threaded parallel chains, and 2k iters with 1k warmup
chn = sample(model, NUTS(1_000, 0.8), MCMCThreads(), 1_000, 4)
println(DataFrame(summarystats(chn)))
# results:
# parameters mean std mcse ess_bulk ess_tail rhat ess_per_sec
# Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
#
# α 0.0042 0.0305 0.0040 59.5643 80.2524 1.0706 0.0642
# λ[1] -2.1122 6.6393 2.0278 13.0743 46.8402 1.5634 0.0141
# λ[2] -0.1085 1.1720 0.1996 23.9859 146.5381 1.2223 0.0259
# λ[3] 0.6606 4.3584 1.4231 10.9612 30.3532 1.9050 0.0118
# λ[4] -0.3683 1.3287 0.3048 16.2733 93.6187 1.3564 0.0175
# λ[5] -0.0138 4.8649 1.5867 11.0779 30.2512 1.8539 0.0119
# λ[6] -0.0985 3.6917 1.2181 10.8100 60.5601 1.9403 0.0117
# λ[7] -0.3905 1.3529 0.3539 11.8725 58.3778 1.7623 0.0128
# λ[8] -1.0362 2.6053 0.8181 11.5641 40.9582 1.7344 0.0125
# λ[9] 0.9852 1.0252 0.1503 37.0615 71.6680 1.1240 0.0400
# λ[10] 0.0326 1.6489 0.3335 20.7106 113.6655 1.2161 0.0223
# λ[11] 0.7310 2.7803 0.8534 12.8535 20.8644 1.6288 0.0139
# λ[12] -0.7746 1.7907 0.5203 12.4086 37.8785 1.6285 0.0134
# λ[13] 1.1169 1.4141 0.2771 24.5536 33.7786 1.1852 0.0265
# λ[14] -0.3582 5.4360 1.8088 11.0894 46.2207 1.9098 0.0120
# λ[15] 0.6097 1.8694 0.4101 17.7370 57.8266 1.2889 0.0191
# λ[16] -0.6527 2.4383 0.5938 13.4484 52.2172 1.4904 0.0145
# λ[17] -0.4548 1.3797 0.2970 18.9404 75.0517 1.2441 0.0204
# λ[18] 0.1632 1.3832 0.3004 20.4743 105.9446 1.2071 0.0221
# λ[19] 0.2587 1.9653 0.5230 12.2939 39.3894 1.6290 0.0133
# λ[20] 1.0108 5.1065 1.7174 11.1679 47.7117 1.8312 0.0120
# τ 0.0733 0.0225 0.0041 29.0739 134.8200 1.1333 0.0314
# σ 0.3250 0.0225 0.0022 116.6290 207.2708 1.0186 0.1258
# β[1] 0.4853 0.0345 0.0052 45.8155 110.9592 1.0770 0.0494
# β[2] 0.0030 0.0246 0.0035 49.2652 215.6946 1.0856 0.0531
# β[3] 0.2613 0.0341 0.0045 58.8703 174.2487 1.0742 0.0635
# β[4] 0.0082 0.0286 0.0041 45.9931 99.5215 1.0843 0.0496
# β[5] -0.2896 0.0315 0.0030 109.3217 165.7488 1.0426 0.1179
# β[6] 0.2234 0.0314 0.0035 79.2594 103.0314 1.0828 0.0855
# β[7] 0.0172 0.0292 0.0044 45.4594 166.8813 1.1061 0.0490
# β[8] 0.1240 0.0342 0.0048 52.9156 193.5309 1.1070 0.0571
# β[9] 0.0031 0.0338 0.0062 31.2744 24.4042 1.1500 0.0337
# β[10] 0.0260 0.0338 0.0047 52.5279 115.6564 1.0818 0.0566
# β[11] 0.0813 0.0363 0.0042 68.8451 143.6499 1.0765 0.0742
# β[12] -0.0063 0.0305 0.0040 58.7870 120.4656 1.0386 0.0634
# β[13] -0.0202 0.0238 0.0025 93.3616 207.3711 1.0294 0.1007
# β[14] -0.4260 0.0333 0.0049 47.0491 66.6016 1.1070 0.0507
# β[15] -0.0153 0.0281 0.0039 47.5601 161.6239 1.1152 0.0513
# β[16] -0.0198 0.0306 0.0071 20.3811 126.5139 1.2857 0.0220
# β[17] -0.0078 0.0253 0.0031 65.9642 111.4478 1.0430 0.0711
# β[18] 0.0113 0.0250 0.0033 58.6306 55.4533 1.0512 0.0632
# β[19] -0.0046 0.0287 0.0034 73.1314 60.9919 1.0320 0.0789
# β[20] -0.3487 0.0348 0.0057 37.4172 51.0361 1.1133 0.0404