diff --git a/Project.toml b/Project.toml index a0c187f21..09bbc87dc 100644 --- a/Project.toml +++ b/Project.toml @@ -114,7 +114,7 @@ SymPy = "2.2" SymPyPythonCall = "0.5" SymbolicIndexingInterface = "0.3.14" SymbolicLimits = "0.2.2" -SymbolicUtils = "4.4" +SymbolicUtils = "4.7" TermInterface = "2" julia = "1.10" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index e3f53ee5a..5091df5ff 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -573,6 +573,9 @@ include("inverse.jl") export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity include("discontinuities.jl") +export SymStruct, @symstruct +include("symstruct.jl") + @public Arr, NAMESPACE_SEPARATOR, Unknown, VariableDefaultValue, VariableSource @public _parse_vars, derivative, gradient, jacobian, sparsejacobian, hessian, sparsehessian @public get_variables, get_variables!, get_differential_vars, option_to_metadata_type, scalarize, shape diff --git a/src/symstruct.jl b/src/symstruct.jl new file mode 100644 index 000000000..86fd57b19 --- /dev/null +++ b/src/symstruct.jl @@ -0,0 +1,280 @@ +""" + $TYPEDEF + +Wrapper type for symbolic structs. Requires that the wrapped struct type `T` be registered +with [`@symstruct`](@ref). After registration, `@variables` can be used to create the +symbolic struct. + +```julia +# Here, `record` has type `SymStruct{Record}` +@variables record::Record +``` + +`getproperty` access on this is a symbolic operation, and returns an expression performing +the appropriate field access. This can only wrap concrete struct types (`isconcretetype(T)` +must be `true`). `getproperty` on this struct leverages `fieldnames` and `fieldtypes`. +Thus, will thus not respect custom `getproperty` methods on the wrapped struct type. +""" +struct SymStruct{T} + sym::SymbolicT + + function SymStruct{T}(x::SymbolicT) where {T} + @assert isconcretetype(T) + # Ensure wrapped types have been registered as such + @assert wrapper_type(T) === SymStruct{T} + # Ensure that the symbolic represents the correct type + @assert symtype(x) === T + new{T}(x) + end +end + +is_wrapper_type(::Type{SymStruct}) = false +is_wrapper_type(::Type{S}) where {T, S <: SymStruct{<:T}} = true +wraps_type(::Type{S}) where {T, S <: SymStruct{T}} = T +iswrapped(::SymStruct{T}) where {T} = true + +SymbolicUtils.unwrap(x::SymStruct) = getfield(x, 1) + +function field_shape end + +""" + @symstruct Foo{T1, T2, ...} + @symstruct Foo{T1, T2, ...} begin + # options... + end + +A macro which enables using type `Foo` with `SymStruct` as a symbolic struct. The first +argument to the macro must be the struct type, with all type parameters named. The optional +second argument is an optional `begin..end` block containing options that influence the +behavior of the macro. The following options are allowed: + +- `shape(:field) = # expression`. For array fields, the shape of the field cannot be + inferred from the type. In case the type of the field can be inferred from the + type, it can be specified using this syntax. The expression must evaluate to an object of type + `Union{SymbolicUtils.Unknown, AbstractVector{UnitRange{Int}}, Tuple{Vararg{UnitRange{Int}}}}`. + The expression has access to the concrete type of the struct being accessed, with all + type parameters available as declared in the first argument. + +For example, given the following struct: + +```julia +struct Record{T} + x::Int + y::Real + z::T +end +``` + +It can be registered as + +```julia +# Note: the type parameter must be declared, but the name itself does not matter +@symstruct Record{V} begin +# If `V` is an `AbstractVector` then field `z` is a 3-vector. Otherwise, it is a scalar. + shape(:z) = V <: AbstractVector ? [1:3] : () +end +``` + +Now, + +```julia +@variables rec::Record{Int} rec2::Record{Vector{Int}} +``` + +`rec.x`, `rec2.x` will be `Num`s with symtype `Int`. `rec.y` and `rec2.y` will be `Num`s +with symtype `Real`. `rec.z` will be a `Num` with symtype `Int`. `rec2.z` will be an +`Arr{Num, 1}` with symtype `Vector{Int}` and shape `[1:2]`. + +In case the shape of a field is not provided, it will be inferred from the type. For +`AbstractArray` subtypes, it will be `SymbolicUtils.Unknown(ndims(arr_type))`. Otherwise, +it will be treated as a scalar. +""" +macro symstruct(T, opts = Expr(:block)) + block = Expr(:block) + where_args = Expr[] + nocurly_name = T + if Meta.isexpr(T, :curly) + for x in @view(T.args[2:end]) + push!(where_args, esc(x)) + end + nocurly_name = T.args[1] + end + T = esc(T) + nocurly_name = esc(nocurly_name) + temp_typevar = :S + push!(block.args, quote + function (::$(typeof(has_symwrapper)))(::Type{$temp_typevar}) where {$(where_args...), $temp_typevar <: $T} + true + end + function (::$(typeof(wrapper_type)))(::Type{$temp_typevar}) where {$(where_args...), $temp_typevar <: $T} + isconcretetype($temp_typevar) ? $SymStruct{$temp_typevar} : $SymStruct{<:$temp_typevar} + end + end) + + @assert Meta.isexpr(opts, :block) """ + Options to `@symstruct` must be specified as a `begin...end` block. Got $opts. + """ + for stmt in opts.args + stmt isa LineNumberNode && continue + @assert Meta.isexpr(stmt, :(=)) """ + Each option to `@symstruct` must be of the form `option(args...) = value`. \ + Got $stmt. + """ + head, val = stmt.args + @assert Meta.isexpr(head, :call) """ + Each option to `@symstruct` must be of the form `option(args...) = value`. \ + Got $head instead of `option(args...)`. + """ + opt = head.args[1] + args = @view(head.args[2:end]) + if opt === :shape + @assert length(args) == 1 """ + The `shape` option must be of the form `shape(:field_name) = value`. Instead \ + of a single argument `:field_name`, multiple arguments $args were found. + """ + @assert args[1] isa QuoteNode """ + The field name provided to the `shape` option must be a literal `Symbol`. + Found `$(args[1])`. + """ + field = args[1] + push!(block.args, __field_shape_expr(T, field, where_args, val)) + else + error("Unsupported option $opt.") + end + end + + return block +end + +function __field_shape_expr(T::Union{Symbol, Expr}, field::QuoteNode, + where_args::Vector{Expr}, val::Union{Expr, Symbol}) + quote + function (::$(typeof(field_shape)))(sym::Type{S}, ::Val{$field}) where {$(where_args...), S <: $T} + val = $(esc(val)) + if val isa $(SymbolicUtils.Unknown) + return val + elseif val isa $(SymbolicUtils.ShapeVecT) + return val + elseif val isa $(AbstractVector{UnitRange{Int}}) + return $(SymbolicUtils.ShapeVecT)(val) + elseif val isa $(Tuple{Vararg{UnitRange{Int}}}) + return $(SymbolicUtils.ShapeVecT)(val) + else + error(""" + Invalid usage of `@symstruct` macro for type $($T). The shape for field \ + $($field) was specified incorrectly. The result of the expression must be \ + one of `SymbolicUtils.Unknown`, `AbstractVector{UnitRange{Int}}` or \ + `Tuple{Vararg{UnitRange{Int}}}`. Found a value of type $(typeof(val)). + """) + end + end + end +end + +# Generated `if..elseif..else` chain for `getproperty`. +@generated function Base.getproperty(sym::SymStruct{T}, name::Symbol) where {T} + chain = Expr(:if) + cur = chain + for fname in fieldnames(T) + fname = Meta.quot(fname) + push!(cur.args, :(name === $fname)) + push!(cur.args, :(return $_literal_getproperty(sym, Val{$fname}()))) + push!(cur.args, Expr(:elseif)) + cur = cur.args[end] + end + cur.head = :block + push!(cur.args, quote + if @isdefined(FieldError) + throw(FieldError($T, name)) + else + error("type $($T) has no field $(name). Available fields are $($(fieldnames(T)))") + end + end) + return chain +end + +""" + $TYPEDEF + +Struct used as operation for symbolic getproperty on `SymStruct{T}` with field `field`. +""" +struct SymbolicGetproperty{T, field} end + +field_name(::SymbolicGetproperty{T, field}) where {T, field} = field + +function (f::SymbolicGetproperty{T})(x::SymbolicT) where {T} + unwrap(f(SymStruct{T}(x))) +end +function (::SymbolicGetproperty{T, field})(x::SymStruct{T}) where {T, field} + _literal_getproperty(x, Val{field}()) +end +function (::SymbolicGetproperty{T, field})(x::T) where {T, field} + getproperty(x, field) +end + +function SymbolicUtils.promote_type(::SymbolicGetproperty{T, field}, x::SymbolicUtils.TypeT) where {T, field} + @assert x == T + fieldtype(x, field) +end + +function SymbolicUtils.promote_shape(::SymbolicGetproperty{T, field}, + @nospecialize(x::SymbolicUtils.ShapeT)) where {T, field} + @assert x isa SymbolicUtils.ShapeVecT && isempty(x) + field_shape(T, Val{field}()) +end + +""" + $TYPEDSIGNATURES + +Called by the generated `getproperty` for `SymStruct`. Performs symbolic field access. +""" +function _literal_getproperty(sym::SymStruct{T}, ::Val{name}) where {T, name} + fT = fieldtype(T, name) + fShape = field_shape(T, Val{name}()) + fname = BSImpl.Const{VartypeT}(name) + _struct = unwrap(sym) + args = ArgsT{VartypeT}((_struct, fname)) + val = BSImpl.Term{VartypeT}(SymbolicGetproperty{T, name}(), args; type = fT, shape = fShape) + if has_symwrapper(fT) + return wrapper_type(fT)(val) + else + return val + end +end + +""" + $TYPEDSIGNATURES + +Obtain the shape of the value obtained by accessing field `name` of type `T`. Only +implemented by `@symstruct` via the `shape` option. +""" +function field_shape(::Type{T}, ::Val{name}) where {T, name} + shape_from_type(fieldtype(T, name)) +end + +shape_from_type(::Type{A}) where {T, N, A <: AbstractArray{T, N}} = SymbolicUtils.Unknown(N) +shape_from_type(::Type{T}) where {T} = SymbolicUtils.ShapeVecT() + +function SymbolicUtils.show_call(io::IO, @nospecialize(f::SymbolicGetproperty), x::SymbolicT) + fname = field_name(f)::Symbol + @match x begin + BSImpl.Term(; args) => print(io, args[1]) + end + print(io, ".") + print(io, fname) +end + +function Base.show(io::IO, x::SymStruct) + show(io, unwrap(x)) +end + +function SymbolicUtils.Code.function_to_expr(@nospecialize(f::SymbolicGetproperty), x::SymbolicT, st) + out = get(st.rewrites, x, nothing) + out === nothing || return out + + fname = field_name(f)::Symbol + args = @match x begin + BSImpl.Term(; args) => args + end + return :($(SymbolicUtils.Code.toexpr(args[1], st)).$fname) +end diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index db6ae09f9..db1aa6a65 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -114,8 +114,11 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) # expected to be defined outside Symbolics if arg isa Expr && arg.head == :(::) T = Base.eval(mod, arg.args[2]) - Ts = has_symwrapper(T) ? (T, BasicSymbolic{VartypeT}, wrapper_type(T)) : - (T, BasicSymbolic{VartypeT}) + if has_symwrapper(T) + Ts = (T, SymbolicT, wrapper_type(T)) + else + Ts = (T, SymbolicT) + end if T <: AbstractArray && wrap_arrays eT = eltype(T) if eT == Any @@ -127,10 +130,9 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) (elT) -> AbstractArray{S} where {S <: elT} end if has_symwrapper(eT) - Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}, - _arr_type_fn(wrapper_type(eT))) + Ts = (Ts..., _arr_type_fn(SymbolicT), _arr_type_fn(wrapper_type(eT))) else - Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}) + Ts = (Ts..., _arr_type_fn(SymbolicT)) end end Ts @@ -174,7 +176,7 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) for (i, T) in enumerate(Ts) if T === BasicSymbolic{VartypeT} push!(body.args, :(@assert $symtype($(names[i])) <: $(types[i][1]))) - elseif T === AbstractArray{BasicSymbolic{VartypeT}} && eltype(types[i][1]) !== Any + elseif T <: (AbstractArray{S} where {S <: SymbolicT}) && eltype(types[i][1]) !== Any push!(body.args, :(@assert $symtype($(names[i])[1]) <: $(eltype(types[i][1])))) end end diff --git a/test/runtests.jl b/test/runtests.jl index c5e15eebf..2886fea29 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Invalidations Test" begin include("invalidations.jl") end @safetestset "Macro Test" begin include("macro.jl") end @safetestset "Arrays" begin include("arrays.jl") end + @safetestset "SymStruct" begin include("symstruct.jl") end @safetestset "View-setting" begin include("stencils.jl") end @safetestset "Complex" begin include("complex.jl") end @safetestset "Semi-polynomial" begin include("semipoly.jl") end diff --git a/test/symstruct.jl b/test/symstruct.jl new file mode 100644 index 000000000..689461a66 --- /dev/null +++ b/test/symstruct.jl @@ -0,0 +1,197 @@ +using Symbolics +using SymbolicUtils, SymbolicUtils.Code +import SymbolicUtils as SU +using Test + +struct Record1 + x::Int + y::String + z::Vector{Real} +end + +@symstruct Record1 + +@testset "Basic record" begin + @variables rec::Record1 + + @test rec isa Symbolics.SymStruct{Record1} + ex = rec.x + @test ex isa Num + @test SU.symtype(SU.unwrap(ex)) === Int + + ex = rec.y + @test ex isa Symbolics.SymbolicT + @test SU.symtype(ex) === String + + ex = rec.z + @test ex isa Symbolics.Arr{Num, 1} + @test SU.symtype(SU.unwrap(ex)) === Vector{Real} + @test SU.shape(SU.unwrap(ex)) === SU.Unknown(1) + + @test toexpr(rec.x) == :(rec.x) + ex = [rec.y, rec.x + rec.z[1]] + val = eval(quote + let rec = Record1(1, "abc", [1.0, 2.0]) + $(toexpr(ex)) + end + end) + @test val == ["abc", 2.0] +end + +struct Record2{T} + x::Int + y::String + z::Vector{T} +end + +@symstruct Record2{T} begin + shape(:z) = [1:3] +end + +@testset "Parametric record with specified shape" begin + @variables rec::Record2{Int} + + ex = rec.z + @test ex isa Symbolics.Arr{Num, 1} + @test SU.symtype(SU.unwrap(ex)) === Vector{Int} + @test SU.shape(SU.unwrap(ex)) == [1:3] + + ex = rec.z[1] + @test ex isa Num + @test SU.symtype(SU.unwrap(ex)) === Int +end + +@testset "Recursive struct" begin + @variables rec::Record2{Record2{Record1}} + + ex = rec.z + @test ex isa Symbolics.Arr{Symbolics.SymStruct{Record2{Record1}}, 1} + @test SU.symtype(SU.unwrap(ex)) === Vector{Record2{Record1}} + @test SU.shape(SU.unwrap(ex)) == [1:3] + + ex = rec.z[1] + @test ex isa Symbolics.SymStruct{Record2{Record1}} + @test SU.symtype(SU.unwrap(ex)) === Record2{Record1} + + ex = rec.z[1].z + @test ex isa Symbolics.Arr{Symbolics.SymStruct{Record1}, 1} + @test SU.symtype(SU.unwrap(ex)) === Vector{Record1} + @test SU.shape(SU.unwrap(ex)) == [1:3] + + ex = rec.z[1].z[1] + @test ex isa Symbolics.SymStruct{Record1} + @test SU.symtype(SU.unwrap(ex)) === Record1 + + ex = rec.z[1].z[1].z + @test ex isa Symbolics.Arr{Num, 1} + @test SU.symtype(SU.unwrap(ex)) === Vector{Real} + @test SU.shape(SU.unwrap(ex)) === SU.Unknown(1) + + @test toexpr(rec.z[1].z[2].z[3]) == :($getindex($getindex($getindex(rec.z, 1).z, 2).z, 3)) + @variables rec::Record2{Record1} + val = eval(quote + let rec = Record2{Record1}(1, "A", + [Record1(2, "B", [2.0, 3.0]), + Record1(3, "C", [3.0, 4.0]), + Record1(4, "D", [4.0, 5.0])]) + $(toexpr(rec.x + rec.z[1].x + rec.z[2].z[1] + rec.z[3].z[2])) + end + end) + @test val == 1 + 2 + 3.0 + 5.0 +end + +abstract type AbstractRecord1 end + +@symstruct AbstractRecord1 begin + shape(:x) = [1:3] +end + +struct ConcreteRecord1_1 <: AbstractRecord1 + x::Vector{Int} +end + +struct ConcreteRecord1_2 <: AbstractRecord1 + x::Vector{Int} +end + +@symstruct ConcreteRecord1_2 begin + shape(:x) = [1:2] +end + +record_op1(x::ConcreteRecord1_1) = sum(x.x) +record_op1(x::ConcreteRecord1_2) = prod(x.x) + +@register_symbolic record_op1(x::AbstractRecord1) + +record_arrop1(x::Vector{<:AbstractRecord1}) = sum(record_op1, x) + +@register_symbolic record_arrop1(x::Vector{AbstractRecord1}) + +@testset "`@symstruct` of simple abstract type" begin + @test Symbolics.has_symwrapper(AbstractRecord1) + @test Symbolics.wrapper_type(AbstractRecord1) == SymStruct{<:AbstractRecord1} + @test Symbolics.wrapper_type(ConcreteRecord1_1) === SymStruct{ConcreteRecord1_1} + + @variables r1::ConcreteRecord1_1 r2::ConcreteRecord1_2 + @test SU.shape(r1.x) == [1:3] + @test SU.shape(r2.x) == [1:2] +end + +@testset "Function registration with symbolic structs" begin + @variables r1::ConcreteRecord1_1 r2::ConcreteRecord1_2 + ex = record_op1(r1) + @test operation(SU.unwrap(ex)) === record_op1 + + val = eval(quote + let r1 = ConcreteRecord1_1([2,2,3]) + $(toexpr(ex)) + end + end) + @test val == 7 + + ex = record_op1(r2) + @test operation(SU.unwrap(ex)) === record_op1 + + val = eval(quote + let r2 = ConcreteRecord1_2([2,2,3]) + $(toexpr(ex)) + end + end) + @test val == 12 + + @variables r3[1:2]::ConcreteRecord1_1 r4[1:2]::ConcreteRecord1_2 + + ex = record_arrop1(r3) + @test operation(SU.unwrap(ex)) === record_arrop1 + val = eval(quote + let r3 = [ConcreteRecord1_1([2,3,4]), ConcreteRecord1_1([3,4,5])] + $(toexpr(ex)) + end + end) + @test val == 21 + + ex = record_arrop1(r4) + @test operation(SU.unwrap(ex)) === record_arrop1 + val = eval(quote + let r4 = [ConcreteRecord1_2([2,3,4]), ConcreteRecord1_2([3,4,5])] + $(toexpr(ex)) + end + end) + @test val == 84 + + ex = record_arrop1([r3[1], r1]) + @test operation(SU.unwrap(ex)) === record_arrop1 + ex = record_arrop1([r4[1], r2]) + @test operation(SU.unwrap(ex)) === record_arrop1 +end + +abstract type AbstractRecord2{T} end + +@symstruct AbstractRecord2 + +@testset "Registering type without parameters works" begin + @test Symbolics.has_symwrapper(AbstractRecord2) + @test Symbolics.wrapper_type(AbstractRecord2) == SymStruct{<:AbstractRecord2} + @test Symbolics.has_symwrapper(AbstractRecord2{Int}) + @test Symbolics.wrapper_type(AbstractRecord2{Int}) == SymStruct{<:AbstractRecord2{Int}} +end