Skip to content

Commit

Permalink
Merge pull request #69 from ShouvikGhosh2048/adding_numerical_bayesia…
Browse files Browse the repository at this point in the history
…n_tests

Adding bayesian numerical tests
  • Loading branch information
sourish-cmi authored Dec 28, 2022
2 parents fcb1f2b + 433a5d1 commit 39ff832
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23 changes: 23 additions & 0 deletions test/numerical/bayesian/LinearRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
mtcars = dataset("datasets", "mtcars")

tests = [
(Prior_Ridge(), 20.080877893580514),
(Prior_Laplace(), 20.070783434589128),
(Prior_Cauchy(), 20.019759144845644),
(Prior_TDist(), 20.042614331921428),
(Prior_HorseShoe(), 20.042984550677183),
]

for (prior, test_mean) in tests
CRRao.set_rng(StableRNG(123))
model = fit(@formula(MPG ~ HP + WT + Gear), mtcars, LinearRegression(), prior)

@test mean(predict(model, mtcars)) test_mean
end

gauss_test = 20.0796026428345

CRRao.set_rng(StableRNG(123))
model = fit(@formula(MPG ~ HP + WT + Gear), mtcars, LinearRegression(), Prior_Gauss(), 30.0, [0.0,-3.0,1.0], 1000)

@test mean(predict(model, mtcars)) gauss_test
58 changes: 58 additions & 0 deletions test/numerical/bayesian/LogisticRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
turnout = dataset("Zelig", "turnout")[1:100,:] # Take a subset of rows to reduce input size

tests = [
(
Prior_Ridge(),
(
(Logit(), 0.7690822208626806),
(Probit(), 0.7685999218881091),
(Cloglog(), 0.7751111243871245),
(Cauchit(), 0.7730511118602764)
)
),
(
Prior_Laplace(),
(
(Logit(), 0.7718593681922629),
(Probit(), 0.7695587585010469),
(Cloglog(), 0.7714870967902365),
(Cauchit(), 0.7714839338283468)
)
),
(
Prior_Cauchy(),
(
(Logit(), 0.7678814727043146),
(Probit(), 0.764699194194744),
(Cloglog(), 0.7642369367775604),
(Cauchit(), 0.7692152829967064)
)
),
(
Prior_TDist(),
(
(Logit(), 0.588835403024102),
(Probit(), 0.7635642627091132),
(Cloglog(), 0.7609943137312546),
(Cauchit(), 0.772095066757767)
)
),
(
Prior_HorseShoe(),
(
(Logit(), 0.38683395333332327),
(Probit(), 0.38253233489484173),
(Cloglog(), 0.7667553778881738),
(Cauchit(), 0.7706755564626601)
)
),
]

for (prior, prior_testcases) in tests
for (link, test_mean) in prior_testcases
CRRao.set_rng(StableRNG(123))
model = fit(@formula(Vote ~ Age + Race + Income + Educate), turnout, LogisticRegression(), link, prior)

@test mean(predict(model, turnout)) test_mean
end
end
16 changes: 16 additions & 0 deletions test/numerical/bayesian/NegBinomialRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
sanction = dataset("Zelig", "sanction")

tests = [
(Prior_Ridge(), 6.8753100988051274),
(Prior_Laplace(), 6.908332048475347),
(Prior_Cauchy(), 6.9829255933233645),
(Prior_TDist(), 6.915515248823249),
(Prior_HorseShoe(), 6.703023191644206),
]

for (prior, test_mean) in tests
CRRao.set_rng(StableRNG(123))
model = fit(@formula(Num ~ Target + Coop + NCost), sanction, NegBinomRegression(), prior)

@test mean(predict(model, sanction)) test_mean
end
16 changes: 16 additions & 0 deletions test/numerical/bayesian/PoissonRegression.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
sanction = dataset("Zelig", "sanction")

tests = [
(Prior_Ridge(), 7.177578002644547),
(Prior_Laplace(), 7.1454141602741785),
(Prior_Cauchy(), 7.148699646242317),
(Prior_TDist(), 7.165968828611132),
(Prior_HorseShoe(), 7.144190707091213),
]

for (prior, test_mean) in tests
CRRao.set_rng(StableRNG(123))
model = fit(@formula(Num ~ Target + Coop + NCost), sanction, PoissonRegression(), prior)

@test mean(predict(model, sanction)) test_mean
end
20 changes: 19 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using CRRao, Test, StableRNGs, Logging, RDatasets, StatsModels
using CRRao, Test, StableRNGs, Logging, RDatasets, StatsModels, Statistics

Logging.disable_logging(Logging.Warn)

Expand All @@ -23,4 +23,22 @@ CRRao.set_rng(StableRNG(123))
include("basic/NegBinomialRegression.jl")
end
end

@testset "Numerical Tests" begin
@testset "Linear Regression" begin
include("numerical/bayesian/LinearRegression.jl")
end

@testset "Logistic Regression" begin
include("numerical/bayesian/LogisticRegression.jl")
end

@testset "Poisson Regression" begin
include("numerical/bayesian/PoissonRegression.jl")
end

@testset "Negative Binomial Regression" begin
include("numerical/bayesian/NegBinomialRegression.jl")
end
end
end

0 comments on commit 39ff832

Please sign in to comment.