From b32695ef53af4905b0a2110f5f306a41fc160567 Mon Sep 17 00:00:00 2001 From: Ryan Senne <50930199+rsenne@users.noreply.github.com> Date: Fri, 28 Nov 2025 11:12:34 -0500 Subject: [PATCH] Add StatsModels.jl formula macro support for GAMs Introduces support for StatsModels.jl @formula macro and FormulaTerm parsing for generalized additive models (GAMs). Adds SmoothTerm type and s() constructor for smooth spline terms, updates FitGAM and GAMFormula to handle FormulaTerm input, and extends tests for formula macro usage and smooth term parsing. Project.toml updated to include StatsModels dependency. --- Project.toml | 1 + src/FitGAM.jl | 58 ++++++++++++++- src/GAMFormula.jl | 98 +++++++++++++++++++++++++ src/GeneralizedAdditiveModels.jl | 6 +- src/SmoothTerm.jl | 64 +++++++++++++++++ test/runtests.jl | 120 +++++++++++++++++++++++++++++++ 6 files changed, 344 insertions(+), 3 deletions(-) create mode 100644 src/SmoothTerm.jl diff --git a/Project.toml b/Project.toml index 78fdffc..c8a972b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,3 +13,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" diff --git a/src/FitGAM.jl b/src/FitGAM.jl index f2e069f..b5d1285 100644 --- a/src/FitGAM.jl +++ b/src/FitGAM.jl @@ -4,10 +4,15 @@ Fit generalised additive model. Usage: ```julia-repl -gam(ModelFormula, Data; Family, Link, Optimizer, maxIter, tol) +# Using string formula (original syntax) +gam("Y ~ s(MPG, k=5, degree=3) + WHT", Data) + +# Using @formula macro (new StatsModels.jl syntax) +using StatsModels +gam(@formula(Y ~ s(MPG, 5, 3) + WHT), Data) ``` Arguments: -- `ModelFormula` : `String` containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R, where `s()` has 3 parts: name of column, `k`` (integer denoting number of knots), and `degree` (polynomial degree of the spline). An example expression is `"Y ~ s(MPG, k=5, degree=3) + WHT + s(TRL, k=5, degree=2)"` +- `ModelFormula` : Either a `String` or `FormulaTerm` (@formula macro) containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R. For strings, use `s(var, k=N, degree=D)` syntax. For @formula macro, use `s(var, N, D)` with positional arguments. An example string expression is `"Y ~ s(MPG, k=5, degree=3) + WHT"` and an example @formula is `@formula(Y ~ s(MPG, 5, 3) + WHT)` - `Data` : `DataFrame` containing the covariates and response variable to use. - `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" - `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" @@ -38,6 +43,55 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident # Build basis Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + # Fit PIRLS procedure + gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) + return gam +end + +""" + gam(ModelFormula::FormulaTerm, Data; Family, Link, Optimizer, maxIter, tol) +Fit generalised additive model using StatsModels.jl @formula macro. + +Usage: +```julia-repl +using StatsModels +f = @formula(Y ~ s(MPG, 5, 3) + WHT) +gam(f, Data; Family="Normal", Link="Identity") +``` +Arguments: +- `ModelFormula` : `FormulaTerm` from StatsModels @formula macro +- `Data` : `DataFrame` containing the covariates and response variable to use. +- `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" +- `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" +- `Optimizer` : Algorithm to use for optimisation. Defaults to `NelderMead()`. +- `maxIter` : Maximum number of iterations for algorithm. Defaults to 25. +- `tol` : Tolerance for solver. Defaults to 1e-6. +""" +function gam(ModelFormula::FormulaTerm, Data::DataFrame; Family="Normal", Link="Identity", Optimizer = NelderMead(), maxIter = 25, tol = 1e-6) + # Delegate to the String version by parsing the FormulaTerm first + # This allows us to reuse all the existing logic + GAMForm = ParseFormula(ModelFormula) + + family_name = Dist_Map[Family] + family_name = Dists[family_name] + link_name = Link_Map[Link] + link_name = Links[link_name] + + # Extract response and covariates + y = Data[!, GAMForm.y] + + # Validate response for Bernoulli family + if Family == "Bernoulli" + @assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family" + end + + x = Data[!, GAMForm.covariates.variable] + BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)] + x = [x[!, col] for col in names(x)] + + # Build basis + Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + # Fit PIRLS procedure gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) return gam diff --git a/src/GAMFormula.jl b/src/GAMFormula.jl index d095e1c..d3b2d11 100644 --- a/src/GAMFormula.jl +++ b/src/GAMFormula.jl @@ -74,3 +74,101 @@ function ParseFormula(formula::String) outs = GAMFormula(Symbol(lhs), df) return outs end + +""" + ParseFormula(formula::FormulaTerm) +Parse StatsModels FormulaTerm into GAMFormula structure. + +Usage: +```julia-repl +f = @formula(y ~ s(x1, k=10, degree=3) + x2) +ParseFormula(f) +``` +Arguments: +- `formula` : `FormulaTerm` from StatsModels.jl @formula macro +""" +function ParseFormula(formula::FormulaTerm) + # Extract response variable + y = formula.lhs.sym + + # Process right-hand side terms + rhs = formula.rhs + + # Create DataFrame to hold covariate information + df = DataFrame(variable = Symbol[], k = Int[], degree = Int[], smooth = Bool[]) + + # Extract terms from the right-hand side + terms = extract_terms(rhs) + + for term in terms + if isa(term, SmoothTerm) + # Smooth term: extract k, degree, and variable name + push!(df, (term.term.sym, term.k, term.degree, true)) + elseif isa(term, Term) + # Linear term: add with default k=0, degree=0, smooth=false + push!(df, (term.sym, 0, 0, false)) + elseif isa(term, InterceptTerm) || isa(term, ConstantTerm) + # Intercept term - we handle this separately, skip for now + continue + else + @warn "Unsupported term type in formula: $(typeof(term)). Skipping." + end + end + + return GAMFormula(y, df) +end + +""" + extract_terms(rhs) +Recursively extract individual terms from the right-hand side of a formula. + +Handles different StatsModels term types including tuples, individual terms, and smooth terms. +""" +function extract_terms(rhs) + terms = [] + + if isa(rhs, Tuple) + # Multiple terms: recursively extract from each + for term in rhs + append!(terms, extract_terms(term)) + end + elseif isa(rhs, SmoothTerm) || isa(rhs, Term) + # Single term: add directly + push!(terms, rhs) + elseif isa(rhs, InterceptTerm) || isa(rhs, ConstantTerm) + # Intercept/constant: add directly + push!(terms, rhs) + elseif isa(rhs, StatsModels.FunctionTerm) + # Handle FunctionTerm (from @formula macro) + # Check if it's our s() function + if rhs.exorig.head == :call && rhs.exorig.args[1] == :s + # Extract arguments from the function call + # rhs.exorig.args[2] is the variable name + # rhs.exorig.args[3] is k (if present) + # rhs.exorig.args[4] is degree (if present) + var_sym = rhs.exorig.args[2] + k = length(rhs.exorig.args) >= 3 ? rhs.exorig.args[3] : 10 + degree = length(rhs.exorig.args) >= 4 ? rhs.exorig.args[4] : 3 + + # Create a SmoothTerm + push!(terms, SmoothTerm(Term(var_sym), k, degree)) + else + @warn "Unsupported function in formula: $(rhs.exorig.args[1])" + end + else + # Try to handle other StatsModels types + # For composite terms, try to extract nested terms + try + # If it has a .terms field (like CategoricalTerm, InteractionTerm, etc.) + if hasfield(typeof(rhs), :terms) + append!(terms, extract_terms(rhs.terms)) + else + @warn "Unable to extract terms from type $(typeof(rhs))" + end + catch e + @warn "Error extracting terms: $e" + end + end + + return terms +end diff --git a/src/GeneralizedAdditiveModels.jl b/src/GeneralizedAdditiveModels.jl index 726aa2b..3b0ad78 100644 --- a/src/GeneralizedAdditiveModels.jl +++ b/src/GeneralizedAdditiveModels.jl @@ -1,6 +1,8 @@ module GeneralizedAdditiveModels using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim +using StatsModels +using StatsModels: FormulaTerm, @formula, Term, ConstantTerm, InterceptTerm, AbstractTerm include("Links-Dists.jl") include("GAMData.jl") @@ -15,8 +17,9 @@ include("alpha.jl") include("PIRLS.jl") include("Predictions.jl") include("Plots.jl") -include("FitGAM.jl") +include("SmoothTerm.jl") include("GAMFormula.jl") +include("FitGAM.jl") export Links export Dists @@ -26,5 +29,6 @@ export GAMData export PartialDependencePlot export plotGAM export gam +export @formula, s, SmoothTerm, ParseFormula end diff --git a/src/SmoothTerm.jl b/src/SmoothTerm.jl new file mode 100644 index 0000000..72ca520 --- /dev/null +++ b/src/SmoothTerm.jl @@ -0,0 +1,64 @@ +#== + SmoothTerm + +Extension of StatsModels.jl for GAM smooth terms. + +This module defines custom term types for representing smooth functions in GAM formulas, +allowing syntax like: @formula(y ~ s(x1, k=10, degree=3) + x2) +==# + + +""" + SmoothTerm + +Represents a smooth spline term in a GAM formula. + +# Fields +- `term::Term`: The variable to be smoothed +- `k::Int`: Number of knots for the spline basis (default: 10) +- `degree::Int`: Polynomial degree of the spline (default: 3) +""" +struct SmoothTerm <: AbstractTerm + term::Term + k::Int + degree::Int +end + +# Constructor with default values +SmoothTerm(term::Term; k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) + +# Allow creating from a Symbol +SmoothTerm(sym::Symbol; k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) + +# Pretty printing +Base.show(io::IO, st::SmoothTerm) = print(io, "s($(st.term.sym), k=$(st.k), degree=$(st.degree))") + +""" + s(variable, k=10, degree=3) + +Create a smooth spline term for use in GAM formulas. + +# Arguments +- `variable`: The variable to be smoothed (Symbol or Term) +- `k`: Number of knots for the spline basis (default: 10) +- `degree`: Polynomial degree of the spline (default: 3) + +# Examples +```julia +using GeneralizedAdditiveModels, StatsModels + +# Using the @formula macro with smooth terms (positional arguments) +f = @formula(y ~ s(x1, 10, 3) + s(x2, 5, 2) + x3) + +# Or define smooth terms before the formula +s1 = s(:x1, 10, 3) +s2 = s(:x2, 5, 2) +# Note: You'll need to use the string formula syntax for pre-defined terms + +# Fit a GAM with the formula +model = gam(f, data) +``` +""" +# Positional argument versions (for use with @formula macro) +s(term::Term, k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) +s(sym::Symbol, k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) diff --git a/test/runtests.jl b/test/runtests.jl index 644f0c1..e1f98e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -178,4 +178,124 @@ end @test abs(pred_p - true_p) < 0.4 end end +end + +@testset "Formula Macro Tests" begin + @testset "SmoothTerm construction" begin + # Test creating smooth terms with positional arguments + st1 = s(:x1, 10, 3) + @test st1 isa SmoothTerm + @test st1.term.sym == :x1 + @test st1.k == 10 + @test st1.degree == 3 + + # Test default values + st2 = s(:x2) + @test st2.k == 10 # default + @test st2.degree == 3 # default + end + + @testset "Formula parsing from FormulaTerm" begin + # Test parsing a simple formula with smooth terms + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 5, 2)) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.k[1] == 10 + @test gam_formula.covariates.degree[1] == 3 + @test gam_formula.covariates.smooth[1] == true + + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.k[2] == 5 + @test gam_formula.covariates.degree[2] == 2 + @test gam_formula.covariates.smooth[2] == true + end + + @testset "Formula with mixed smooth and linear terms" begin + # Test formula with both smooth and linear terms + f = @formula(Volume ~ s(Girth, 10, 3) + Height) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + + # First term is smooth + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.smooth[1] == true + + # Second term is linear + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.smooth[2] == false + @test gam_formula.covariates.k[2] == 0 + @test gam_formula.covariates.degree[2] == 0 + end + + @testset "GAM fitting with @formula macro" begin + # Test fitting a GAM using the @formula macro + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod = gam(f, df) + @test mod isa GAMData + @test length(mod.Fitted) == nrow(df) + + # Compare with string formula version + mod_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + + # Results should be very similar (allowing for numerical precision) + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "GAM with @formula and different families" begin + # Test with Gamma family + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod_gamma = gam(f, df; Family="Gamma", Link="Log") + @test mod_gamma isa GAMData + @test mod_gamma.Family[:Name] == "Gamma" + @test mod_gamma.Link[:Name] == "Log" + + # Compare with string formula version + mod_gamma_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; + Family="Gamma", Link="Log") + @test isapprox(mod_gamma.Fitted, mod_gamma_string.Fitted, rtol=1e-6) + end + + @testset "Bernoulli GAM with @formula" begin + # Create binary data + n = 200 + x1 = range(-2, 2, length=n) + x2 = randn(n) + + # Create true nonlinear effect + f1 = sin.(x1 * π/2) + f2 = x2.^2 .- 1 + eta = f1 + f2 + p = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p)) + + df_bern = DataFrame(y=y, x1=x1, x2=x2) + + # Fit using @formula + f = @formula(y ~ s(x1, 8, 3) + s(x2, 8, 3)) + mod = gam(f, df_bern; Family="Bernoulli", Link="Logit") + + @test mod isa GAMData + @test mod.Family[:Name] == "Bernoulli" + @test all(0 .<= mod.Fitted .<= 1) + + # Compare with string version + mod_string = gam("y ~ s(x1, k=8, degree=3) + s(x2, k=8, degree=3)", df_bern; + Family="Bernoulli", Link="Logit") + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "Plotting GAM fitted with @formula" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + mod = gam(f, df) + + p = plotGAM(mod) + @test p isa Plots.Plot + end end \ No newline at end of file