Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD with Tracker.jl versus Zygote.jl and Mooncake.jl. For BNN. #2454

Open
deveshjawla opened this issue Jan 6, 2025 · 1 comment
Open

AD with Tracker.jl versus Zygote.jl and Mooncake.jl. For BNN. #2454

deveshjawla opened this issue Jan 6, 2025 · 1 comment
Labels
user issue Issues raised by, or actively affecting, users

Comments

@deveshjawla
Copy link

deveshjawla commented Jan 6, 2025

Can someone help me understand what's causing the Tracker.jl to fail with "linking"? And how could Zygote match the performance of Tracker when "Standard" during the TuringBenchmarking?

I get the following outputs.

┌ Warning: Gradient computation (with linking) failed for AutoTracker(): MethodError(copyto!, (0.0, Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}}(+, (0.0, -0.1371315395009085))), 0x0000000000006a89)
└ @ TuringBenchmarking ~/.julia/packages/TuringBenchmarking/fc6o7/src/TuringBenchmarking.jl:243
"gradient" => 3-element BenchmarkTools.BenchmarkGroup:
          tags: []
          "AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["AutoMooncake{Mooncake.Config}(Mooncake.Config(false, false))"]
                  "linked" => Trial(531.125 μs)
                  "standard" => Trial(532.459 μs)
          "AutoZygote()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Zygote"]
                  "linked" => Trial(2.382 ms)
                  "standard" => Trial(2.252 ms)
          "AutoTracker()" => 2-element BenchmarkTools.BenchmarkGroup:
                  tags: ["Tracker"]
                  "linked" => 0-element BenchmarkTools.BenchmarkGroup:
                          tags: []
                  "standard" => Trial(99.125 μs)

MWE:

using Lux, Zygote, Tracker, Mooncake, Turing, Random, TuringBenchmarking, Functors
nn = Chain(Dense(10, 5, relu), Dense(5, 1, use_bias=false))
rng = Xoshiro(0)
ps, st = Lux.setup(rng, nn)
num_params = Lux.parameterlength(nn) # number of parameters in NN
const model = StatefulLuxLayer{true}(nn, nothing, st)

function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple)
    @assert length(ps_new) == Lux.parameterlength(ps)
    i = 1
    function get_ps(x)
        z = reshape(view(ps_new, i:(i+length(x)-1)), size(x))
        i += length(x)
        return z
    end
    return fmap(get_ps, ps)
end

@model function BNN(x, y, num_p)
    θ_p ~ MvNormal(zeros(num_p), ones(num_p))

    preds = Lux.apply(model, x, vector_to_parameters(θ_p, ps))

    sigma ~ Gamma(0.1, 1.0) # Prior for the variance
 
    y[:] ~ Product(Normal.(vec(preds), sigma))
end

benchmark_result = benchmark_model(BNN(randn(10,10), randn(1,10), num_params), adbackends=[AutoZygote(), AutoTracker(), AutoMooncake(; config=Mooncake.Config(; debug_mode=false))]) 
  1. Mooncake works well on other problems too.
  2. Zygote is slow compared to both. Would be good to bring its speed upto Tracker.
  3. Tracker is the fastest but failed with "linked"

Any help is appreciated. Thank you.

@mhauru mhauru added the user issue Issues raised by, or actively affecting, users label Jan 6, 2025
@yebai
Copy link
Member

yebai commented Jan 7, 2025

Moving forward, we will proactively fix any Mooncake performance issues. Any slowdown of Mooncake compared to Tracker is likely a performance bug or poor rule implementation.

cc @willtebbutt

EDIT: can we have a warning message for AutoTracker explicitly saying Turing no longer supports its use, and users should consider AutoFowardDiff / AutoReverseDiff / AutoMooncake / AutoEnzyme.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
user issue Issues raised by, or actively affecting, users
Projects
None yet
Development

No branches or pull requests

3 participants