Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 43 additions & 20 deletions src/BuildBasis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,32 +144,55 @@ Arguments:
- `Basis` : `AbstractArray` containing the basis matrix.
"""
function BuildPenaltyMatrix(y, x, sp, Basis)

n = length(y)
n_knots = map(x -> length(x.breakpoints), Basis)
X = map(BuildBasisMatrix, Basis, x)
D = map(BuildDifferenceMatrix, Basis)
# Drop one column from X and D for identifiability
X = map(DropCol, X, n_knots)
D = map(DropCol, D, n_knots)

# Center
ColMeans = map(BuildBasisMatrixColMeans, X)
X = map(CenterBasisMatrix, X, ColMeans)

# Store Coef index
CoefIndex = BuildCoefIndex(X)

# Build Design Matrix
X = hcat(repeat([1],n), hcat(X...)) # add intercept

# Build Penalty Matrix
D = dcat(map((p, d) -> (sqrt(p) * d), sp, D))
D = hcat(repeat([0], size(D,1)), D) # add 0 for intercept
# Prepare per-term blocks
Xblocks = Vector{AbstractMatrix}(undef, length(x))
Dblocks = Vector{AbstractMatrix}(undef, length(x))
ColMeans = Vector{AbstractArray}(undef, length(x))

for i in eachindex(x)
bi = Basis[i]
xi = x[i]
if bi === :linear
# centered 1-column linear block; no penalty
μ = mean(xi)
Xi = reshape(xi .- μ, :, 1)
Di = zeros(1, 1)
Xblocks[i] = Xi
Dblocks[i] = Di
ColMeans[i] = reshape([μ], 1, :)
else
# smooth block
nk = length(bi.breakpoints)
Xi0 = BuildBasisMatrix(bi, xi)
Di0 = BuildDifferenceMatrix(bi)
Xi = DropCol(Xi0, nk)
Di = DropCol(Di0, nk)
# drop for identifiability
cm = BuildBasisMatrixColMeans(Xi)
Xi = CenterBasisMatrix(Xi, cm)
Xblocks[i] = Xi
Dblocks[i] = Di
ColMeans[i] = cm
end
end

# Coefficient index per-term by block width
CoefIndex = BuildCoefIndex(Xblocks)

# Assembled design with intercept
X = hcat(repeat([1], n), hcat(Xblocks...))

# Block-diagonal penalty (sqrt(sp) scaling on smooth blocks; linear blocks are zero already)
Dscaled = map((p, d) -> sqrt(p) .* d, sp, Dblocks)
D = dcat(Dscaled)
D = hcat(repeat([0], size(D, 1)), D) # intercept column

return X, y, D, ColMeans, CoefIndex
end


"""
HatMatrix(X, D, W)
Builds a hat matrix.
Expand Down
22 changes: 16 additions & 6 deletions src/FitGAM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,24 @@ function gam(ModelFormula::String, Data::DataFrame; Family="Normal", Link="Ident
@assert all(y .∈ Ref([0, 1])) "Response must be binary (0 or 1) for Bernoulli family"
end

x = Data[!, GAMForm.covariates.variable]
BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)]
x = [x[!, col] for col in names(x)]
# Collect covariate columns and term meta
xdf = Data[!, GAMForm.covariates.variable]
x = [xdf[!, col] for col in names(xdf)]
BasisArgs = [(GAMForm.covariates.k[i], GAMForm.covariates.degree[i]) for i in 1:nrow(GAMForm.covariates)]
smoothmask = collect(GAMForm.covariates.smooth)

# Build basis
Basis = map((xi, argi) -> BuildUniformBasis(xi, argi[1], argi[2]), x, BasisArgs)
# Per-term basis: smooth -> BSplineBasis, linear -> :linear
Basis = Vector{Any}(undef, length(x))
for i in eachindex(x)
if smoothmask[i]
k, degree = BasisArgs[i]
Basis[i] = BuildUniformBasis(x[i], k, degree)
else
Basis[i] = :linear
end
end

# Fit PIRLS procedure
gam = OptimPIRLS(y, x, Basis, family_name, link_name; Optimizer, maxIter, tol)
return gam
end
end
4 changes: 2 additions & 2 deletions src/GAMData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Arguments:
mutable struct GAMData
y::AbstractArray
x::AbstractArray
Basis::AbstractArray{BSplineBasis}
Basis::AbstractArray
Family::Dict
Link::Dict
Coef::AbstractArray
Expand All @@ -44,4 +44,4 @@ mutable struct GAMData
)
new(y, x, Basis, Family, Link, Coef, ColMeans, CoefIndex, Fitted, Diagnostics)
end
end
end
22 changes: 20 additions & 2 deletions src/PIRLS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,30 @@ function PIRLS(y, x, sp, Basis, Dist, Link; maxIter = 25, tol = 1e-6)
# Initial Predictions
n = length(y)

# μ init. Should replace with multiple dispatch logic
μmin = 1e-8
μmax = 1e12 # generous upper bound to avoid overflow in weights

if Dist[:Name] == "Bernoulli"
mu = clamp.(y, 1e-6, 1 - 1e-6)
p0 = clamp(mean(y), 1e-3, 1 - 1e-3)
mu = fill(p0, n)
elseif Dist[:Name] == "Poisson"
# Poisson needs strictly positive mu; start at max(y, small) or mean if all zeros
if all(==(0), y)
mu = fill(0.1, n)
else
mu = max.(Float64.(y), μmin)
end
elseif Dist[:Name] == "Gamma"
# Gamma also requires positive mu; start near the positive mean of y
ybar = max(mean(abs.(Float64.(y))), 1e-2)
mu = fill(ybar, n)
else
mu = y
# Gaussian and others: start at y (okay to be any real)
mu = Float64.(y)
end

mu = clamp.(mu, μmin, μmax)
eta = Link[:Function].(mu)

# Deviance
Expand Down
14 changes: 11 additions & 3 deletions src/Predictions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,15 @@ Arguments:
- `ix` : `Int` denoting the variable to plot.
"""
function PredictPartial(mod, ix)
predMatrix = BuildPredictionMatrix(mod.x[ix], mod.Basis[ix], mod.ColMeans[ix])
predBeta = mod.Coef[mod.CoefIndex[ix]]
return predMatrix * predBeta
bi = mod.Basis[ix]
if bi === :linear
μ = mod.ColMeans[ix][1] # stored as 1×1
Xi = reshape(mod.x[ix] .- μ, :, 1)
β = mod.Coef[mod.CoefIndex[ix]] # scalar
return Xi * β
else
predMatrix = BuildPredictionMatrix(mod.x[ix], bi, mod.ColMeans[ix])
predBeta = mod.Coef[mod.CoefIndex[ix]]
return predMatrix * predBeta
end
end
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
[deps]
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RDatasets = "ce6b1742-4840-55fa-b093-852dadbb1d8b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Loading