diff --git a/modelopt/torch/quantization/backends/__init__.py b/modelopt/torch/quantization/backends/__init__.py index 92317d92b..c4c2fadd5 100644 --- a/modelopt/torch/quantization/backends/__init__.py +++ b/modelopt/torch/quantization/backends/__init__.py @@ -15,5 +15,18 @@ """Quantization backends.""" -from .gemm_registry import * -from .nvfp4_gemm import * +from .fp8_per_tensor_gemm import Fp8PerTensorLinear, _fp8_availability_check +from .gemm_registry import gemm_registry +from .nvfp4_gemm import Nvfp4Linear, _nvfp4_availability_check + +# Register default implementations +gemm_registry.register( + gemm_func=Fp8PerTensorLinear.apply, + availability_check=_fp8_availability_check, +) + +# Register default implementations +gemm_registry.register( + gemm_func=Nvfp4Linear.apply, + availability_check=_nvfp4_availability_check, +)