Skip to content

Commit

Permalink
Recursively propagate kwargs through update_coefficients!
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Mar 11, 2023
1 parent 31275d4 commit 7143793
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 51 deletions.
7 changes: 7 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
are still compatible with essentially any Krylov method which would otherwise be compatible with
matrix-free representations, hence their support in the SciMLOperators interface.

## Note about keyword arguments to `update_coefficients!`

In rare cases, an operator may be used in a context where additional state is expected to be provided
to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional
state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator.
For the [premade SciMLOperators](premade_operators.md), one can specify the additional state used by an operator with an `accepted_kwarg_fields` argument that defaults to an empty tuple.
22 changes: 12 additions & 10 deletions src/batch.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
#
"""
BatchedDiagonalOperator(diag, [; update_func])
BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=())
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
by `update_coefficients!` and is assumed to have the following signature:
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractVector,u,p,t; <accepted kwarg fields>) -> [modifies diag]
"""
struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T}
diag::D
update_func::F

function BatchedDiagonalOperator(
diag::AbstractArray;
update_func=DEFAULT_UPDATE_FUNC
update_func=nothing,
accepted_kwarg_fields=()
)
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{
eltype(diag),
typeof(diag),
typeof(update_func)
typeof(_update_func)
}(
diag, update_func,
diag, _update_func,
)
end
end

function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC)
BatchedDiagonalOperator(u; update_func=update_func)
function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwarg_fields=())
BatchedDiagonalOperator(u; update_func, accepted_kwarg_fields)
end

# traits
Expand All @@ -40,7 +42,7 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
update_func = if isreal(L)
L.update_func
else
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
end
BatchedDiagonalOperator(diag; update_func=update_func)
end
Expand All @@ -57,15 +59,15 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
isconstant(L::BatchedDiagonalOperator) = update_func_isconstant(L.update_func)
islinear(::BatchedDiagonalOperator) = true
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)

getops(L::BatchedDiagonalOperator) = (L.diag,)

update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing)
update_coefficients!(L::BatchedDiagonalOperator,u,p,t; kwargs...) = (L.update_func(L.diag,u,p,t; kwargs...); nothing)

# operator application
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
Expand Down
4 changes: 2 additions & 2 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ function update_coefficients(L::FunctionOperator, u, p, t)
)
end

function update_coefficients!(L::FunctionOperator, u, p, t)
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
ops = getops(L)
for op in ops
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end

L.p = p
Expand Down
24 changes: 18 additions & 6 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,31 @@ out-of-place form B = update_coefficients(A,u,p,t).
"""
function (::AbstractSciMLOperator) end

# Utilities for update functions
DEFAULT_UPDATE_FUNC(A,u,p,t) = A
function preprocess_update_func(update_func, accepted_kwarg_fields)
update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func
return FilterKwargs(update_func, accepted_kwarg_fields)
end
function update_func_isconstant(update_func)
if update_func isa FilterKwargs
return update_func.f == DEFAULT_UPDATE_FUNC
else
return update_func == DEFAULT_UPDATE_FUNC
end
end

update_coefficients!(L,u,p,t) = nothing
update_coefficients(L,u,p,t) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
update_coefficients!(L,u,p,t; kwargs...) = nothing
update_coefficients(L,u,p,t; kwargs...) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...)
for op in getops(L)
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end
nothing
end

(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
(L::AbstractSciMLOperator)(u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); L * u)
(L::AbstractSciMLOperator)(du, u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u))

###
# caching interface
Expand Down
63 changes: 37 additions & 26 deletions src/matrix.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#
"""
MatrixOperator(A[; update_func])
MatrixOperator(A; update_func=nothing, accepted_kwarg_fields=())
Represents a time-dependent linear operator given by an AbstractMatrix. The
update function is called by `update_coefficients!` and is assumed to have
the following signature:
update_func(A::AbstractMatrix,u,p,t) -> [modifies A]
update_func(A::AbstractMatrix,u,p,t; <accepted kwarg fields>) -> [modifies A]
"""
struct MatrixOperator{T,AType<:AbstractMatrix{T},F} <: AbstractSciMLOperator{T}
A::AType
update_func::F
MatrixOperator(A::AType; update_func=DEFAULT_UPDATE_FUNC) where{AType} =
new{eltype(A),AType,typeof(update_func)}(A, update_func)
function MatrixOperator(A::AType; update_func=nothing, accepted_kwarg_fields=()) where {AType}
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{eltype(A),AType,typeof(_update_func)}(A, _update_func)
end
end

# constructors
Expand All @@ -39,21 +41,21 @@ for op in (
if isconstant(L)
MatrixOperator($op(L.A))
else
update_func = (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t))
update_func = (A,u,p,t; kwargs...) -> $op(L.update_func($op(L.A),u,p,t; kwargs...))
MatrixOperator($op(L.A); update_func = update_func)
end
end
end
Base.conj(L::MatrixOperator) = MatrixOperator(
conj(L.A);
update_func= (A,u,b,t) -> conj(L.update_func(conj(L.A),u,p,t))
update_func= (A,u,p,t; kwargs...) -> conj(L.update_func(conj(L.A),u,p,t; kwargs...))
)

has_adjoint(A::MatrixOperator) = has_adjoint(A.A)
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing)
update_coefficients!(L::MatrixOperator,u,p,t; kwargs...) = (L.update_func(L.A,u,p,t; kwargs...); nothing)

getops(L::MatrixOperator) = (L.A)
isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
isconstant(L::MatrixOperator) = update_func_isconstant(L.update_func)
Base.iszero(L::MatrixOperator) = iszero(L.A)

SparseArrays.sparse(L::MatrixOperator) = sparse(L.A)
Expand Down Expand Up @@ -88,13 +90,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat)
LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u)

"""
DiagonalOperator(diag, [; update_func])
DiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=())
Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
The update function is called by `update_coefficients!` and is assumed to have
the following signature:
update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractVector,u,p,t; <accepted kwarg fields>) -> [modifies diag]
When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)`
can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u`
Expand All @@ -105,11 +107,12 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
with leading length `size(u, 1) = N`.
"""
function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC)
diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
DEFAULT_UPDATE_FUNC
function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
diag_update_func = if update_func_isconstant(_update_func)
_update_func
else
(A, u, p, t) -> (update_func(A.diag, u, p, t); A)
(A, u, p, t; kwargs...) -> (_update_func(A.diag, u, p, t; kwargs...); A)
end
MatrixOperator(Diagonal(diag); update_func=diag_update_func)
end
Expand Down Expand Up @@ -202,13 +205,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr
LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u)

"""
L = AffineOperator(A, B, b[; update_func])
L = AffineOperator(A, B, b; update_func=nothing, accepted_kwarg_fields=())
L(u) = A*u + B*b
Represents a time-dependent affine operator. The update function is called
by `update_coefficients!` and is assumed to have the following signature:
update_func(b::AbstractArray,u,p,t) -> [modifies b]
update_func(b::AbstractArray,u,p,t; <accepted kwarg fields>) -> [modifies b]
"""
struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
A::AType
Expand Down Expand Up @@ -236,44 +239,52 @@ end
function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator},
B::Union{AbstractMatrix,AbstractSciMLOperator},
b::AbstractArray;
update_func = DEFAULT_UPDATE_FUNC,
update_func=nothing,
accepted_kwarg_fields=()
)
@assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors
of same size"

_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

A = A isa AbstractMatrix ? MatrixOperator(A) : A
B = B isa AbstractMatrix ? MatrixOperator(B) : B
cache = B * b

AffineOperator(A, B, b, cache, update_func)
AffineOperator(A, B, b, cache, _update_func)
end

"""
L = AddVector(b[; update_func])
L = AddVector(b; update_func=nothing, accepted_kwarg_fields=())
L(u) = u + b
"""
function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

N = size(b, 1)
Id = IdentityOperator(N)

AffineOperator(Id, Id, b; update_func=update_func)
AffineOperator(Id, Id, b; update_func=_update_func)
end

"""
L = AddVector(B, b[; update_func])
L = AddVector(B, b; update_func=nothing, accepted_kwarg_fields=())
L(u) = u + B*b
"""
function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

N = size(B, 1)
Id = IdentityOperator(N)

AffineOperator(Id, B, b; update_func=update_func)
AffineOperator(Id, B, b; update_func=_update_func)
end

getops(L::AffineOperator) = (L.A, L.B, L.b)

update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing)
isconstant(L::AffineOperator) = (L.update_func == DEFAULT_UPDATE_FUNC) & all(isconstant, (L.A, L.B))
update_coefficients!(L::AffineOperator,u,p,t; kwargs...) = (L.update_func(L.b,u,p,t; kwargs...); nothing)
isconstant(L::AffineOperator) = update_func_isconstant(L.update_func) & all(isconstant, (L.A, L.B))

islinear(::AffineOperator) = false

Base.size(L::AffineOperator) = size(L.A)
Expand Down
16 changes: 9 additions & 7 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,24 @@ end
Base.:+::AbstractSciMLScalarOperator) = α

"""
ScalarOperator(val[; update_func])
ScalarOperator(val; update_func=nothing, accepted_kwarg_fields=())
(α::ScalarOperator)(a::Number) = α * a
Represents a time-dependent scalar/scaling operator. The update function
is called by `update_coefficients!` and is assumed to have the following
signature:
update_func(oldval,u,p,t) -> newval
update_func(oldval,u,p,t; <accepted kwarg fields>) -> newval
"""
mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T}
val::T
update_func::F

ScalarOperator(val::T; update_func=DEFAULT_UPDATE_FUNC) where{T} =
new{T,typeof(update_func)}(val, update_func)
function ScalarOperator(val::T; update_func=nothing, accepted_kwarg_fields=()) where {T}
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{T,typeof(_update_func)}(val, _update_func)
end
end

# constructors
Expand All @@ -118,7 +120,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ)
# traits
function Base.conj::ScalarOperator) # TODO - test
val = conj.val)
update_func = (oldval,u,p,t) -> α.update_func(oldval |> conj,u,p,t) |> conj
update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj
ScalarOperator(val; update_func=update_func)
end

Expand All @@ -132,11 +134,11 @@ Base.abs(α::ScalarOperator) = abs(α.val)
Base.iszero::ScalarOperator) = iszero.val)

getops::ScalarOperator) =.val,)
isconstant::ScalarOperator) = α.update_func == DEFAULT_UPDATE_FUNC
isconstant::ScalarOperator) = update_func_isconstant(α.update_func)
has_ldiv::ScalarOperator) = !iszero.val)
has_ldiv!::ScalarOperator) = has_ldiv(α)

update_coefficients!(L::ScalarOperator,u,p,t) = (L.val = L.update_func(L.val,u,p,t); nothing)
update_coefficients!(L::ScalarOperator,u,p,t; kwargs...) = (L.val = L.update_func(L.val,u,p,t; kwargs...); nothing)

"""
Lazy addition of Scalar Operators
Expand Down
10 changes: 10 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@ end
dims(A) = length(size(A))
dims(::AbstractArray{<:Any,N}) where{N} = N
dims(::AbstractSciMLOperator) = 2

# Keyword argument filtering
struct FilterKwargs{F,K}
f::F
accepted_kwarg_fields::K
end
function (f_filter::FilterKwargs)(args...; kwargs...)
filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwarg_fields)
f_filter.f(args...; filtered_kwargs...)
end
#
11 changes: 11 additions & 0 deletions test/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,16 @@ end
@test num(v,u,p,t) val * u

@test convert(Number, num) val

# Test scalar operator which expects keyword argument to update, modeled in the style of a DiffEq W-operator.
γ = ScalarOperator(0.0; update_func=(args...; dtgamma) -> dtgamma, accepted_kwarg_fields=(:dtgamma,))

dtgamma = rand()
@test γ(u,p,t; dtgamma) dtgamma * u
@test γ(v,u,p,t; dtgamma) dtgamma * u

γ_added = γ + α
@test γ_added(u,p,t; dtgamma) (dtgamma + p) * u
@test γ_added(v,u,p,t; dtgamma) (dtgamma + p) * u
end
#

0 comments on commit 7143793

Please sign in to comment.