From 5a7f7c0769b5c37268191252c56f2cd49802ae01 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 22 May 2018 12:11:28 +0800 Subject: [PATCH] wip julia 0.7 --- src/Flux.jl | 5 +++-- src/data/tree.jl | 7 ------- src/optimise/train.jl | 3 +-- src/oset.jl | 26 ++++++++++++++++++++++++++ src/tracker/array.jl | 20 ++++++++++---------- src/treelike.jl | 2 -- 6 files changed, 40 insertions(+), 23 deletions(-) create mode 100644 src/oset.jl diff --git a/src/Flux.jl b/src/Flux.jl index 7125630f69..843795c333 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -4,7 +4,7 @@ module Flux # Zero Flux Given -using Juno, Requires, Reexport +using Requires, Reexport using MacroTools: @forward export Chain, Dense, RNN, LSTM, GRU, Conv, @@ -25,6 +25,7 @@ using .Optimise: @epochs export SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad +include("oset.jl") include("utils.jl") include("onehot.jl") include("treelike.jl") @@ -35,7 +36,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") -include("data/Data.jl") +# include("data/Data.jl") @require CuArrays include("cuda/cuda.jl") diff --git a/src/data/tree.jl b/src/data/tree.jl index 5067714a08..c9db9e7ba8 100644 --- a/src/data/tree.jl +++ b/src/data/tree.jl @@ -21,13 +21,6 @@ function Base.show(io::IO, t::Tree) print_tree(io, t) end -using Juno - -@render Juno.Inline t::Tree begin - render(t) = Juno.Tree(t.value, render.(t.children)) - Juno.Tree(typeof(t), [render(t)]) -end - Base.getindex(t::Tree, i::Integer) = t.children[i] Base.getindex(t::Tree, i::Integer, is::Integer...) = t[i][is...] diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 401a1c51f9..bd21b45079 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -1,4 +1,3 @@ -using Juno using Flux.Tracker: back! runall(f) = f @@ -35,7 +34,7 @@ Multiple optimisers and callbacks can be passed to `opt` and `cb` as arrays. function train!(loss, data, opt; cb = () -> ()) cb = runall(cb) opt = runall(opt) - @progress for d in data + for d in data l = loss(d...) isinf(l) && error("Loss is Inf") isnan(l) && error("Loss is NaN") diff --git a/src/oset.jl b/src/oset.jl new file mode 100644 index 0000000000..e01d254581 --- /dev/null +++ b/src/oset.jl @@ -0,0 +1,26 @@ +const ASet{T} = Base.AbstractSet{T} +const ODict = ObjectIdDict + +struct ObjectIdSet{T} <: ASet{T} + dict::ObjectIdDict + ObjectIdSet{T}() where T = new(ObjectIdDict()) +end + +Base.eltype{T}(::ObjectIdSet{T}) = T + +ObjectIdSet() = ObjectIdSet{Any}() + +Base.push!{T}(s::ObjectIdSet{T}, x::T) = (s.dict[x] = nothing; s) +Base.delete!{T}(s::ObjectIdSet{T}, x::T) = (delete!(s.dict, x); s) +Base.in(x, s::ObjectIdSet) = haskey(s.dict, x) + +(::Type{ObjectIdSet{T}}){T}(xs) = push!(ObjectIdSet{T}(), xs...) + +ObjectIdSet(xs) = ObjectIdSet{eltype(xs)}(xs) + +Base.collect(s::ObjectIdSet) = collect(keys(s.dict)) +Base.similar(s::ObjectIdSet, T::Type) = ObjectIdSet{T}() + +@forward ObjectIdSet.dict Base.length + +const OSet = ObjectIdSet diff --git a/src/tracker/array.jl b/src/tracker/array.jl index e11296abf4..7ee9044fcb 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -27,16 +27,16 @@ Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = print(io, "TrackedArray{…,$A}") -function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) - if repr - print(io, "param(") - Base.showarray(io, data(X), true) - print(io, ")") - else - header && print(io, "Tracked ") - Base.showarray(io, data(X), false, header = header) - end -end +# function Base.showarray(io::IO, X::TrackedArray, repr::Bool = true; header = true) +# if repr +# print(io, "param(") +# Base.showarray(io, data(X), true) +# print(io, ")") +# else +# header && print(io, "Tracked ") +# Base.showarray(io, data(X), false, header = header) +# end +# end Base.setindex!(xs::TrackedArray, v, i...) = error("Can't differentiate `setindex!`") diff --git a/src/treelike.jl b/src/treelike.jl index fbe9fcad37..77f53b3b7b 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -20,8 +20,6 @@ function mapleaves(f, x; cache = ObjectIdDict()) cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) end -using DataFlow: OSet - function prefor(f, x; seen = OSet()) x ∈ seen && return f(x)