Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitgola005 committed Sep 4, 2024
1 parent 894541a commit aa47854
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 5 additions & 9 deletions src/lightning_habana/pytorch/plugins/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import Any, Callable, Mapping, Optional, Union

import torch
Expand Down Expand Up @@ -139,14 +138,11 @@ def clip_gradients(

def _enable_fp8_inference(
self,
quant: bool = True,
module: torch.nn.Module,
quant: Optional[Union[bool, str, dict]] = True,
fp8_data_path: Optional[str] = None,
) -> None:
"""Convert modules for fp8 inference.
This module cannot be used with trainer.fit.
"""
"""Enable fp8 inference."""
htcore.hpu_set_env()
self.quant = quant
self.fp8_data_path = fp8_data_path
Expand All @@ -162,7 +158,7 @@ def convert_modules(
) -> torch.nn.Module:
"""Enable support for fp8."""
if inference and self.fp8_inference_available:
self._enable_fp8_inference(quant, fp8_data_path)
self._enable_fp8_inference(module=module, quant=quant, fp8_data_path=fp8_data_path)
if not inference and self.fp8_training_available:
self._enable_fp8_training(module, replace_layers, recipe)
self._enable_fp8_training(module=module, replace_layers=replace_layers, recipe=recipe)
return module
6 changes: 5 additions & 1 deletion src/lightning_habana/pytorch/plugins/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _setup_fp8_inference_config(

fp8_config = MAXABS_QUANT if quant is True else MAXABS_MEASURE if quant is False else quant
fp8_data_path = fp8_data_path if fp8_data_path is not None else os.environ.get("HABANA_LOGS")
assert fp8_data_path is not None

if isinstance(fp8_config, str):
if os.path.isfile(fp8_config):
Expand Down Expand Up @@ -161,7 +162,10 @@ def _enable_fp8_inference(
self._setup_fp8_inference_modules(module, quant, fp8_data_path)

def _setup_fp8_inference_modules(
self, module: torch.nn.Module, quant: bool = True, fp8_data_path: Optional[str] = None
self,
module: torch.nn.Module,
quant: Optional[Union[bool, str, dict]] = True,
fp8_data_path: Optional[str] = None,
) -> None:
"""Convert module for fp8 inference."""
from neural_compressor.torch.quantization import FP8Config, convert, prepare
Expand Down

0 comments on commit aa47854

Please sign in to comment.