Skip to content

Commit

Permalink
matmul fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Aug 3, 2018
1 parent 4cf6bac commit f5c9361
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,20 +276,29 @@ LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)

x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
y::AbstractMatrix * x::TrackedMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)

x::TrackedMatrix * y::AbstractVector = track(*, x, y)
y::AbstractMatrix * x::TrackedVector = track(*, x, y)
x::AbstractMatrix * y::TrackedVector = track(*, x, y)
x::TrackedMatrix * y::TrackedVector = track(*, x, y)

x::TrackedVector * y::AbstractVector = track(*, x, y)
y::AbstractVector * x::TrackedVector = track(*, x, y)
x::AbstractVector * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::TrackedVector = track(*, x, y)

@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ ->* transpose(b), transpose(a) * Δ)

# @grad function (a::AbstractMatrix * b::AbstractVecOrMat)
# # @show size(a) size(b)
# data(a)*data(b), function (Δ)
# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b)))
# @show typeof(Δ) typeof(b)
# (Δ * transpose(b), transpose(a) * Δ)
# end
# end

# NNlib

using NNlib
Expand Down

0 comments on commit f5c9361

Please sign in to comment.