Skip to content

Commit

Permalink
wip julia 0.7
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed May 29, 2018
1 parent e92f840 commit 5a7f7c0
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 23 deletions.
5 changes: 3 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ module Flux

# Zero Flux Given

using Juno, Requires, Reexport
using Requires, Reexport
using MacroTools: @forward

export Chain, Dense, RNN, LSTM, GRU, Conv,
Expand All @@ -25,6 +25,7 @@ using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad

include("oset.jl")
include("utils.jl")
include("onehot.jl")
include("treelike.jl")
Expand All @@ -35,7 +36,7 @@ include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")

include("data/Data.jl")
# include("data/Data.jl")

@require CuArrays include("cuda/cuda.jl")

Expand Down
7 changes: 0 additions & 7 deletions src/data/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@ function Base.show(io::IO, t::Tree)
print_tree(io, t)
end

using Juno

@render Juno.Inline t::Tree begin
render(t) = Juno.Tree(t.value, render.(t.children))
Juno.Tree(typeof(t), [render(t)])
end

Base.getindex(t::Tree, i::Integer) = t.children[i]
Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...]

Expand Down
3 changes: 1 addition & 2 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Juno
using Flux.Tracker: back!

runall(f) = f
Expand Down Expand Up @@ -35,7 +34,7 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays.
function train!(loss, data, opt; cb = () -> ())
cb = runall(cb)
opt = runall(opt)
@progress for d in data
for d in data
l = loss(d...)
isinf(l) && error("Loss is Inf")
isnan(l) && error("Loss is NaN")
Expand Down
26 changes: 26 additions & 0 deletions src/oset.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
const ASet{T} = Base.AbstractSet{T}
const ODict = ObjectIdDict

struct ObjectIdSet{T} <: ASet{T}
dict::ObjectIdDict
ObjectIdSet{T}() where T = new(ObjectIdDict())
end

Base.eltype{T}(::ObjectIdSet{T}) = T

ObjectIdSet() = ObjectIdSet{Any}()

Base.push!{T}(s::ObjectIdSet{T}, x::T) = (s.dict[x] = nothing; s)
Base.delete!{T}(s::ObjectIdSet{T}, x::T) = (delete!(s.dict, x); s)
Base.in(x, s::ObjectIdSet) = haskey(s.dict, x)

(::Type{ObjectIdSet{T}}){T}(xs) = push!(ObjectIdSet{T}(), xs...)

ObjectIdSet(xs) = ObjectIdSet{eltype(xs)}(xs)

Base.collect(s::ObjectIdSet) = collect(keys(s.dict))
Base.similar(s::ObjectIdSet, T::Type) = ObjectIdSet{T}()

@forward ObjectIdSet.dict Base.length

const OSet = ObjectIdSet
20 changes: 10 additions & 10 deletions src/tracker/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
print(io, "TrackedArray{…,$A}")

function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
if repr
print(io, "param(")
Base.showarray(io, data(X), true)
print(io, ")")
else
header && print(io, "Tracked ")
Base.showarray(io, data(X), false, header = header)
end
end
# function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true)
# if repr
# print(io, "param(")
# Base.showarray(io, data(X), true)
# print(io, ")")
# else
# header && print(io, "Tracked ")
# Base.showarray(io, data(X), false, header = header)
# end
# end

Base.setindex!(xs::TrackedArray, v, i...) =
error("Can't differentiate `setindex!`")
Expand Down
2 changes: 0 additions & 2 deletions src/treelike.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ function mapleaves(f, x; cache = ObjectIdDict())
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end

using DataFlow: OSet

function prefor(f, x; seen = OSet())
x seen && return
f(x)
Expand Down

0 comments on commit 5a7f7c0

Please sign in to comment.