Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate kwargs through update_coefficients! #143

Merged
merged 24 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1ed74b6
Recursively propagate kwargs through update_coefficients!
gaurav-arya Jan 30, 2023
293d5eb
Rename accepted_kwarg_fields -> accepted_kwargs
gaurav-arya Mar 12, 2023
1712594
Allow accepted_kwargs=nothing to indicate no wrapping
gaurav-arya Mar 12, 2023
cd70503
Tweak keyword filtering logic
gaurav-arya Mar 12, 2023
17298fb
Test operator update (including kwarg update) in operator algebra test
gaurav-arya Mar 12, 2023
dc3e31d
Support kwargs in function operator
gaurav-arya Mar 12, 2023
5b47ef9
Propagate kwargs for out-of-place function operator update_coefficien…
gaurav-arya Mar 12, 2023
88d5050
Catch function operator error for empty kwargs
gaurav-arya Mar 12, 2023
c7fcd51
Address code review suggestions on diag op construction
gaurav-arya Mar 12, 2023
24aef00
Improve logic for normalizing kwargs
gaurav-arya Mar 12, 2023
44d66bb
Test operator application form in operator algebra test set
gaurav-arya Mar 12, 2023
9d8fdff
Support kwargs in function operator functionals
gaurav-arya Mar 12, 2023
f793c60
Add example
gaurav-arya Mar 12, 2023
3312967
Remove unncessary function call
gaurav-arya Mar 12, 2023
69f0ecc
Rename kwargs_for_op -> accepted_kwargs
gaurav-arya Mar 12, 2023
fef7618
Fix function operator out-of-place update coefficients
gaurav-arya Mar 12, 2023
ceca67c
Use NoKwargFilter() to bypass keyword filtering (rather than nothing)
gaurav-arya Mar 12, 2023
0efa2ce
Remove debug line
gaurav-arya Mar 12, 2023
2adb43f
Merge branch 'master' into ag-kwargs
vpuri3 May 27, 2023
e63e445
fix diagonaloperator update
vpuri3 May 27, 2023
e89d5a1
function op working
vpuri3 May 29, 2023
4818a66
moved kwargs to FunctionOp.traits
vpuri3 May 29, 2023
3854874
tests passing
vpuri3 May 29, 2023
3a665ca
Base.Pairs notdef in LTS
vpuri3 May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,26 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
are still compatible with essentially any Krylov method which would otherwise be compatible with
matrix-free representations, hence their support in the SciMLOperators interface.

## Note about keyword arguments to `update_coefficients!`

In rare cases, an operator may be used in a context where additional state is expected to be provided
to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional
state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator.
For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwargs` argument (by default, none are passed).

In the below example, we create an operator that gleefully ignores `u`, `p`, and `t` and uses its own special scaling.
```@example
using SciMLOperators

γ = ScalarOperator(0.0; update_func=(a, u, p, t; my_special_scaling) -> my_special_scaling,
accepted_kwargs=(:my_special_scaling,))

# Update coefficients, then apply operator
update_coefficients!(γ, nothing, nothing, nothing; my_special_scaling=7.0)
@show γ * [2.0]

# Use operator application form
@show γ([2.0], nothing, nothing; my_special_scaling = 5.0)
nothing # hide
```
57 changes: 30 additions & 27 deletions src/batch.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
#
"""
BatchedDiagonalOperator(diag, [; update_func])
BatchedDiagonalOperator(diag; update_func, update_func!, accepted_kwargs)

Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
by `update_coefficients!` and is assumed to have the following signature:

update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractArray, u, p, t; <accepted kwarg fields>) -> [modifies diag]
"""
struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T}
diag::D
update_func::F
update_func!::F!

function BatchedDiagonalOperator(diag::AbstractArray, update_func, update_func!)

new{
eltype(diag),
typeof(diag),
Expand All @@ -25,15 +26,16 @@ struct BatchedDiagonalOperator{T,D,F,F!} <: AbstractSciMLOperator{T}
end
end

function BatchedDiagonalOperator(diag::AbstractArray;
update_func = DEFAULT_UPDATE_FUNC,
update_func! = DEFAULT_UPDATE_FUNC)
BatchedDiagonalOperator(diag, update_func, update_func!)
end
function DiagonalOperator(u::AbstractArray;
update_func = DEFAULT_UPDATE_FUNC,
update_func! = DEFAULT_UPDATE_FUNC,
accepted_kwargs = nothing
)

update_func = preprocess_update_func(update_func , accepted_kwargs)
update_func! = preprocess_update_func(update_func!, accepted_kwargs)

function DiagonalOperator(u::AbstractArray; update_func = DEFAULT_UPDATE_FUNC,
update_func! = DEFAULT_UPDATE_FUNC)
BatchedDiagonalOperator(u; update_func = update_func, update_func! = update_func!)
BatchedDiagonalOperator(u, update_func, update_func!)
end

# traits
Expand All @@ -46,38 +48,39 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
update_func = if isreal(L)
L.update_func
else
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
end
BatchedDiagonalOperator(diag; update_func=update_func)
end

function update_coefficients(L::BatchedDiagonalOperator,u,p,t)
@set! L.diag = L.update_func(L.diag,u,p,t)
LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
if isreal(L)
true
else
vec(L.diag) |> Diagonal |> ishermitian
end
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

function update_coefficients(L::BatchedDiagonalOperator,u ,p, t; kwargs...)
@set! L.diag = L.update_func(L.diag, u, p, t; kwargs...)
end

function update_coefficients!(L::BatchedDiagonalOperator, u, p, t; kwargs...)
L.update_func!(L.diag, u, p, t; kwargs...)
end
update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func!(L.diag,u,p,t); L)

getops(L::BatchedDiagonalOperator) = (L.diag,)

function isconstant(L::BatchedDiagonalOperator)
L.update_func == L.update_func! == DEFAULT_UPDATE_FUNC
update_func_isconstant(L.update_func) & update_func_isconstant(L.update_func!)
end
islinear(::BatchedDiagonalOperator) = true
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)

LinearAlgebra.issymmetric(L::BatchedDiagonalOperator) = true
function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
if isreal(L)
true
else
d = vec(L.diag)
D = Diagonal(d)
ishermitian(d)
end
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

# operator application
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
Base.:\(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .\ u
Expand Down
78 changes: 40 additions & 38 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"""
Matrix free operators (given by a function)
"""
mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <: AbstractSciMLOperator{T}
mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,K,C} <: AbstractSciMLOperator{T}
""" Function with signature op(u, p, t) and (if isinplace) op(du, u, p, t) """
op::F
""" Adjoint operator"""
Expand All @@ -17,6 +17,8 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
p::P
""" Time """
t::Tt
""" Keyword arguments """
kwargs::K
""" Cache """
cache::C

Expand All @@ -28,6 +30,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
traits,
p,
t,
accepted_kwargs,
cache
)

Expand All @@ -48,6 +51,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
typeof(traits),
typeof(p),
typeof(t),
typeof(accepted_kwargs),
typeof(cache),
}(
op,
Expand All @@ -57,6 +61,7 @@ mutable struct FunctionOperator{iip,oop,mul5,T<:Number,F,Fa,Fi,Fai,Tr,P,Tt,C} <:
traits,
p,
t,
accepted_kwargs,
cache,
)
end
Expand Down Expand Up @@ -84,6 +89,7 @@ function FunctionOperator(op,
FunctionOperator(op, input, output; kwargs...)
end

# TODO: document constructor and revisit design as needed (e.g. for "accepted_kwargs")
function FunctionOperator(op,
input::AbstractVecOrMat,
output::AbstractVecOrMat = input;
Expand All @@ -101,6 +107,7 @@ function FunctionOperator(op,

p=nothing,
t::Union{Number,Nothing}=nothing,
accepted_kwargs = (),

ifcache::Bool = true,

Expand Down Expand Up @@ -191,7 +198,8 @@ function FunctionOperator(op,
traits,
p,
t,
cache,
normalize_kwargs(accepted_kwargs),
cache
)

if ifcache & isnothing(L.cache)
Expand All @@ -201,31 +209,29 @@ function FunctionOperator(op,
L
end

function update_coefficients(L::FunctionOperator, u, p, t)

if isconstant(L)
return L
end

@set! L.op = update_coefficients(L.op, u, p, t)
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t)
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t)
@set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t)
function update_coefficients(L::FunctionOperator, u, p, t; kwargs...)

@set! L.p = p
@set! L.t = t

L
isconstant(L) && return L

filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg))

@set! L.op = update_coefficients(L.op, u, p, t; filtered_kwargs...)
@set! L.op_adjoint = update_coefficients(L.op_adjoint, u, p, t; filtered_kwargs...)
@set! L.op_inverse = update_coefficients(L.op_inverse, u, p, t; filtered_kwargs...)
@set! L.op_adjoint_inverse = update_coefficients(L.op_adjoint_inverse, u, p, t; filtered_kwargs...)
end

function update_coefficients!(L::FunctionOperator, u, p, t)
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)

if isconstant(L)
return L
end
isconstant(L) && return

filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in L.kwargs if haskey(kwargs, kwarg))

for op in getops(L)
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; filtered_kwargs...)
end

L.p = p
Expand Down Expand Up @@ -267,9 +273,6 @@ function Base.adjoint(L::FunctionOperator)
@set! traits.size = reverse(size(L))
@set! traits.eltypes = reverse(traits.eltypes)

p = L.p
t = L.t

cache = if iscached(L)
cache = reverse(L.cache)
else
Expand All @@ -281,8 +284,9 @@ function Base.adjoint(L::FunctionOperator)
op_inverse,
op_adjoint_inverse,
traits,
p,
t,
L.p,
L.t,
L.kwargs,
cache,
)
end
Expand Down Expand Up @@ -310,9 +314,6 @@ function Base.inv(L::FunctionOperator)
(p::Real) -> 1 / traits.opnorm(p)
end

p = L.p
t = L.t

cache = if iscached(L)
cache = reverse(L.cache)
else
Expand All @@ -324,8 +325,9 @@ function Base.inv(L::FunctionOperator)
op_inverse,
op_adjoint_inverse,
traits,
p,
t,
L.p,
L.t,
L.kwargs,
cache,
)
end
Expand Down Expand Up @@ -353,8 +355,8 @@ function LinearAlgebra.opnorm(L::FunctionOperator, p)
argument. E.g., `(p::Real) -> p == Inf ? 100 : error("only Inf norm is
defined")`
""")
opn = L.opnorm
return opn isa Number ? opn : L.opnorm(p)
opn = L.traits.opnorm
return opn isa Number ? opn : L.traits.opnorm(p)
end
LinearAlgebra.issymmetric(L::FunctionOperator) = L.traits.issymmetric
LinearAlgebra.ishermitian(L::FunctionOperator) = L.traits.ishermitian
Expand All @@ -373,31 +375,31 @@ end
islinear(L::FunctionOperator) = L.traits.islinear
isconstant(L::FunctionOperator) = L.traits.isconstant
has_adjoint(L::FunctionOperator) = !(L.op_adjoint isa Nothing)
has_mul(L::FunctionOperator{iip}) where{iip} = true
has_mul!(L::FunctionOperator{iip}) where{iip} = iip
has_mul(::FunctionOperator{iip}) where{iip} = true
has_mul!(::FunctionOperator{iip}) where{iip} = iip
has_ldiv(L::FunctionOperator{iip}) where{iip} = !(L.op_inverse isa Nothing)
has_ldiv!(L::FunctionOperator{iip}) where{iip} = iip & !(L.op_inverse isa Nothing)

# TODO - FunctionOperator, Base.conj, transpose

# operator application
Base.:*(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op(u, L.p, L.t)
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t)
Base.:\(L::FunctionOperator{iip,true}, u::AbstractVecOrMat) where{iip} = L.op_inverse(u, L.p, L.t; L.kwargs...)

function Base.:*(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
_, co = L.cache
du = zero(co)
L.op(du, u, L.p, L.t)
L.op(du, u, L.p, L.t; L.kwargs...)
end

function Base.:\(L::FunctionOperator{true,false}, u::AbstractVecOrMat)
ci, _ = L.cache
du = zero(ci)
L.op_inverse(du, u, L.p, L.t)
L.op_inverse(du, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
L.op(v, u, L.p, L.t)
L.op(v, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{false}, u::AbstractVecOrMat, args...)
Expand All @@ -414,11 +416,11 @@ function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop,
end

function LinearAlgebra.mul!(v::AbstractVecOrMat, L::FunctionOperator{true, oop, true}, u::AbstractVecOrMat, α, β) where{oop}
L.op(v, u, L.p, L.t, α, β)
L.op(v, u, L.p, L.t, α, β; L.kwargs...)
end

function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::FunctionOperator{true}, u::AbstractVecOrMat)
L.op_inverse(v, u, L.p, L.t)
L.op_inverse(v, u, L.p, L.t; L.kwargs...)
end

function LinearAlgebra.ldiv!(L::FunctionOperator{true}, u::AbstractVecOrMat)
Expand Down
Loading