Skip to content

Commit

Permalink
untyping tex.DType
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Jan 17, 2025
1 parent 7c7949c commit d113dda
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/fp8/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class FP8Meta:
sync_amax: bool = False

@property
def te_dtype(self) -> tex.DType:
def te_dtype(self) -> "tex.DType":
from nanotron.fp8.tensor import convert_torch_dtype_to_te_dtype

return convert_torch_dtype_to_te_dtype(self.dtype)
Expand Down
2 changes: 1 addition & 1 deletion src/nanotron/fp8/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _quantize(tensor: torch.Tensor, fp8_meta: "FP8Meta") -> torch.Tensor:
return (tensor * fp8_meta.scale).to(torch.float16)


def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> tex.DType:
def convert_torch_dtype_to_te_dtype(dtype: torch.dtype) -> "tex.DType":
# NOTE: transformer engine maintains it own dtype mapping
# so we need to manually map torch dtypes to TE dtypes
TORCH_DTYPE_TE_DTYPE_NAME_MAPPING = {
Expand Down

0 comments on commit d113dda

Please sign in to comment.