diff --git a/src/BuildBasis.jl b/src/BuildBasis.jl index 8ecef57..a2f79e2 100644 --- a/src/BuildBasis.jl +++ b/src/BuildBasis.jl @@ -144,32 +144,55 @@ Arguments: - `Basis` : `AbstractArray` containing the basis matrix. """ function BuildPenaltyMatrix(y, x, sp, Basis) - n = length(y) - n_knots = map(x -> length(x.breakpoints), Basis) - X = map(BuildBasisMatrix, Basis, x) - D = map(BuildDifferenceMatrix, Basis) - # Drop one column from X and D for identifiability - X = map(DropCol, X, n_knots) - D = map(DropCol, D, n_knots) - - # Center - ColMeans = map(BuildBasisMatrixColMeans, X) - X = map(CenterBasisMatrix, X, ColMeans) - - # Store Coef index - CoefIndex = BuildCoefIndex(X) - # Build Design Matrix - X = hcat(repeat([1],n), hcat(X...)) # add intercept - - # Build Penalty Matrix - D = dcat(map((p, d) -> (sqrt(p) * d), sp, D)) - D = hcat(repeat([0], size(D,1)), D) # add 0 for intercept + # Prepare per-term blocks + Xblocks = Vector{AbstractMatrix}(undef, length(x)) + Dblocks = Vector{AbstractMatrix}(undef, length(x)) + ColMeans = Vector{AbstractArray}(undef, length(x)) + + for i in eachindex(x) + bi = Basis[i] + xi = x[i] + if bi === :linear + # centered 1-column linear block; no penalty + μ = mean(xi) + Xi = reshape(xi .- μ, :, 1) + Di = zeros(1, 1) + Xblocks[i] = Xi + Dblocks[i] = Di + ColMeans[i] = reshape([μ], 1, :) + else + # smooth block + nk = length(bi.breakpoints) + Xi0 = BuildBasisMatrix(bi, xi) + Di0 = BuildDifferenceMatrix(bi) + Xi = DropCol(Xi0, nk) + Di = DropCol(Di0, nk) + # drop for identifiability + cm = BuildBasisMatrixColMeans(Xi) + Xi = CenterBasisMatrix(Xi, cm) + Xblocks[i] = Xi + Dblocks[i] = Di + ColMeans[i] = cm + end + end + + # Coefficient index per-term by block width + CoefIndex = BuildCoefIndex(Xblocks) + + # Assembled design with intercept + X = hcat(repeat([1], n), hcat(Xblocks...)) + + # Block-diagonal penalty (sqrt(sp) scaling on smooth blocks; linear blocks are zero already) + Dscaled = map((p, d) -> sqrt(p) .* d, sp, Dblocks) + D = dcat(Dscaled) + D = hcat(repeat([0], size(D, 1)), D) # intercept column return X, y, D, ColMeans, CoefIndex end + """ HatMatrix(X, D, W) Builds a hat matrix. diff --git a/src/FitGAM.jl b/src/FitGAM.jl index f2e069f..35a148a 100644 --- a/src/FitGAM.jl +++ b/src/FitGAM.jl @@ -31,14 +31,24 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident @assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family" end - x = Data[!, GAMForm.covariates.variable] - BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)] - x = [x[!, col] for col in names(x)] + # Collect covariate columns and term meta + xdf = Data[!, GAMForm.covariates.variable] + x = [xdf[!, col] for col in names(xdf)] + BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)] + smoothmask = collect(GAMForm.covariates.smooth) - # Build basis - Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + # Per-term basis: smooth -> BSplineBasis, linear -> :linear + Basis = Vector{Any}(undef, length(x)) + for i in eachindex(x) + if smoothmask[i] + k, degree = BasisArgs[i] + Basis[i] = BuildUniformBasis(x[i], k, degree) + else + Basis[i] = :linear + end + end # Fit PIRLS procedure gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) return gam -end \ No newline at end of file +end diff --git a/src/GAMData.jl b/src/GAMData.jl index d71fdf1..fe38387 100644 --- a/src/GAMData.jl +++ b/src/GAMData.jl @@ -21,7 +21,7 @@ Arguments: mutable struct GAMData y::AbstractArray x::AbstractArray - Basis::AbstractArray{BSplineBasis} + Basis::AbstractArray Family::Dict Link::Dict Coef::AbstractArray @@ -44,4 +44,4 @@ mutable struct GAMData ) new(y, x, Basis, Family, Link, Coef, ColMeans, CoefIndex, Fitted, Diagnostics) end -end \ No newline at end of file +end diff --git a/src/PIRLS.jl b/src/PIRLS.jl index 1293994..a74c2e1 100644 --- a/src/PIRLS.jl +++ b/src/PIRLS.jl @@ -21,12 +21,30 @@ function PIRLS(y, x, sp, Basis, Dist, Link; maxIter = 25, tol = 1e-6) # Initial Predictions n = length(y) + # μ init. Should replace with multiple dispatch logic + μmin = 1e-8 + μmax = 1e12 # generous upper bound to avoid overflow in weights + if Dist[:Name] == "Bernoulli" - mu = clamp.(y, 1e-6, 1 - 1e-6) + p0 = clamp(mean(y), 1e-3, 1 - 1e-3) + mu = fill(p0, n) + elseif Dist[:Name] == "Poisson" + # Poisson needs strictly positive mu; start at max(y, small) or mean if all zeros + if all(==(0), y) + mu = fill(0.1, n) + else + mu = max.(Float64.(y), μmin) + end + elseif Dist[:Name] == "Gamma" + # Gamma also requires positive mu; start near the positive mean of y + ybar = max(mean(abs.(Float64.(y))), 1e-2) + mu = fill(ybar, n) else - mu = y + # Gaussian and others: start at y (okay to be any real) + mu = Float64.(y) end + mu = clamp.(mu, μmin, μmax) eta = Link[:Function].(mu) # Deviance diff --git a/src/Predictions.jl b/src/Predictions.jl index bf2f029..b0fd93b 100644 --- a/src/Predictions.jl +++ b/src/Predictions.jl @@ -29,7 +29,15 @@ Arguments: - `ix` : `Int` denoting the variable to plot. """ function PredictPartial(mod, ix) - predMatrix = BuildPredictionMatrix(mod.x[ix], mod.Basis[ix], mod.ColMeans[ix]) - predBeta = mod.Coef[mod.CoefIndex[ix]] - return predMatrix * predBeta + bi = mod.Basis[ix] + if bi === :linear + μ = mod.ColMeans[ix][1] # stored as 1×1 + Xi = reshape(mod.x[ix] .- μ, :, 1) + β = mod.Coef[mod.CoefIndex[ix]] # scalar + return Xi * β + else + predMatrix = BuildPredictionMatrix(mod.x[ix], bi, mod.ColMeans[ix]) + predBeta = mod.Coef[mod.CoefIndex[ix]] + return predMatrix * predBeta + end end diff --git a/test/Project.toml b/test/Project.toml index f6698d9..99c79a7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,9 @@ [deps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index d430a90..63d7365 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,181 +1,409 @@ using GAM using Test -using RDatasets, Plots -using Distributions - -#-------------------- Set up data ----------------- - -df = dataset("datasets", "trees"); +using RDatasets, DataFrames, Plots +using Distributions, Random, Statistics, LinearAlgebra #-------------------- Run tests ----------------- -@testset "GAM.jl" begin - - mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) - - p = plotGAM(mod) - @test p isa Plots.Plot - - # Gamma version - - mod2 = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; Family = "Gamma", Link = "Log") - - p1 = plotGAM(mod2) - @test p1 isa Plots.Plot -end - -@testset "Bernoulli GAM Tests" begin - - # Test 1: Basic Bernoulli GAM with known pattern - @testset "Basic Bernoulli GAM" begin - n = 200 - x1 = range(-2, 2, length=n) - x2 = randn(n) - - # Create true nonlinear effect - f1 = sin.(x1 * π/2) - f2 = x2.^2 .- 1 - eta = f1 + f2 - p = 1 ./ (1 .+ exp.(-eta)) - y = rand.(Bernoulli.(p)) - - df = DataFrame(y=y, x1=x1, x2=x2) - - # Fit model - mod = gam("y ~ s(x1, k=8, degree=3) + s(x2, k=8, degree=3)", df; - Family = "Bernoulli", Link = "Logit") - - # Basic tests - @test mod isa GAMData - @test mod.Family[:Name] == "Bernoulli" - @test mod.Link[:Name] == "Logit" - @test all(0 .<= mod.Fitted .<= 1) # Predictions should be probabilities - @test length(mod.Fitted) == n - - # Test plotting - p = plotGAM(mod) - @test p isa Plots.Plot - end +@testset "GAM.jl Test Suite" begin - # Test 2: Edge cases with extreme probabilities - @testset "Extreme probability cases" begin - n = 100 - x = range(-5, 5, length=n) + # =========================== + # Core Gaussian GAM Tests + # =========================== + @testset "Gaussian GAM (Normal Family)" begin + df = dataset("datasets", "trees") - # Create data with extreme probabilities - eta = 10 * x # Very strong effect to create extreme probabilities - p = 1 ./ (1 .+ exp.(-eta)) - y = Float64.(p .> 0.5) # Deterministic for testing + @testset "Basic fitting and prediction" begin + mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + + @test mod isa GAMData + @test mod.Family[:Name] == "Normal" + @test mod.Link[:Name] == "Identity" + @test length(mod.Fitted) == nrow(df) + + # Check residuals properties + residuals = df.Volume .- mod.Fitted + @test mean(residuals) ≈ 0 atol=0.5 + @test std(residuals) < std(df.Volume) # Model should explain some variance + end - df = DataFrame(y=y, x=x) + @testset "Model diagnostics" begin + mod = gam("Volume ~ s(Girth, k=8, degree=3) + s(Height, k=8, degree=3)", df) + + @test haskey(mod.Diagnostics, :RSS) + @test haskey(mod.Diagnostics, :EDF) + @test haskey(mod.Diagnostics, :GCV) + + @test mod.Diagnostics[:RSS] > 0 + @test 2 < mod.Diagnostics[:EDF] < 16 # Between minimum and maximum possible + @test mod.Diagnostics[:GCV] > 0 + end - # This should run without numerical errors - mod = gam("y ~ s(x, k=5, degree=3)", df; - Family = "Bernoulli", Link = "Logit") + @testset "Partial predictions" begin + mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + + # Test partial predictions for each smooth + for i in 1:length(mod.x) + partial = GAM.PredictPartial(mod, i) + @test length(partial) == length(mod.x[i]) + @test !any(isnan.(partial)) + @test !any(isinf.(partial)) + end + end - @test !any(isnan.(mod.Fitted)) - @test !any(isinf.(mod.Fitted)) - @test all(0 .<= mod.Fitted .<= 1) + @testset "Visualization" begin + mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + p = plotGAM(mod) + @test p isa Plots.Plot + end end - # Test 3: Binary validation - @testset "Binary response validation" begin - n = 50 - x = randn(n) + # =========================== + # Gamma GAM Tests + # =========================== + @testset "Gamma GAM" begin + df = dataset("datasets", "trees") - # Test with non-binary data (should throw error) - y_continuous = randn(n) - df_continuous = DataFrame(y=y_continuous, x=x) - - @test_throws AssertionError gam("y ~ s(x, k=5, degree=3)", df_continuous; - Family = "Bernoulli", Link = "Logit") - - # Test with proper binary data (0s and 1s) - y_binary = rand([0, 1], n) - df_binary = DataFrame(y=y_binary, x=x) + @testset "Basic fitting with log link" begin + mod = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; + Family = "Gamma", Link = "Log") + + @test mod.Family[:Name] == "Gamma" + @test mod.Link[:Name] == "Log" + @test all(mod.Fitted .> 0) # Gamma predictions must be positive + @test length(mod.Fitted) == nrow(df) + end - mod = gam("y ~ s(x, k=5, degree=3)", df_binary; - Family = "Bernoulli", Link = "Logit") - @test mod isa GAMData + @testset "Convergence and stability" begin + # Test with different starting values via different maxIter + mod1 = gam("Volume ~ s(Girth, k=8, degree=3)", df; + Family = "Gamma", Link = "Log", maxIter = 10) + mod2 = gam("Volume ~ s(Girth, k=8, degree=3)", df; + Family = "Gamma", Link = "Log", maxIter = 50) + + # Should converge to similar values + @test cor(mod1.Fitted, mod2.Fitted) > 0.99 + end end - # Test 4: Model diagnostics - @testset "Model diagnostics" begin - n = 150 - x = sort(randn(n)) - logit_p = 2*x - p = 1 ./ (1 .+ exp.(-logit_p)) - y = rand.(Bernoulli.(p)) - - df = DataFrame(y=y, x=x) - - mod = gam("y ~ s(x, k=10, degree=3)", df; - Family = "Bernoulli", Link = "Logit") - - # Check diagnostics exist and are reasonable - @test haskey(mod.Diagnostics, :EDF) - @test haskey(mod.Diagnostics, :GCV) - @test mod.Diagnostics[:EDF] > 1 # Should have some smoothing - @test mod.Diagnostics[:EDF] < 10 # But not too complex - @test mod.Diagnostics[:GCV] > 0 + # =========================== + # Bernoulli GAM Tests + # =========================== + @testset "Bernoulli GAM (Logistic)" begin + Random.seed!(123) # For reproducibility + + @testset "Nonlinear binary classification" begin + n = 300 + x1 = range(-2, 2, length=n) + x2 = randn(n) + + # Create true nonlinear effect + f1 = sin.(x1 * π/2) + f2 = 0.5 * (x2.^2 .- 1) + eta = f1 + f2 + p_true = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p_true)) + + df = DataFrame(y=y, x1=x1, x2=x2) + + mod = gam("y ~ s(x1, k=8, degree=3) + s(x2, k=8, degree=3)", df; + Family = "Bernoulli", Link = "Logit") + + @test mod.Family[:Name] == "Bernoulli" + @test mod.Link[:Name] == "Logit" + @test all(0 .<= mod.Fitted .<= 1) + + # Check classification performance + predictions = mod.Fitted .> 0.5 + accuracy = mean(predictions .== y) + @test accuracy > 0.6 # Should do better than random + + # Check calibration: average predicted prob should ≈ observed proportion + @test abs(mean(mod.Fitted) - mean(y)) < 0.1 + end + + @testset "Numerical stability with extreme probabilities" begin + n = 100 + x = range(-5, 5, length=n) + + # Create separation - extreme but not complete + eta = 8 * x + p = 1 ./ (1 .+ exp.(-eta)) + y = Float64.(p .> 0.5) + + df = DataFrame(y=y, x=x) + + mod = gam("y ~ s(x, k=5, degree=3)", df; + Family = "Bernoulli", Link = "Logit") + + @test !any(isnan.(mod.Fitted)) + @test !any(isinf.(mod.Fitted)) + @test all(0 .<= mod.Fitted .<= 1) + + # Should still separate classes well + @test mean(mod.Fitted[y .== 1]) > 0.9 + @test mean(mod.Fitted[y .== 0]) < 0.1 + end + + @testset "Input validation" begin + n = 50 + x = randn(n) + + # Non-binary data should throw error + y_continuous = randn(n) + df_invalid = DataFrame(y=y_continuous, x=x) + @test_throws AssertionError gam("y ~ s(x, k=5, degree=3)", df_invalid; + Family = "Bernoulli", Link = "Logit") + + # Binary data should work + y_binary = rand([0, 1], n) + df_valid = DataFrame(y=y_binary, x=x) + mod = gam("y ~ s(x, k=5, degree=3)", df_valid; + Family = "Bernoulli", Link = "Logit") + @test mod isa GAMData + end + + @testset "Monotonic relationships" begin + n = 200 + x = sort(randn(n) * 2) + + # Strong monotonic relationship + eta = 3 * x + p = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p)) + + df = DataFrame(y=y, x=x) + mod = gam("y ~ s(x, k=10, degree=3)", df; + Family = "Bernoulli", Link = "Logit") + + # Check monotonicity of fitted values + fitted_sorted = mod.Fitted[sortperm(x)] + differences = diff(fitted_sorted) + @test sum(differences .> 0) / length(differences) > 0.95 # Should be mostly increasing + + # Check effective degrees of freedom (should be relatively low for smooth monotonic) + @test 1.5 < mod.Diagnostics[:EDF] < 5 + end end - # Test 5: Convergence behavior - @testset "PIRLS convergence" begin - n = 100 - x1 = randn(n) - x2 = randn(n) + # =========================== + # Mixed Models Tests + # =========================== + @testset "Mixed Linear and Smooth Terms" begin - # Simple linear predictor for easier convergence - eta = 0.5 .+ x1 - 0.5*x2 - p = 1 ./ (1 .+ exp.(-eta)) - y = rand.(Bernoulli.(p)) + @testset "Gaussian with mixed terms" begin + df = dataset("datasets", "trees") + + mod = gam("Volume ~ s(Girth, k=8, degree=3) + Height", df) + + @test mod isa GAMData + @test length(mod.x) == 2 + @test length(mod.Basis) == 2 + + # Linear term should have minimal degrees of freedom + # (This assumes Height is the second term) + partial_height = GAM.PredictPartial(mod, 2) + @test length(unique(diff(partial_height[sortperm(df.Height)]))) < 5 # Should be nearly constant slope + end - df = DataFrame(y=y, x1=x1, x2=x2) + @testset "Bernoulli with mixed terms" begin + Random.seed!(456) + n = 500 + + # Linear and nonlinear effects + x_linear = randn(n) + x_smooth = sort(rand(n) * 2 .- 1) + + # True model: linear + smooth + beta_linear = 1.5 + f_smooth = 2 * sin.(3 * π * x_smooth) + eta = 0.5 .+ beta_linear * x_linear .+ f_smooth + p = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p)) + + df = DataFrame(y=y, x_linear=x_linear, x_smooth=x_smooth) + + mod = gam("y ~ x_linear + s(x_smooth, k=10, degree=3)", df; + Family="Bernoulli", Link="Logit") + + @test all(0 .<= mod.Fitted .<= 1) + + # Check that linear effect is captured + partial_linear = GAM.PredictPartial(mod, 1) + linear_coef = cov(partial_linear, x_linear) / var(x_linear) + @test abs(linear_coef - beta_linear) < 0.5 # Should be close to true value + + # Check that smooth captures nonlinearity + partial_smooth = GAM.PredictPartial(mod, 2) + smooth_sorted = partial_smooth[sortperm(x_smooth)] + + # Should have multiple turning points (capturing the sine wave) + second_diff = diff(diff(smooth_sorted)) + sign_changes = sum(diff(sign.(second_diff)) .!= 0) + @test sign_changes > 2 # Sine wave should have multiple inflection points + end + end + + # =========================== + # Poisson GAM Tests + # =========================== + @testset "Poisson GAM" begin + Random.seed!(789) + n = 200 - # Test with different convergence parameters - mod1 = gam("y ~ s(x1, k=5, degree=3) + s(x2, k=5, degree=3)", df; - Family = "Bernoulli", Link = "Logit", maxIter = 5) + @testset "Count data modeling" begin + x1 = randn(n) + x2 = rand(n) * 3 + + # Create count data + eta = 1.5 .+ 0.8 * x1 .+ sin.(2π * x2) + lambda = exp.(eta) + lambda = clamp.(lambda, 0.1, 20) # Keep reasonable for testing + y = rand.(Poisson.(lambda)) + + df = DataFrame(y=y, x1=x1, x2=x2) + + mod = gam("y ~ s(x1, k=6, degree=3) + s(x2, k=8, degree=3)", df; + Family = "Poisson", Link = "Log") + + @test mod.Family[:Name] == "Poisson" + @test mod.Link[:Name] == "Log" + @test all(mod.Fitted .>= 0) # Poisson predictions must be non-negative + + # Check mean-variance relationship (Poisson property) + # Group predictions into bins and check mean ≈ variance + n_bins = 5 + fitted_quantiles = quantile(mod.Fitted, range(0, 1, length=n_bins+1)) + for i in 1:n_bins + mask = fitted_quantiles[i] .<= mod.Fitted .<= fitted_quantiles[i+1] + if sum(mask) > 10 + observed = y[mask] + predicted_mean = mean(mod.Fitted[mask]) + observed_mean = mean(observed) + observed_var = var(observed) + # For Poisson, mean ≈ variance + @test abs(log(observed_mean + 1) - log(observed_var + 1)) < 1 + end + end + end + end + + # =========================== + # Basis Function Tests + # =========================== + @testset "Basis Function Properties" begin + x = randn(100) - mod2 = gam("y ~ s(x1, k=5, degree=3) + s(x2, k=5, degree=3)", df; - Family = "Bernoulli", Link = "Logit", maxIter = 50, tol = 1e-8) + @testset "Uniform basis" begin + basis = GAM.BuildUniformBasis(x, 10, 3) + @test length(basis.breakpoints) == 10 + @test basis.order == 3 + @test minimum(basis.breakpoints) ≈ minimum(x) + @test maximum(basis.breakpoints) ≈ maximum(x) + end - @test mod1 isa GAMData - @test mod2 isa GAMData + @testset "Basis matrix properties" begin + basis = GAM.BuildUniformBasis(x, 8, 3) + X = GAM.BuildBasisMatrix(basis, x) + + @test size(X, 1) == length(x) + @test size(X, 2) == length(basis) + @test all(0 .<= X .<= 1) # B-splines are bounded + @test all(sum(X, dims=2) .≈ 1) # Partition of unity + end - # More iterations with tighter tolerance should give similar or better fit - @test mod2.Diagnostics[:GCV] <= mod1.Diagnostics[:GCV] * 1.1 # Allow 10% tolerance + @testset "Penalty matrix" begin + basis = GAM.BuildUniformBasis(x, 10, 3) + D = GAM.BuildDifferenceMatrix(basis) + + @test size(D, 2) == length(basis) + @test size(D, 1) == length(basis) - 2 # Second differences + @test rank(D) == size(D, 1) # Should be full rank + end end - # Test 6: Comparison with known logistic regression result - @testset "Linear model special case" begin - n = 2000 - x = randn(n) + # =========================== + # Edge Cases and Robustness + # =========================== + @testset "Edge Cases" begin + + @testset "Small sample sizes" begin + n = 20 + x = randn(n) + y = randn(n) + df = DataFrame(x=x, y=y) + + # Should work with small n but limited knots + mod = gam("y ~ s(x, k=4, degree=2)", df) + @test mod isa GAMData + @test mod.Diagnostics[:EDF] <= 4 + end - # True linear model - beta0 = 0.5 - beta1 = 1.5 - eta = beta0 .+ beta1 * x - p = 1 ./ (1 .+ exp.(-eta)) - y = rand.(Bernoulli.(p)) + @testset "Perfect fit scenario" begin + # Create data that can be fit perfectly + x = [1.0, 2.0, 3.0, 4.0, 5.0] + y = [2.0, 4.0, 6.0, 8.0, 10.0] # Perfect linear + df = DataFrame(x=x, y=y) + + mod = gam("y ~ s(x, k=3, degree=2)", df) + @test maximum(abs.(y .- mod.Fitted)) < 0.1 # Should fit nearly perfectly + end - df = DataFrame(y=y, x=x) + @testset "Constant response" begin + n = 50 + x = randn(n) + y = ones(n) * 5.0 # Constant response + df = DataFrame(x=x, y=y) + + mod = gam("y ~ s(x, k=5, degree=3)", df) + @test std(mod.Fitted) < 0.1 # Fitted values should be nearly constant + @test mean(mod.Fitted) ≈ 5.0 atol=0.1 + end + end + + # =========================== + # Performance Tests + # =========================== + @testset "Performance and Scaling" begin - # Fit with many knots to approximate linear function - mod = gam("y ~ s(x, k=20, degree=3)", df; - Family = "Bernoulli", Link = "Logit") + @testset "Large dataset handling" begin + n = 1000 + x = randn(n) + y = 2 * sin.(x) + 0.5 * randn(n) + df = DataFrame(x=x, y=y) + + # Should complete in reasonable time + t = @elapsed mod = gam("y ~ s(x, k=20, degree=3)", df) + @test t < 30.0 # Should finish within 30 seconds + @test mod isa GAMData + end - # Check that predictions are reasonable - x_test = [-1.0, 0.0, 1.0] - for xi in x_test - pred_mat = GAM.BuildPredictionMatrix([xi], mod.Basis[1], mod.ColMeans[1]) - pred_eta = mod.Coef[1] .+ pred_mat * mod.Coef[mod.CoefIndex[1]] - pred_p = 1 / (1 + exp(-pred_eta[1])) + @testset "Multiple smooths scaling" begin + n = 200 + p = 5 # Number of predictors + + df = DataFrame() + formula_parts = String[] - true_p = 1 / (1 + exp(-(beta0 + beta1 * xi))) + for i in 1:p + df[!, Symbol("x$i")] = randn(n) + push!(formula_parts, "s(x$i, k=5, degree=3)") + end - # Should be reasonably close - @test abs(pred_p - true_p) < 0.4 + # Response is sum of nonlinear functions + y = zeros(n) + for i in 1:p + y .+= sin.(df[!, Symbol("x$i")]) + end + y .+= 0.5 * randn(n) + df.y = y + + formula = "y ~ " * join(formula_parts, " + ") + + t = @elapsed mod = gam(formula, df) + @test t < 60.0 # Should handle multiple smooths efficiently + @test mod isa GAMData + @test length(mod.Basis) == p end end -end \ No newline at end of file +end