diff --git a/Project.toml b/Project.toml index 78fdffc..c8a972b 100644 --- a/Project.toml +++ b/Project.toml @@ -13,3 +13,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +StatsModels = "3eaba693-59b7-5ba5-a881-562e759f1c8d" diff --git a/src/FitGAM.jl b/src/FitGAM.jl index 35a148a..55dbc5a 100644 --- a/src/FitGAM.jl +++ b/src/FitGAM.jl @@ -4,10 +4,15 @@ Fit generalised additive model. Usage: ```julia-repl -gam(ModelFormula, Data; Family, Link, Optimizer, maxIter, tol) +# Using string formula (original syntax) +gam("Y ~ s(MPG, k=5, degree=3) + WHT", Data) + +# Using @formula macro (new StatsModels.jl syntax) +using StatsModels +gam(@formula(Y ~ s(MPG, 5, 3) + WHT), Data) ``` Arguments: -- `ModelFormula` : `String` containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R, where `s()` has 3 parts: name of column, `k`` (integer denoting number of knots), and `degree` (polynomial degree of the spline). An example expression is `"Y ~ s(MPG, k=5, degree=3) + WHT + s(TRL, k=5, degree=2)"` +- `ModelFormula` : Either a `String` or `FormulaTerm` (@formula macro) containing the expression of the model. Continuous covariates are wrapped in s() like `mgcv` in R. For strings, use `s(var, k=N, degree=D)` syntax. For @formula macro, use `s(var, N, D)` with positional arguments. An example string expression is `"Y ~ s(MPG, k=5, degree=3) + WHT"` and an example @formula is `@formula(Y ~ s(MPG, 5, 3) + WHT)` - `Data` : `DataFrame` containing the covariates and response variable to use. - `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" - `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" @@ -52,3 +57,52 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) return gam end + +""" + gam(ModelFormula::FormulaTerm, Data; Family, Link, Optimizer, maxIter, tol) +Fit generalised additive model using StatsModels.jl @formula macro. + +Usage: +```julia-repl +using StatsModels +f = @formula(Y ~ s(MPG, 5, 3) + WHT) +gam(f, Data; Family="Normal", Link="Identity") +``` +Arguments: +- `ModelFormula` : `FormulaTerm` from StatsModels @formula macro +- `Data` : `DataFrame` containing the covariates and response variable to use. +- `Family` : `String` specifying Likelihood distribution. Should be one of "Normal", "Poisson", "Gamma", or "Bernoulli". Defaults to "Normal" +- `Link` : `String` specifying link function distribution. Should be one of "Identity", "Log", or "Logit". Defaults to "Identity" +- `Optimizer` : Algorithm to use for optimisation. Defaults to `NelderMead()`. +- `maxIter` : Maximum number of iterations for algorithm. Defaults to 25. +- `tol` : Tolerance for solver. Defaults to 1e-6. +""" +function gam(ModelFormula::FormulaTerm, Data::DataFrame; Family="Normal", Link="Identity", Optimizer = NelderMead(), maxIter = 25, tol = 1e-6) + # Delegate to the String version by parsing the FormulaTerm first + # This allows us to reuse all the existing logic + GAMForm = ParseFormula(ModelFormula) + + family_name = Dist_Map[Family] + family_name = Dists[family_name] + link_name = Link_Map[Link] + link_name = Links[link_name] + + # Extract response and covariates + y = Data[!, GAMForm.y] + + # Validate response for Bernoulli family + if Family == "Bernoulli" + @assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family" + end + + x = Data[!, GAMForm.covariates.variable] + BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)] + x = [x[!, col] for col in names(x)] + + # Build basis + Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs) + + # Fit PIRLS procedure + gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol) + return gam +end diff --git a/src/GAMFormula.jl b/src/GAMFormula.jl index d095e1c..d3b2d11 100644 --- a/src/GAMFormula.jl +++ b/src/GAMFormula.jl @@ -74,3 +74,101 @@ function ParseFormula(formula::String) outs = GAMFormula(Symbol(lhs), df) return outs end + +""" + ParseFormula(formula::FormulaTerm) +Parse StatsModels FormulaTerm into GAMFormula structure. + +Usage: +```julia-repl +f = @formula(y ~ s(x1, k=10, degree=3) + x2) +ParseFormula(f) +``` +Arguments: +- `formula` : `FormulaTerm` from StatsModels.jl @formula macro +""" +function ParseFormula(formula::FormulaTerm) + # Extract response variable + y = formula.lhs.sym + + # Process right-hand side terms + rhs = formula.rhs + + # Create DataFrame to hold covariate information + df = DataFrame(variable = Symbol[], k = Int[], degree = Int[], smooth = Bool[]) + + # Extract terms from the right-hand side + terms = extract_terms(rhs) + + for term in terms + if isa(term, SmoothTerm) + # Smooth term: extract k, degree, and variable name + push!(df, (term.term.sym, term.k, term.degree, true)) + elseif isa(term, Term) + # Linear term: add with default k=0, degree=0, smooth=false + push!(df, (term.sym, 0, 0, false)) + elseif isa(term, InterceptTerm) || isa(term, ConstantTerm) + # Intercept term - we handle this separately, skip for now + continue + else + @warn "Unsupported term type in formula: $(typeof(term)). Skipping." + end + end + + return GAMFormula(y, df) +end + +""" + extract_terms(rhs) +Recursively extract individual terms from the right-hand side of a formula. + +Handles different StatsModels term types including tuples, individual terms, and smooth terms. +""" +function extract_terms(rhs) + terms = [] + + if isa(rhs, Tuple) + # Multiple terms: recursively extract from each + for term in rhs + append!(terms, extract_terms(term)) + end + elseif isa(rhs, SmoothTerm) || isa(rhs, Term) + # Single term: add directly + push!(terms, rhs) + elseif isa(rhs, InterceptTerm) || isa(rhs, ConstantTerm) + # Intercept/constant: add directly + push!(terms, rhs) + elseif isa(rhs, StatsModels.FunctionTerm) + # Handle FunctionTerm (from @formula macro) + # Check if it's our s() function + if rhs.exorig.head == :call && rhs.exorig.args[1] == :s + # Extract arguments from the function call + # rhs.exorig.args[2] is the variable name + # rhs.exorig.args[3] is k (if present) + # rhs.exorig.args[4] is degree (if present) + var_sym = rhs.exorig.args[2] + k = length(rhs.exorig.args) >= 3 ? rhs.exorig.args[3] : 10 + degree = length(rhs.exorig.args) >= 4 ? rhs.exorig.args[4] : 3 + + # Create a SmoothTerm + push!(terms, SmoothTerm(Term(var_sym), k, degree)) + else + @warn "Unsupported function in formula: $(rhs.exorig.args[1])" + end + else + # Try to handle other StatsModels types + # For composite terms, try to extract nested terms + try + # If it has a .terms field (like CategoricalTerm, InteractionTerm, etc.) + if hasfield(typeof(rhs), :terms) + append!(terms, extract_terms(rhs.terms)) + else + @warn "Unable to extract terms from type $(typeof(rhs))" + end + catch e + @warn "Error extracting terms: $e" + end + end + + return terms +end diff --git a/src/GeneralizedAdditiveModels.jl b/src/GeneralizedAdditiveModels.jl index 726aa2b..3b0ad78 100644 --- a/src/GeneralizedAdditiveModels.jl +++ b/src/GeneralizedAdditiveModels.jl @@ -1,6 +1,8 @@ module GeneralizedAdditiveModels using Distributions, GLM, Optim, BSplines, LinearAlgebra, DataFrames, Plots, Optim +using StatsModels +using StatsModels: FormulaTerm, @formula, Term, ConstantTerm, InterceptTerm, AbstractTerm include("Links-Dists.jl") include("GAMData.jl") @@ -15,8 +17,9 @@ include("alpha.jl") include("PIRLS.jl") include("Predictions.jl") include("Plots.jl") -include("FitGAM.jl") +include("SmoothTerm.jl") include("GAMFormula.jl") +include("FitGAM.jl") export Links export Dists @@ -26,5 +29,6 @@ export GAMData export PartialDependencePlot export plotGAM export gam +export @formula, s, SmoothTerm, ParseFormula end diff --git a/src/SmoothTerm.jl b/src/SmoothTerm.jl new file mode 100644 index 0000000..72ca520 --- /dev/null +++ b/src/SmoothTerm.jl @@ -0,0 +1,64 @@ +#== + SmoothTerm + +Extension of StatsModels.jl for GAM smooth terms. + +This module defines custom term types for representing smooth functions in GAM formulas, +allowing syntax like: @formula(y ~ s(x1, k=10, degree=3) + x2) +==# + + +""" + SmoothTerm + +Represents a smooth spline term in a GAM formula. + +# Fields +- `term::Term`: The variable to be smoothed +- `k::Int`: Number of knots for the spline basis (default: 10) +- `degree::Int`: Polynomial degree of the spline (default: 3) +""" +struct SmoothTerm <: AbstractTerm + term::Term + k::Int + degree::Int +end + +# Constructor with default values +SmoothTerm(term::Term; k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) + +# Allow creating from a Symbol +SmoothTerm(sym::Symbol; k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) + +# Pretty printing +Base.show(io::IO, st::SmoothTerm) = print(io, "s($(st.term.sym), k=$(st.k), degree=$(st.degree))") + +""" + s(variable, k=10, degree=3) + +Create a smooth spline term for use in GAM formulas. + +# Arguments +- `variable`: The variable to be smoothed (Symbol or Term) +- `k`: Number of knots for the spline basis (default: 10) +- `degree`: Polynomial degree of the spline (default: 3) + +# Examples +```julia +using GeneralizedAdditiveModels, StatsModels + +# Using the @formula macro with smooth terms (positional arguments) +f = @formula(y ~ s(x1, 10, 3) + s(x2, 5, 2) + x3) + +# Or define smooth terms before the formula +s1 = s(:x1, 10, 3) +s2 = s(:x2, 5, 2) +# Note: You'll need to use the string formula syntax for pre-defined terms + +# Fit a GAM with the formula +model = gam(f, data) +``` +""" +# Positional argument versions (for use with @formula macro) +s(term::Term, k::Int=10, degree::Int=3) = SmoothTerm(term, k, degree) +s(sym::Symbol, k::Int=10, degree::Int=3) = SmoothTerm(Term(sym), k, degree) diff --git a/test/runtests.jl b/test/runtests.jl index f2f7b8a..1261d5b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -406,4 +406,113 @@ using Distributions, Random, Statistics, LinearAlgebra @test length(mod.Basis) == p end end + + # =========================== + # Formula Macro (@formula) Tests + # =========================== + @testset "Formula Macro Tests" begin + df = dataset("datasets", "trees") + + @testset "SmoothTerm construction" begin + st1 = s(:x1, 10, 3) + @test st1 isa SmoothTerm + @test st1.term.sym == :x1 + @test st1.k == 10 + @test st1.degree == 3 + + st2 = s(:x2) + @test st2.k == 10 + @test st2.degree == 3 + end + + @testset "Formula parsing from FormulaTerm" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 5, 2)) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.k[1] == 10 + @test gam_formula.covariates.degree[1] == 3 + @test gam_formula.covariates.smooth[1] == true + + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.k[2] == 5 + @test gam_formula.covariates.degree[2] == 2 + @test gam_formula.covariates.smooth[2] == true + end + + @testset "Formula with mixed smooth and linear terms" begin + f = @formula(Volume ~ s(Girth, 10, 3) + Height) + + gam_formula = ParseFormula(f) + @test gam_formula.y == :Volume + @test nrow(gam_formula.covariates) == 2 + + @test gam_formula.covariates.variable[1] == :Girth + @test gam_formula.covariates.smooth[1] == true + + @test gam_formula.covariates.variable[2] == :Height + @test gam_formula.covariates.smooth[2] == false + @test gam_formula.covariates.k[2] == 0 + @test gam_formula.covariates.degree[2] == 0 + end + + @testset "GAM fitting with @formula macro" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod = gam(f, df) + @test mod isa GAMData + @test length(mod.Fitted) == nrow(df) + + mod_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df) + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "GAM with @formula and different families" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + + mod_gamma = gam(f, df; Family="Gamma", Link="Log") + @test mod_gamma isa GAMData + @test mod_gamma.Family[:Name] == "Gamma" + @test mod_gamma.Link[:Name] == "Log" + + mod_gamma_string = gam("Volume ~ s(Girth, k=10, degree=3) + s(Height, k=10, degree=3)", df; + Family="Gamma", Link="Log") + @test isapprox(mod_gamma.Fitted, mod_gamma_string.Fitted, rtol=1e-6) + end + + @testset "Bernoulli GAM with @formula" begin + n = 200 + x1 = range(-2, 2, length=n) + x2 = randn(n) + + f1 = sin.(x1 * π/2) + f2 = x2.^2 .- 1 + eta = f1 + f2 + p = 1 ./ (1 .+ exp.(-eta)) + y = rand.(Bernoulli.(p)) + + df_bern = DataFrame(y=y, x1=x1, x2=x2) + + f = @formula(y ~ s(x1, 8, 3) + s(x2, 8, 3)) + mod = gam(f, df_bern; Family="Bernoulli", Link="Logit") + + @test mod isa GAMData + @test mod.Family[:Name] == "Bernoulli" + @test all(0 .<= mod.Fitted .<= 1) + + mod_string = gam("y ~ s(x1, k=8, degree=3) + s(x2, k=8, degree=3)", df_bern; + Family="Bernoulli", Link="Logit") + @test isapprox(mod.Fitted, mod_string.Fitted, rtol=1e-6) + end + + @testset "Plotting GAM fitted with @formula" begin + f = @formula(Volume ~ s(Girth, 10, 3) + s(Height, 10, 3)) + mod = gam(f, df) + + p = plotGAM(mod) + @test p isa Plots.Plot + end + end end