Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ julia> x = rand(4);
# and primal value based on the type and shape of `x`.
julia> result = DiffResults.HessianResult(x)

# Instead of passing an output buffer to `hessian!`, we pass `result`.
# Note that we re-alias to `result` - this is important! See `hessian!`
# docs for why we do this.
# Instead of passing an output buffer to `ForwardDiff.hessian!`, we pass `result`.
# Note that we re-alias to `result`:
# This is not required in this example since `ForwardDiff.hessian!` mutates `result`;
# however, in general it is important since immutable `DiffResult` instances
# (e.g. `DiffResult` objects with static arrays) cannot be updated in-place.
julia> result = ForwardDiff.hessian!(result, f, x);

# ...and now we can get all the computed data from `result`
Expand Down Expand Up @@ -65,12 +67,12 @@ DiffResults.jacobian
DiffResults.hessian
```

## Mutating a `DiffResult`
## Modifying a `DiffResult`

```@docs
DiffResults.value!
DiffResults.derivative!
DiffResults.gradient!
DiffResults.jacobian!
DiffResults.hessian!
DiffResults.value!!
DiffResults.derivative!!
DiffResults.gradient!!
DiffResults.jacobian!!
DiffResults.hessian!!
```
176 changes: 123 additions & 53 deletions src/DiffResults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,28 +145,36 @@ Note that this method returns a reference, not a copy.
value(r::DiffResult) = r.value

"""
value!(r::DiffResult, x)
value!!(r::DiffResult, x)

Return `s::DiffResult` with the same data as `r`, except for `value(s) == x`.

This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
mutated in-place and returned. Thus, this function should be called as `r = value!(r, x)`.
!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = value!!(r, x)`.
"""
value!(r::MutableDiffResult, x::Number) = (r.value = x; return r)
value!(r::MutableDiffResult, x::AbstractArray) = (copyto!(value(r), x); return r)
value!(r::ImmutableDiffResult{O,V}, x::Union{Number,AbstractArray}) where {O,V} = ImmutableDiffResult(convert(V, x), r.derivs)
value!!(r::MutableDiffResult, x::Number) = (r.value = x; return r)
value!!(r::MutableDiffResult, x::AbstractArray) = (copyto!(value(r), x); return r)
value!!(r::ImmutableDiffResult{O,V}, x::Union{Number,AbstractArray}) where {O,V} = ImmutableDiffResult(convert(V, x), r.derivs)

"""
value!(f, r::DiffResult, x)
value!!(f, r::DiffResult, x)

Equivalent to `value!(r::DiffResult, map(f, x))`, but without the implied temporary
Equivalent to `value!!(r::DiffResult, map(f, x))`, but without the implied temporary
allocation (when possible).

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = value!!(f, r, x)`.
"""
value!(f, r::MutableDiffResult, x::Number) = (r.value = f(x); return r)
value!(f, r::MutableDiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
value!(f, r::ImmutableDiffResult{O,V}, x::Number) where {O,V} = value!(r, convert(V, f(x)))
value!(f, r::ImmutableDiffResult{O,V}, x::AbstractArray) where {O,V} = value!(r, convert(V, map(f, x)))
value!!(f, r::MutableDiffResult, x::Number) = (r.value = f(x); return r)
value!!(f, r::MutableDiffResult, x::AbstractArray) = (map!(f, value(r), x); return r)
value!!(f, r::ImmutableDiffResult{O,V}, x::Number) where {O,V} = value!!(r, convert(V, f(x)))
value!!(f, r::ImmutableDiffResult{O,V}, x::AbstractArray) where {O,V} = value!!(r, convert(V, map(f, x)))

# derivative/derivative! #
#------------------------#
Expand All @@ -181,61 +189,68 @@ Note that this method returns a reference, not a copy.
derivative(r::DiffResult, ::Type{Val{i}} = Val{1}) where {i} = r.derivs[i]

"""
derivative!(r::DiffResult, x, ::Type{Val{i}} = Val{1})
derivative!!(r::DiffResult, x, ::Type{Val{i}} = Val{1})

Return `s::DiffResult` with the same data as `r`, except `derivative(s, Val{i}) == x`.

This function may or may not mutate `r`. If `r::ImmutableDiffResult`, a totally new
instance will be created and returned, whereas if `r::MutableDiffResult`, then `r` will be
mutated in-place and returned. Thus, this function should be called as
`r = derivative!(r, x, Val{i})`.
!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = derivative!!(r, x, Val{i})`.
"""
function derivative!(r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
r.derivs = tuple_setindex(r.derivs, x, Val{i})
return r
end

function derivative!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
copyto!(derivative(r, Val{i}), x)
return r
end

function derivative!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(r::ImmutableDiffResult, x::Union{Number,StaticArray}, ::Type{Val{i}} = Val{1}) where {i}
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, x, Val{i}))
end

function derivative!(r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
T = tuple_eltype(r.derivs, Val{i})
return ImmutableDiffResult(value(r), tuple_setindex(r.derivs, T(x), Val{i}))
end

"""
derivative!(f, r::DiffResult, x, ::Type{Val{i}} = Val{1})
derivative!!(f, r::DiffResult, x, ::Type{Val{i}} = Val{1})

Equivalent to `derivative!(r::DiffResult, map(f, x), Val{i})`, but without the implied
Equivalent to `derivative!!(r::DiffResult, map(f, x), Val{i})`, but without the implied
temporary allocation (when possible).

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = derivative!!(f, r, x, Val{i})`.
"""
function derivative!(f, r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(f, r::MutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
r.derivs = tuple_setindex(r.derivs, f(x), Val{i})
return r
end

function derivative!(f, r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(f, r::MutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
map!(f, derivative(r, Val{i}), x)
return r
end

function derivative!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
return derivative!(r, f(x), Val{i})
function derivative!!(f, r::ImmutableDiffResult, x::Number, ::Type{Val{i}} = Val{1}) where {i}
return derivative!!(r, f(x), Val{i})
end

function derivative!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i}
return derivative!(r, map(f, x), Val{i})
function derivative!!(f, r::ImmutableDiffResult, x::StaticArray, ::Type{Val{i}} = Val{1}) where {i}
return derivative!!(r, map(f, x), Val{i})
end

function derivative!(f, r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
function derivative!!(f, r::ImmutableDiffResult, x::AbstractArray, ::Type{Val{i}} = Val{1}) where {i}
T = tuple_eltype(r.derivs, Val{i})
return derivative!(r, map(f, T(x)), Val{i})
return derivative!!(r, map(f, T(x)), Val{i})
end

# special-cased methods #
Expand All @@ -251,23 +266,35 @@ Equivalent to `derivative(r, Val{1})`.
gradient(r::DiffResult) = derivative(r)

"""
gradient!(r::DiffResult, x)
gradient!!(r::DiffResult, x)

Return `s::DiffResult` with the same data as `r`, except `gradient(s) == x`.

Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(r, x, Val{1})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = gradient!!(r, x)`.
"""
gradient!(r::DiffResult, x) = derivative!(r, x)
gradient!!(r::DiffResult, x) = derivative!!(r, x)

"""
gradient!(f, r::DiffResult, x)
gradient!!(f, r::DiffResult, x)

Equivalent to `gradient!(r::DiffResult, map(f, x))`, but without the implied temporary
Equivalent to `gradient!!(r::DiffResult, map(f, x))`, but without the implied temporary
allocation (when possible).

Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(f, r, x, Val{1})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = gradient!!(f, r, x)`.
"""
gradient!(f, r::DiffResult, x) = derivative!(f, r, x)
gradient!!(f, r::DiffResult, x) = derivative!!(f, r, x)

"""
jacobian(r::DiffResult)
Expand All @@ -279,23 +306,35 @@ Equivalent to `derivative(r, Val{1})`.
jacobian(r::DiffResult) = derivative(r)

"""
jacobian!(r::DiffResult, x)
jacobian!!(r::DiffResult, x)

Return `s::DiffResult` with the same data as `r`, except `jacobian(s) == x`.

Equivalent to `derivative!(r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(r, x, Val{1})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = jacobian!!(r, x)`.
"""
jacobian!(r::DiffResult, x) = derivative!(r, x)
jacobian!!(r::DiffResult, x) = derivative!!(r, x)

"""
jacobian!(f, r::DiffResult, x)
jacobian!!(f, r::DiffResult, x)

Equivalent to `jacobian!(r::DiffResult, map(f, x))`, but without the implied temporary
Equivalent to `jacobian!!(r::DiffResult, map(f, x))`, but without the implied temporary
allocation (when possible).

Equivalent to `derivative!(f, r, x, Val{1})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(f, r, x, Val{1})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = jacobian!!(f, r, x)`.
"""
jacobian!(f, r::DiffResult, x) = derivative!(f, r, x)
jacobian!!(f, r::DiffResult, x) = derivative!!(f, r, x)

"""
hessian(r::DiffResult)
Expand All @@ -307,23 +346,35 @@ Equivalent to `derivative(r, Val{2})`.
hessian(r::DiffResult) = derivative(r, Val{2})

"""
hessian!(r::DiffResult, x)
hessian!!(r::DiffResult, x)

Return `s::DiffResult` with the same data as `r`, except `hessian(s) == x`.

Equivalent to `derivative!(r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(r, x, Val{2})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = hessian!(r, x)`.
"""
hessian!(r::DiffResult, x) = derivative!(r, x, Val{2})
hessian!!(r::DiffResult, x) = derivative!!(r, x, Val{2})

"""
hessian!(f, r::DiffResult, x)
hessian!!(f, r::DiffResult, x)

Equivalent to `hessian!(r::DiffResult, map(f, x))`, but without the implied temporary
Equivalent to `hessian!!(r::DiffResult, map(f, x))`, but without the implied temporary
allocation (when possible).

Equivalent to `derivative!(f, r, x, Val{2})`; see `derivative!` docs for aliasing behavior.
Equivalent to `derivative!!(f, r, x, Val{2})`.

!!! warn
This function may or may not mutate `r`.
If `r::ImmutableDiffResult`, a totally new instance will be created and returned,
whereas if `r::MutableDiffResult`, then `r` will be mutated in-place and returned.
Thus, this function should be called as `r = hessian!!(f, r, x)`.
"""
hessian!(f, r::DiffResult, x) = derivative!(f, r, x, Val{2})
hessian!!(f, r::DiffResult, x) = derivative!!(f, r, x, Val{2})

###################
# Pretty Printing #
Expand All @@ -333,4 +384,23 @@ Base.show(io::IO, r::ImmutableDiffResult) = print(io, "ImmutableDiffResult($(r.v

Base.show(io::IO, r::MutableDiffResult) = print(io, "MutableDiffResult($(r.value), $(r.derivs))")

################
# Deprecations #
################

Base.@deprecate value!(r::DiffResult, x::Union{Number,AbstractArray}) value!!(r, x) false
Base.@deprecate value!(f, r::DiffResult, x::Union{Number,AbstractArray}) value!!(f, r, x) false

Base.@deprecate derivative!(r::DiffResult, x::Union{Number,AbstractArray}, ::Type{Val{i}} = Val{1}) where {i} derivative!!(r, x, Val{i}) false
Base.@deprecate derivative!(f, r::DiffResult, x::Union{Number,AbstractArray}, ::Type{Val{i}} = Val{1}) where {i} derivative!!(f, r, x, Val{i}) false

Base.@deprecate gradient!(r::DiffResult, x) gradient!!(r, x) false
Base.@deprecate gradient!(f, r::DiffResult, x) gradient!!(f, r, x) false

Base.@deprecate jacobian!(r::DiffResult, x) jacobian!!(r, x) false
Base.@deprecate jacobian!(f, r::DiffResult, x) jacobian!!(f, r, x) false

Base.@deprecate hessian!(r::DiffResult, x) hessian!!(r, x) false
Base.@deprecate hessian!(f, r::DiffResult, x) hessian!!(f, r, x) false

end # module
Loading
Loading