Skip to content

Commit da85a95

Browse files
Merge pull request #1257 from kapple19/#1085
Added BasicSymbolic as input for function fed to derivative
2 parents c8a6047 + a58baa0 commit da85a95

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/diff.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,8 +499,8 @@ derivative(::typeof(+), args::NTuple{N,Any}, ::Val) where {N} = 1
499499
derivative(::typeof(*), args::NTuple{N,Any}, ::Val{i}) where {N,i} = *(deleteat!(collect(args), i)...)
500500
derivative(::typeof(one), args::Tuple{<:Any}, ::Val) = 0
501501

502-
derivative(f::Function, x::Num) = derivative(f(x), x)
503-
derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Num, typeof(x)) |> throw
502+
derivative(f::Function, x::Union{Num, <:BasicSymbolic}) = derivative(f(x), x)
503+
derivative(::Function, x::Any) = TypeError(:derivative, "2nd argument", Union{Num, <:BasicSymbolic}, x) |> throw
504504

505505
function count_order(x)
506506
@assert !(x isa Symbol) "The variable $x must have an order of differentiation that is greater or equal to 1!"

test/diff.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,20 @@ let
432432
end
433433
end
434434

435+
# Derivative of a `BasicSymbolic` (#1085)
436+
let
437+
x = Symbolics.Sym{Int}(:x)
438+
@testset for f in [sqrt, sin, acos, exp]
439+
@test isequal(
440+
Symbolics.derivative(f, x),
441+
Symbolics.derivative(
442+
f,
443+
Symbolics.BasicSymbolic(x)
444+
)
445+
)
446+
end
447+
end
448+
435449
# Check ssqrt, scbrt, slog
436450
let
437451
@variables x

0 commit comments

Comments
 (0)