-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Purely functional loss objectives #338
Comments
Thans for this detailed description @XuehaiPan Can you elaborate on this?
As far as I understand it, the parameters of the LSTM (if they're part of the model) will end up in But I guess I misunderstood your point. |
@vmoens Sorry for the late reply and wrong link for the LPG algorithm. Background: Policy Gradient Theorem (Ref: High-Dimensional Continuous Control Using Generalized Advantage Estimation (GAE)) (Using screenshot here due to poor inline math rendering on GitHub) The main contribution of the GAE paper is introducing a new rule: GAE is calculated by a carefully designed rule, which is fixed. For Learned Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, a separate LSTM network (parameterized by Then the policy-gradient-based RL policy (parameterized by The use case is:
Note that the parameters fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module) We need to maintain indices for this and pass the parameters carefully: params = ... # manipulate the parameters carefully
loss = fmodel(params, buffers, batch) In the feature request, we are requesting to calculate all necessary data before calling the loss function rather than inside the function. Then we can do: meta_fmodel, meta_params, meta_buffers = functorch.make_functional_with_buffers(meta_module)
fmodel, params, buffers = functorch.make_functional_with_buffers(loss_module)
batch = meta_fmodel(meta_params, meta_buffers, batch)
batch = fmodel(params, buffers, batch)
loss = loss_fn(batch) |
Thanks for the answer
If I understand well what bothers you with the current API is that with functorch as it is not, you are working with a list of parameters/buffers which makes it hard to assign them to a particular module, is that right? What if we were storing the params in a dictionary instead (or better: a TensorDict :D )? |
Yes, that's right.
Thanks for this. The new implementation would resolve both points 1 and 2 in #338 (comment). |
Hey, @XuehaiPan are you still interested in this? |
Motivation
1. Consistent style for
torch.nn.modules.loss.*Loss
In
torch.nn.modules.loss
, there are many*Loss
subclassingnn.Module
. TheLoss.__init__()
does not takes othernn.Module
's as arguments. And methodLoss.forward()
method is purely functional and directly callsnn.functional.*_loss
.I think the motivation for using
torch.nn.modules.loss.*Loss
is compositing networks bynn.Sequential(...)
.2. More straightforward implementation for functional style algorithms, such as meta-RL algorithms
In many meta-RL algorithms, the policy is trained with meta-parameters that may not register to the
LossModule
.Case.1 MGRL: Register leaf meta-parameters as buffers in the loss module
For Meta-Gradient Reinforcement Learning (MGRL) https://arxiv.org/abs/1805.09801, it takes the discount factor
gamma
as the meta-parameter cross RL updates.Use PPO for example:
See https://github.com/metaopt/TorchOpt#torchopt-as-differentiable-optimizer-for-meta-learning for figures.
we need to register our meta-parameter
gamma
in the buffer of the loss module instead of full control of the parameters by the user.For integration with
functorch
, register the meta-parameter as module buffer works freely.Case.2 LPG: Register non-leaf meta-parameters as buffers in the loss module on every outer update
For Learning Policy Gradient (LPG) https://arxiv.org/abs/2007.08794, it takes the LSTM network as the meta-parameter.
Different from MGRL, on each update, the meta-network output is not a leaf tensor anymore. Then we need to register these output again and again before each call of
loss_module.forward
. This makesnot working.
cc @Benjamin-eecs @waterhorse1
Solution
A clear and concise description of what you want to happen.
Split the
forward
method in the loss module into a separate pure function, i.e., a state less function does not have any parameters. The model parameters should be organized by other modules. The loss function only takes atensordict
as input, and add a new key"loss_objective"
into thetensordict
. All tensor inputs (e.g.value = self.critic(...)
) should be calculated before calling the loss function, because the loss function is purely functional, i.e., does not host parameters (e.g.,actor.parameters()
,critic.parameters()
).Here is a prototype example:
For backward compatibility, refactor the
PPOLoss
module as:Alternatives
A clear and concise description of any alternative solutions or features you've considered.
Copy and paste the loss module source code, then do specific customizations.
Additional context
Add any other context or screenshots about the feature request here.
Checklist
The text was updated successfully, but these errors were encountered: