diff --git a/src/complex.jl b/src/complex.jl index 64420f324..39f22d99e 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -1,11 +1,11 @@ SymbolicUtils.promote_symtype(::typeof(imag), ::Type{Complex{T}}) where {T} = T -Base.promote_rule(::Type{Complex{T}}, ::Type{S}) where {T<:Real, S<:Num} = Complex{S} # 283 +Base.promote_rule(::Type{Complex{T}}, ::Type{S}) where {T <: Real, S <: Num} = Complex{S} # 283 is_wrapper_type(::Type{Complex{Num}}) = true -has_symwrapper(::Type{<:Complex{T}}) where {T<:Real} = true +has_symwrapper(::Type{<:Complex{T}}) where {T <: Real} = true wraps_type(::Type{Complex{Num}}) = Complex{Real} iswrapped(::Complex{Num}) = true -function wrapper_type(::Type{Complex{T}}) where T +function wrapper_type(::Type{Complex{T}}) where {T} Symbolics.has_symwrapper(T) ? Complex{wrapper_type(T)} : Complex{T} end @@ -14,11 +14,13 @@ function SymbolicUtils.unwrap(a::Complex{<:Num}) if SymbolicUtils.isconst(re) && SymbolicUtils.isconst(img) return Const{VartypeT}(complex(unwrap_const(re), unwrap_const(img))) end - if iscall(re) && operation(re) === real && iscall(img) && operation(img) === imag && isequal(arguments(re)[1], arguments(img)[1]) + if iscall(re) && operation(re) === real && iscall(img) && operation(img) === imag && + isequal(arguments(re)[1], arguments(img)[1]) return arguments(re)[1] end sT = promote_type(symtype(re), symtype(img)) - return Term{VartypeT}(complex, SymbolicUtils.ArgsT{vartype(re)}((re, img)); type = Complex{sT}, shape = SymbolicUtils.ShapeVecT()) + return Term{VartypeT}(complex, SymbolicUtils.ArgsT{vartype(re)}((re, img)); + type = Complex{sT}, shape = SymbolicUtils.ShapeVecT()) end function Base.Complex{Num}(x::BasicSymbolic{VartypeT}) @@ -32,9 +34,8 @@ function Base.show(io::IO, a::Complex{Num}) ii = unwrap(imag(a)) if iscall(rr) && (operation(rr) === real) && - iscall(ii) && (operation(ii) === imag) && - isequal(arguments(rr)[1], arguments(ii)[1]) - + iscall(ii) && (operation(ii) === imag) && + isequal(arguments(rr)[1], arguments(ii)[1]) return print(io, arguments(rr)[1]) end @@ -42,5 +43,19 @@ function Base.show(io::IO, a::Complex{Num}) end function (s::SymbolicUtils.Substituter)(x::Complex{Num}) - Complex{Num}(s(real(x)), s(imag(x))) + re_sub = s(real(x)) + im_sub = s(imag(x)) + # Unwrap to check if the substituted values are complex constants + re_unwrapped = unwrap(re_sub) + im_unwrapped = unwrap(im_sub) + # If both parts are constants, we can evaluate the full complex expression + if SymbolicUtils.isconst(re_unwrapped) && SymbolicUtils.isconst(im_unwrapped) + re_val = SymbolicUtils.unwrap_const(re_unwrapped) + im_val = SymbolicUtils.unwrap_const(im_unwrapped) + # Properly handle complex arithmetic: (a + b*im) where a, b can be complex + # (a_re + a_im*im) + (b_re + b_im*im)*im = (a_re - b_im) + (a_im + b_re)*im + result = re_val + im_val * im + return Complex{Num}(wrap(Const{VartypeT}(real(result))), wrap(Const{VartypeT}(imag(result)))) + end + Complex{Num}(re_sub, im_sub) end diff --git a/test/complex.jl b/test/complex.jl index e9b77cbb3..d2bd15849 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -50,3 +50,46 @@ end @test !hasname(2x) @test !hasname(x + y) end + +# issue #1109: substituting complex values into expressions with complex coefficients +@testset "complex value substitution" begin + @variables x y + + # Basic case from issue #1109 + p1 = 0.4 + 1.7im * x + result1 = substitute(p1, Dict(x => 0.2 + 1.0im)) + expected1 = 0.4 + 1.7im * (0.2 + 1.0im) + @test result1 isa Complex{Num} + @test isapprox(unwrap_const(unwrap(real(result1))), real(expected1)) + @test isapprox(unwrap_const(unwrap(imag(result1))), imag(expected1)) + + # Both real and imag parts have the variable + p2 = x + 2.0im * x + result2 = substitute(p2, Dict(x => 1.0 + 0.5im)) + expected2 = (1.0 + 0.5im) + 2.0im * (1.0 + 0.5im) + @test isapprox(unwrap_const(unwrap(real(result2))), real(expected2)) + @test isapprox(unwrap_const(unwrap(imag(result2))), imag(expected2)) + + # Real value substitution still works + p3 = 0.4 + 1.7im * x + result3 = substitute(p3, Dict(x => 0.5)) + expected3 = 0.4 + 1.7im * 0.5 + @test isapprox(unwrap_const(unwrap(real(result3))), real(expected3)) + @test isapprox(unwrap_const(unwrap(imag(result3))), imag(expected3)) + + # Two variables with complex substitution + p4 = x + y * im + result4 = substitute(p4, Dict(x => 1.0 + 2.0im, y => 3.0 + 4.0im)) + expected4 = (1.0 + 2.0im) + (3.0 + 4.0im) * im + @test isapprox(unwrap_const(unwrap(real(result4))), real(expected4)) + @test isapprox(unwrap_const(unwrap(imag(result4))), imag(expected4)) + + # Symbolics.value should work on the result + result_val = Symbolics.value(result1) + @test result_val isa Complex + @test isapprox(result_val, expected1) + + # simplify and expand should work + @test_nowarn Symbolics.simplify(result1) + @test_nowarn Symbolics.expand(result1) +end