Skip to content

Commit

Permalink
Add fp8 training support with HPU (#149)
Browse files Browse the repository at this point in the history
Add support for fp8 training

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jerome Anand <[email protected]>
  • Loading branch information
3 people authored Feb 16, 2024
1 parent 85c8bbf commit 6dbd7c9
Show file tree
Hide file tree
Showing 10 changed files with 484 additions and 140 deletions.
3 changes: 2 additions & 1 deletion .azure/hpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,12 @@ jobs:
tests/test_pytorch/test_dynamic_shapes.py \
tests/test_pytorch/test_datamodule.py \
tests/test_pytorch/test_profiler.py \
tests/test_pytorch/test_precision.py \
--hpus 1 --junitxml=hpu_test-torch-results.xml
displayName: 'HPU General tests'
- bash: |
python -m pytest -sv tests/test_pytorch/test_compile.py \
python -m pytest -sv tests/test_pytorch/test_compile.py \
--hpus 1 --junitxml=hpu_compile_test-results.xml
env:
PT_HPU_LAZY_MODE: 0
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added DeepSpeed precision plugin for HPU ([#147](https://github.com/Lightning-AI/lightning-Habana/pull/147))
- Added support for fp8 training. ([#149](https://github.com/Lightning-AI/lightning-Habana/pull/149))

### Changed

Expand Down
44 changes: 42 additions & 2 deletions docs/source/intermediate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ These also allow for fine tuning with `enabled` for enabling and disabling mixed
class AutocastModelCM(nn.Module):
# Autocast can be used as a context manager to the required code block.
def forward(self, input):
with torch.autocast("device_type="hpu", dtype=torch.bfloat16):
with torch.autocast(device_type="hpu", dtype=torch.bfloat16):
...
return
class AutocastModelDecorator(nn.Module):
# Autocast can be used as a decorator to the required code block.
@torch.autocast("device_type="hpu", dtype=torch.bfloat16)
@torch.autocast(device_type="hpu", dtype=torch.bfloat16)
def forward(self, input):
...
return
Expand All @@ -111,6 +111,46 @@ and `Automatic Mixed Precision Package: torch.autocast <https://pytorch.org/docs

----

fp8 Training
----------------------------------------

Lightning supports fp8 training using HPUPrecisionPlugin, :class:`~lightning_habana.pytorch.plugins.precision.HPUPrecisionPlugin`.
fp8 training is only available on Gaudi2 and above. Output from fp8 supported modules is in `torch.bfloat16`.

The plugin accepts following args for the fp8 training:

1. `replace_layers` : Set `True` to let the plugin replace `torch.nn.Modules` with `trandformer_engine` equivalent modules. You can directly import and use modules from `transformer_engine` as well.

2. `recipe` : fp8 recipe used in training.

.. code-block:: python
from lightning import Trainer
from lightning_habana.pytorch.accelerator import HPUAccelerator
from lightning_habana.pytorch.plugins.precision import HPUPrecisionPlugin
from habana_frameworks.torch.hpex.experimental.transformer_engine import recipe
model = BoringModel()
# init the precision plugin for fp8 training.
plugin = HPUPrecisionPlugin(precision="fp8", replace_layers=True, recipe=recipe.DelayedScaling())
# Replace torch.nn.Modules with transformer engine equivalent modules
plugin.replace_modules(model)
# Initialize a trainer with HPUPrecisionPlugin
trainer = Trainer(
accelerator=HPUAccelerator(),
plugins=plugin
)
# Train the model ⚡
trainer.fit(model)
For more details, `recipes`, and list of supported `transformer_engine` modules, refer to `FP8 Training with Intel Gaudi Transformer Engine <https://docs.habana.ai/en/latest/PyTorch/PyTorch_FP8_Training/index.html>`__.

----

Enabling DeviceStatsMonitor with HPUs
----------------------------------------

Expand Down
14 changes: 12 additions & 2 deletions examples/pytorch/mnist_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"HPUPrecisionPlugin",
"MixedPrecisionPlugin",
"recipe_caching",
"fp8_training",
]

OPTIONAL_RUN_TYPE = [
Expand Down Expand Up @@ -148,9 +149,11 @@ def get_model(run_type):
def get_plugins(run_type):
"""Select plugin."""
if run_type == "HPUPrecisionPlugin":
return [HPUPrecisionPlugin(device="hpu", precision="bf16-mixed")]
return HPUPrecisionPlugin(device="hpu", precision="bf16-mixed")
if run_type == "MixedPrecisionPlugin":
return [MixedPrecision(device="hpu", precision="bf16-mixed")]
return MixedPrecision(device="hpu", precision="bf16-mixed")
if run_type == "fp8_training":
return HPUPrecisionPlugin(device="hpu", precision="fp8")
return None


Expand Down Expand Up @@ -187,9 +190,16 @@ def run_training(run_type, options, model, data_module, plugin):
print(f"Running MNIST mixed precision training with {options=}")
# Run model and print accuracy
for _run_type in options.run_types:
if HPUAccelerator.get_device_name() == "GAUDI" and _run_type == "fp8_training":
print("fp8 training not supported on GAUDI. Skipping.")
continue

seed_everything(42)
model, data_module = get_model(_run_type)
plugin = get_plugins(_run_type)
if _run_type == "fp8_training":
plugin.convert_modules(model)

if options.verbose:
print(f"Running {_run_type=} with {model=}, and {plugin=}")
run_training(_run_type, options, model, data_module, plugin)
9 changes: 7 additions & 2 deletions src/lightning_habana/pytorch/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from lightning_utilities import module_available

from lightning_habana.utils.imports import _HABANA_FRAMEWORK_AVAILABLE
from lightning_habana.utils.resources import _parse_hpus, device_count, get_device_stats
from lightning_habana.utils.resources import _parse_hpus, device_count, get_device_stats, is_fp8_available

if _HABANA_FRAMEWORK_AVAILABLE:
import habana_frameworks.torch.core as htcore
Expand Down Expand Up @@ -88,6 +88,11 @@ def get_device_name() -> str:
except (AttributeError, NameError):
return ""

@staticmethod
def is_fp8_available() -> Tuple[bool, str]:
"""Returns a bool indicating if fp8 is available, with reason if not available."""
return is_fp8_available()

@staticmethod
def is_lazy() -> bool:
"""Checks if lazy is enabled or not."""
Expand Down
97 changes: 88 additions & 9 deletions src/lightning_habana/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from typing import Generator, Literal
from contextlib import _GeneratorContextManager, contextmanager
from typing import Any, Generator, Literal, Mapping, Optional, Union

import torch
from lightning_utilities import module_available
from typing_extensions import get_args

if module_available("lightning"):
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.plugins.precision import Precision
elif module_available("pytorch_lightning"):
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision import Precision
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
else:
raise ModuleNotFoundError("You are missing `lightning` or `pytorch-lightning` package, please install it.")

from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_EQUAL_1_11_0
from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_EQUAL_1_11_0, _HPU_SYNAPSE_GREATER_EQUAL_1_14_0
from lightning_habana.utils.resources import _HABANA_FRAMEWORK_AVAILABLE, is_fp8_available

_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed"]
_PRECISION_INPUT = Literal["32", "32-true", "bf16", "bf16-mixed", "fp8"]

if _HPU_SYNAPSE_GREATER_EQUAL_1_14_0 and _HABANA_FRAMEWORK_AVAILABLE:
# Required for training in fp8 using habana transformer engine
import habana_frameworks.torch.hpex.experimental.transformer_engine as tengine
from habana_frameworks.torch.hpex.experimental.transformer_engine.recipe import DelayedScaling

class HPUPrecisionPlugin(PrecisionPlugin):
"""Plugin that enables bfloat support on HPUs.

class HPUPrecisionPlugin(Precision):
"""Plugin that enables mixed precision support on HPUs.
Args:
precision: to enable ``torch.bfloat16`` (``'bf16-mixed'``).
Expand All @@ -44,6 +52,8 @@ def __init__(
self,
precision: _PRECISION_INPUT,
device: str = "hpu",
recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None,
replace_layers: bool = False,
) -> None:
if not _HPU_SYNAPSE_GREATER_EQUAL_1_11_0:
raise OSError("HPU precision plugin requires `Synapse AI release >= 1.11.0`.")
Expand All @@ -54,13 +64,82 @@ def __init__(
f" `precision` must be one of: {supported_precision}."
)
self.precision = precision
self.replace_layers = False
self.device = device

def autocast_context_manager(self) -> torch.autocast:
if any([recipe, replace_layers]) and precision != "fp8":
rank_zero_warn(f"Precision is not 'fp8'. Params {recipe=} and {replace_layers=} will not be set.")

self.recipe = None
self.fp8_train_available = False

if self.precision == "fp8":
fp8_available, reason_no_fp8 = is_fp8_available()
if not fp8_available:
raise NotImplementedError(f"fp8 not supported: {reason_no_fp8}.")
self.recipe = recipe
self.fp8_train_available = fp8_available
self.replace_layers = replace_layers
rank_zero_info(f"fp8 training available: {self.fp8_train_available}.")

def convert_modules(self, module: torch.nn.Module) -> torch.nn.Module:
"""Replace layers of a module with Transformer engine equivalent layers."""
if self.replace_layers is True and self.fp8_train_available:
# In case model already contains a transformer engine modules,
# assume user responsibility for conversion of required layers.
if any(
"habana_frameworks.torch.hpex.experimental.transformer_engine" in m.__module__ for m in module.modules()
):
rank_zero_info(
f"Module {module} already contains transformer engine equivalent modules. Skipping conversion"
)
else:
_replace_layers(module)
return module

def autocast_context_manager(self) -> Union[_GeneratorContextManager[Any], torch.autocast]:
"""Return Autocast context manager."""
if self.fp8_train_available:
return _nested_precision_cm(fp8_enabled=(self.precision == "fp8"), recipe=self.recipe)
return torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
"""Enable autocast context."""
with self.autocast_context_manager():
yield


def _replace_layers(module: torch.nn.Module) -> None:
"""Replace layers with Transformer engine equivalent layers.
Args: torch.nn.Module.
Return: transformer engine equivalent of torch.nn.Module.
List of supported modules: https://docs.habana.ai/en/latest/PyTorch/PyTorch_FP8_Training/index.html
Eg. torch.nn.Linear -> transformer_engine.Linear
"""
for name, child in module.named_children():
if isinstance(child, torch.nn.Linear):
has_bias = child.bias is not None
replacement = tengine.Linear(child.in_features, child.out_features, bias=has_bias)
rank_zero_info(f"Replacing layer {name} with transformer engine equivalent")
module.__setattr__(name, replacement)
else:
_replace_layers(child)


@contextmanager
def _nested_precision_cm(
fp8_enabled: bool, recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]]
) -> Generator[Any, Any, Any]:
"""CM to nest fp8 precision with torch.autocast.
This enables the ops that do not support fp8 to run with torch autocast.
"""
with torch.autocast(device_type="hpu", dtype=torch.bfloat16, enabled=True), tengine.fp8_autocast(
enabled=fp8_enabled, fp8_recipe=recipe
):
yield
1 change: 1 addition & 0 deletions src/lightning_habana/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lightning_habana.utils.resources import _HABANA_FRAMEWORK_AVAILABLE, get_hpu_synapse_version # noqa: F401

_HPU_SYNAPSE_GREATER_EQUAL_1_11_0 = Version(get_hpu_synapse_version()) >= Version("1.11.0")
_HPU_SYNAPSE_GREATER_EQUAL_1_14_0 = Version(get_hpu_synapse_version()) >= Version("1.14.0")
_TORCH_LESSER_EQUAL_1_13_1 = compare_version("torch", operator.le, "1.13.1")
_TORCH_GREATER_EQUAL_2_0_0 = compare_version("torch", operator.ge, "2.0.0")
_LIGHTNING_GREATER_EQUAL_2_0_0 = compare_version("lightning", operator.ge, "2.0.0") or compare_version(
Expand Down
15 changes: 14 additions & 1 deletion src/lightning_habana/utils/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _parse_hpu_synapse_versions(line: str) -> Tuple[str, str]:
"""
hl = fw = ""

try:
# Item "None" of "Optional[Match[str]]" has no attribute "group"
hl = re.search(r"hl-([\d\.]+)", line).group(1) # type: ignore[union-attr]
Expand Down Expand Up @@ -125,3 +124,17 @@ def device_count() -> int:
except (AttributeError, NameError):
rank_zero_debug("Function `device_count` failed, returning default count of 8.")
return 8


@lru_cache
def is_fp8_available() -> Tuple[bool, str]:
"""Returns a bool indicating if fp8 is available."""
from lightning_habana.utils.imports import _HPU_SYNAPSE_GREATER_EQUAL_1_14_0

if not _HPU_SYNAPSE_GREATER_EQUAL_1_14_0:
raise OSError("fp8 training requires `Synapse AI release >= 1.14.0`.")
if not _HABANA_FRAMEWORK_AVAILABLE:
raise OSError("Habana Frameworks required for training on Habana devices.")
import habana_frameworks.torch.hpex.experimental.transformer_engine as tengine

return tengine.fp8.is_fp8_available()
1 change: 0 additions & 1 deletion tests/test_pytorch/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,6 @@ def test_all_stages_with_compile(tmpdir, hpus):
@pytest.mark.standalone()
@pytest.mark.skipif(HPUAccelerator.auto_device_count() <= 1, reason="Test requires multiple HPU devices")
@pytest.mark.usefixtures("_is_compile_allowed")
@pytest.mark.parametrize("hpus", [2])
def test_parallel_strategy_with_compile(tmp_path, hpus):
"""Tests compiled BoringModel on HPU."""
model = BoringModel()
Expand Down
Loading

0 comments on commit 6dbd7c9

Please sign in to comment.