Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a wrapper + mutation to "implicitly" update scheduled parameters? #35

Open
ToucheSir opened this issue Aug 14, 2022 · 3 comments
Open
Labels
enhancement New feature or request

Comments

@ToucheSir
Copy link
Member

ToucheSir commented Aug 14, 2022

This is to write down a thought which came from #34 and FluxML/Optimisers.jl#89. Presently, we rely on mutably/immutably updating any objects which depend on the schedule value after each step. This is simple and easy to understand, but it could get unwieldy with more complex optimizer state trees.

What if we instead created a stateful type or wrapper which keeps track of the current schedule value? Then, we make this or some type which contains a reference to it subclass a number type (maybe Real? Could make it parametric on the value type). This proxy number can then be manipulated directly by Optimisers.jl rules, but will appear to update automatically whenever the schedule is ticked.

Some pseudocode for the above:

Option 1: wrapper itself is mutable number proxy

mutable struct ScheduleValue{T<:Real} <: Real
  inner::T
end

# Overload basic math operations (much like Flux.Nil)
Base.:+(sv::ScheduleValue, x::Number) = sv.inner + x
....

eta = ScheduleValue(0f0)
d = Descent(eta)
schedule = Exp(...)

for s in schedule
  eta.inner = s  # probably want a proper function for this
  ...
end

Option 2: number proxy is derived from wrapper

struct ScheduleValue{S<:Stateful} <: Real
  iter::S
end

_getval(sv::ScheduleValue) = sv.iter.schedule(sv.iter.state)

# Overload basic math operations (much like Flux.Nil)
Base.:+(sv::ScheduleValue, x::Number) = _getval(sv.inner) + x
...

schedule = Stateful(Exp(...))

eta = ScheduleValue(schedule)
d = Descent(eta)

for _ in schedule  # no need for value here, just next! on the Stateful
  ...
end

Too magic? Perhaps. I could also see serialization being an issue because of the mutable references, but BSON/JLD2 at least should work. However, this does seem more ergonomic than wrapping optimization rules when it comes to scheduling multiple hyperparameters simultaneously.

@ToucheSir ToucheSir changed the title Using a wrapper + mutation to Using a wrapper + mutation to "implicitly" update scheduled parameters? Aug 14, 2022
@darsnack
Copy link
Member

Sorry, I had a thought here, but I forgot to write it down.

I could see Option 2 being the more attractive one. We could even go one step further and merge the behavior into Stateful instead of wrapping it. I see this as a feature for people who like "magic" and will probably want the self-mutating iterator.

Note that the solution presented in #34 takes a different approach from this and FluxML/Optimisers.jl#89. I see adjust as a very low-level function for the most stripped down version of Optimisers.jl. Appropriate for a package that might have its own opinions about scheduling. In contrast, Scheduler from #34 avoids adjust entirely. The "rule" at each leaf in the tree does not store the underlying optimization rule (e.g. Descent) at all. As such, we don't have stale hyper-parameters that we need to update in the tree. Tree-based functions like update are applied to a tree of Schedulers, and only when we reach a leaf do we construct the underlying optimization rule then call apply! directly (assuming this is cheap to do). This is akin to how OptimiserChain calls apply! directly.

My thought is that both could be offered as options, since they co-exist peacefully, and people have opinions on mutability.

@darsnack
Copy link
Member

Re: multiple hyper-parameters

I wrote the solution in #34 for a single hyper-parameter, but you could easily make Scheduler hold multiple schedules that can be used in the constructor:

lr_sched = Exp(...)
momentum_sched = Exp(...)
opt = Scheduler(lr_sched, momentum_sched) do lr, momentum
    Momentum(lr, momentum) # could also have constant hyper-parameters here
end

@ToucheSir
Copy link
Member Author

Sounds good. I'll defer to you on which of these options are worthwhile then, since it seems like we have no shortage of choice :)

@darsnack darsnack added the enhancement New feature or request label Oct 5, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants