You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
There's a bug in the gradient accumulation calculation that affects multi-node training scenarios. The current implementation uses local device count instead of world size (total devices across all nodes), leading to incorrect gradient accumulation steps and training behavior.
Current Behavior
The gradient_accumulation_iters function calculates steps based on local batch size:
defgradient_accumulation_iters(self, devices: int) ->int:
"""Number of iterations between gradient synchronizations"""gradient_accumulation_iters=self.batch_size(devices) //self.micro_batch_sizereturngradient_accumulation_itersdefbatch_size(self, devices: int) ->int:
"""Number of samples between optimizer steps per data-parallel rank"""batch_size=self.global_batch_size//devices# devices is local count onlyreturnbatch_size
The devices parameter comes from torch.cuda.device_count(), which only returns local GPU count (e.g., 8) rather than total GPUs across all nodes (e.g., 8 * num_nodes = 128).
Impact
This causes:
Incorrect gradient accumulation frequency (off by a factor of num_nodes)
Mismatch between steps and iterations in training logs
Example from logs:
Epoch 1 | iter 109472 step 13684 | loss train: 3.099, val: 3.075
Note the 8x difference between iterations and steps (109472/13684 ≈ 8)
Steps to Reproduce
Configure multi-node training (e.g., 16 nodes, 8 GPUs each)
Set global batch size and micro batch size
Observe gradient accumulation steps and training logs
Expected Behavior
The calculation should use total world size (all devices across all nodes) instead of local device count:
defbatch_size(self, devices: int) ->int:
"""Number of samples between optimizer steps per data-parallel rank"""batch_size=self.global_batch_size//fabric.world_size# Use world_size instead of devicesreturnbatch_size
Proposed Solution
Use fabric.world_size instead of local devices count to properly account for all processes across nodes.
Additional Context
This issue likely wasn't caught earlier because the tutorials primarily use single-node setups (e.g.,
Bug description
There's a bug in the gradient accumulation calculation that affects multi-node training scenarios. The current implementation uses local device count instead of world size (total devices across all nodes), leading to incorrect gradient accumulation steps and training behavior.
Current Behavior
The
gradient_accumulation_iters
function calculates steps based on local batch size:The
devices
parameter comes fromtorch.cuda.device_count()
, which only returns local GPU count (e.g., 8) rather than total GPUs across all nodes (e.g., 8 * num_nodes = 128).Impact
This causes:
Note the 8x difference between iterations and steps (109472/13684 ≈ 8)
Steps to Reproduce
Expected Behavior
The calculation should use total world size (all devices across all nodes) instead of local device count:
Proposed Solution
Use
fabric.world_size
instead of localdevices
count to properly account for all processes across nodes.Additional Context
litgpt/config_hub/pretrain/tinyllama.yaml
Line 109 in a5021be
Let me know if you'd like me to modify any part of this issue description before you post it.
What operating system are you using?
Linux
LitGPT Version
The text was updated successfully, but these errors were encountered: