Skip to content

Commit

Permalink
fix linear weight initilize in paddle and resume
Browse files Browse the repository at this point in the history
Signed-off-by: Guoxia Wang <[email protected]>
  • Loading branch information
GuoxiaWang committed Jan 21, 2025
1 parent a65ad37 commit afca65c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
7 changes: 6 additions & 1 deletion transformer_engine/paddle/layer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ def _get_fp8_state(self) -> paddle.Tensor:
state = {}
state["scaling_fwd"] = self.fp8_meta["scaling_fwd"].to_numpy()
state["scaling_bwd"] = self.fp8_meta["scaling_bwd"].to_numpy()
get_global_fp8_state().get_fp8_fwd_buffer().wait()
get_global_fp8_state().get_fp8_bwd_buffer().wait()
state["global_fp8_fwd_buffer"] = get_global_fp8_state().get_fp8_fwd_buffer().to_numpy()
state["global_fp8_bwd_buffer"] = get_global_fp8_state().get_fp8_bwd_buffer().to_numpy()
# Store other pickelable values.
Expand Down Expand Up @@ -224,7 +226,10 @@ def _set_fp8_state(self, state: paddle.Tensor) -> None:
if state is None:
return

state = pickle.loads(state.numpy().tobytes())
if isinstance(state, paddle.Tensor):
state = pickle.loads(state.numpy().tobytes())
else:
state = pickle.loads(state.tobytes())
if state is None:
return

Expand Down
14 changes: 8 additions & 6 deletions transformer_engine/paddle/layer/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,17 +786,19 @@ def __init__(

# Initialize weight parameter
with track_rng_state(enable=self.tensor_parallel):
# TE linear weight is in column major
# Ensure that the random seed state is consistent with Paddle Linear.
self.weight = self.create_parameter(
shape=(
[self.out_features, self.in_features]
if self.backend == "transformer_engine"
else [self.in_features, self.out_features]
),
shape=([self.in_features, self.out_features]),
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False,
)
# TE linear weight is in column major
if self.backend == "transformer_engine":
with paddle.no_grad():
clone_weight = self.weight.clone().t().contiguous()
self.weight.get_tensor()._share_data_with(clone_weight.get_tensor())

set_weight_tensor_dist_attr(
self.weight, self.tensor_parallel, self.parallel_mode, self.backend
)
Expand Down

0 comments on commit afca65c

Please sign in to comment.