From db567bb16dc4423272572fc8cf89b9cebb3431cd Mon Sep 17 00:00:00 2001 From: Carl Case Date: Fri, 22 Sep 2023 16:57:46 +0000 Subject: [PATCH] Use master weights for bfloat16 FusedAdam when master_weights=True --- apex/optimizers/fused_adam.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/apex/optimizers/fused_adam.py b/apex/optimizers/fused_adam.py index 841d5634a..26559378f 100644 --- a/apex/optimizers/fused_adam.py +++ b/apex/optimizers/fused_adam.py @@ -158,6 +158,7 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no g_bf, p_bf, m_bf, v_bf = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] p_16_master = [] + p_bf_master = [] p_32_master = [] for p, p_master in zip(group['params'], group_master['params']): @@ -182,6 +183,8 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no m_16.append(state['exp_avg']) v_16.append(state['exp_avg_sq']) elif p.dtype == torch.bfloat16: + if self.master_weights: + p_bf_master.append(p_master.data) g_bf.append(p.grad) p_bf.append(p) m_bf.append(state['exp_avg']) @@ -232,10 +235,11 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no inv_scale) if len(g_bf) > 0: - multi_tensor_applier( - self.multi_tensor_adam_capturable, + multi_tensor_applier(self.multi_tensor_adam_capturable_master if self.master_weights + else self.multi_tensor_adam_capturable, self._dummy_overflow_buf, - [g_bf, p_bf, m_bf, v_bf], + [g_bf, p_bf, m_bf, v_bf, p_bf_master] if self.master_weights + else [g_bf, p_bf, m_bf, v_bf], group['lr'], beta1, beta2,