-
Notifications
You must be signed in to change notification settings - Fork 536
9080 expose mat mul precision #9081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
86200a5
initial commit: binding and backends package, no tests.
yaoshiang 2901b77
Tests for default, high, and highest precision.
yaoshiang 564685a
clang-format
yaoshiang 4c0b52e
formatter
yaoshiang 0d3f44f
Updates to error messages
yaoshiang 3c333d8
fixed test class names
yaoshiang 967ccda
typo
yaoshiang 393257e
typo on error message. unit tested and yapfed.
yaoshiang f48d7f6
linter
yaoshiang 114456b
minor edits.
yaoshiang 87cff3d
Updated TODO per review
yaoshiang b96f671
Update todo and precision math per comment.
yaoshiang c26ad8f
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 30d1c19
yapf
yaoshiang 959b0dc
linter
yaoshiang 8aaa979
parameterized, but in a process isolated way.
yaoshiang 04516c7
removed dead code
yaoshiang c3856c8
added issue for repeatable, unexpected behavior.
yaoshiang 79af416
Updated docstring.
yaoshiang 6ff7f59
changed naming of is_on_tpu
yaoshiang 89c72a8
Merge branch 'master' into 9080-expose-mat_mul_precision
yaoshiang 9db6e3e
CICD friendly version, hopefully.
yaoshiang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
"""Numeric tests for default precision of mat mul.""" | ||
|
||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
|
||
import test_utils | ||
|
||
|
||
class TestMatMulPrecision(unittest.TestCase): | ||
|
||
def _make_input(self): | ||
eye = torch.eye(1024, device='cpu', dtype=torch.float64) | ||
rand_ = torch.testing.make_tensor((1024, 1024), | ||
dtype=torch.float64, | ||
device="cpu", | ||
low=0.99, | ||
high=1.01) | ||
return eye * rand_ | ||
|
||
# TODO: Figure out why either PT/XLA or unittest | ||
# is unable to successfully run this test in a parameterized way. | ||
# https://github.com/pytorch/xla/issues/9129 | ||
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') | ||
@unittest.expectedFailure | ||
def test_all(self): | ||
# The number of bit of precise mantissa expected in the result. | ||
parameters = [ | ||
('highest', 22), | ||
('high', 14), | ||
('default', 8), | ||
] | ||
# Although pytest has a slightly more elegant parameterized testing function, | ||
# all TPU tests user unittest. | ||
for i, (precision, bits) in enumerate(parameters): | ||
with self.subTest(precision=precision, bits=bits): | ||
self._test_parameterized(precision, bits) | ||
|
||
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') | ||
def test_highest(self): | ||
self._test_parameterized('highest', 22) | ||
|
||
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') | ||
def test_high(self): | ||
self._test_parameterized('high', 14) | ||
|
||
@unittest.skipIf(not test_utils.is_on_tpu(), 'Skipping, not on TPU.') | ||
def test_default(self): | ||
self._test_parameterized('default', 8) | ||
|
||
# DO NOT add epsilons to this test. These tests must be numerically exact. | ||
def _test_parameterized(self, precision, bits): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision(precision) | ||
|
||
# Diagonal matrices force mat mul through MXU | ||
# but require only one non-zero accumulation. | ||
x = self._make_input() | ||
y = self._make_input() | ||
reference_float64 = torch.matmul(x, y) | ||
|
||
# TODO: Justify this logic. Why isn't it Why is it not | ||
# 1 - ((2**8 - 1) / 2**8)**2 (equation stated by per TPU expert)? | ||
widest_atol = torch.tensor( | ||
-1 + ((2**(bits) + 1) / 2**bits)**2, dtype=torch.float64) | ||
|
||
narrowest_atol = widest_atol / 4.0 | ||
|
||
x = x.to(torch.float32).to('xla') | ||
y = y.to(torch.float32).to('xla') | ||
|
||
# Act | ||
actual = torch.matmul(x, y).to('cpu').to(torch.float64) | ||
|
||
# Disable rtol, we know exactly the atol for default, high, and highest. | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=widest_atol, | ||
) | ||
|
||
with self.assertRaises(AssertionError): | ||
torch.testing.assert_close( | ||
actual, | ||
reference_float64, | ||
rtol=0.0, | ||
atol=narrowest_atol, | ||
) | ||
|
||
assert not torch.equal(actual, reference_float64), ( | ||
"Actual product and reference product should not be closer than equal, " | ||
f"but they are: {torch.diag(actual)} == {torch.diag(reference_float64)}" | ||
) | ||
|
||
|
||
# There is no main function. This is designed to be run from | ||
# python -m unittest ... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
"""Tests for get/set_mat_mul_precision from init_python_bindings.cpp""" | ||
|
||
import sys | ||
import unittest | ||
|
||
import torch | ||
import torch_xla | ||
import torch_xla.backends | ||
|
||
|
||
class TestMatMulPrecisionGetAndSet(unittest.TestCase): | ||
|
||
def setUp(self): | ||
self._original = torch_xla.backends.get_mat_mul_precision() | ||
torch.set_printoptions(precision=20) | ||
torch_xla.sync() | ||
|
||
def tearDown(self): | ||
torch_xla.backends.set_mat_mul_precision(self._original) | ||
torch.set_printoptions(profile="default") | ||
torch_xla.sync() | ||
|
||
def test_set_mat_mul_precision_error(self): | ||
# Assert | ||
with self.assertRaises(ValueError): | ||
# Act | ||
torch_xla.backends.set_mat_mul_precision('BAD VALUE') | ||
|
||
def test_get_and_set_mat_mul_precision_default(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('default') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'default') | ||
|
||
def test_get_and_set_mat_mul_precision_high(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('high') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'high') | ||
|
||
def test_get_and_set_mat_mul_precision_highest(self): | ||
# Arrange | ||
torch_xla.backends.set_mat_mul_precision('highest') | ||
|
||
# Act | ||
status = torch_xla.backends.get_mat_mul_precision() | ||
|
||
# Assert | ||
self.assertEqual(status, 'highest') | ||
|
||
|
||
if __name__ == '__main__': | ||
test = unittest.main() | ||
sys.exit(0 if test.result.wasSuccessful() else 1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""torch_xla.backends controls the behavior of the XLA backend. | ||
|
||
This subpackage parallels the torch.backends.{cuda, cpu, mps, etc} | ||
subpackages in PyTorch. | ||
""" | ||
|
||
# See https://github.com/pytorch/pytorch/blob/main/torch/backends/mps/__init__.py | ||
# for an example of how backends are implemented in PyTorch | ||
# in the __init__.py file, despite general style guidelines against this. | ||
|
||
# Literal is available from Python 3.8, | ||
# matching the Python versions for PyTorch and PyTorch/XLA. | ||
from typing import Final, Literal, TypeAlias | ||
|
||
import torch_xla | ||
|
||
__all__ = ["set_mat_mul_precision", "get_mat_mul_precision"] | ||
|
||
# Valid values for get_mat_mul_precision/set_mat_mul_precision | ||
# Note: it is idiomatic to PyTorch to use strings rather than enums. | ||
# See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/backends/cpu/__init__.py#L9 | ||
|
||
_DEFAULT: Final = "default" | ||
_HIGH: Final = "high" | ||
_HIGHEST: Final = "highest" | ||
|
||
# Use of variables with Final typehint instead of literals is valid. | ||
_PrecisionType: TypeAlias = Literal[ | ||
_DEFAULT, _HIGH, _HIGHEST] # pyright: ignore[reportInvalidTypeForm] | ||
|
||
|
||
# Some of this description adapted from Jax documentation. | ||
# TODO: Once the numerics tutorial is released, link from this docstring. | ||
def set_mat_mul_precision(precision: _PrecisionType) -> None: | ||
yaoshiang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Control the default matmul and conv precision for 32bit inputs. | ||
|
||
Some platforms, like TPU, offer configurable precision levels for | ||
matrix multiplication and convolution computations, | ||
trading off accuracy for speed. | ||
|
||
This option controls the default precision level for | ||
computations involved in matrix multiplication and convolution on | ||
32bit inputs. The levels describe the precision at | ||
which scalar products are computed. | ||
|
||
On a TPU: | ||
* `default` is the fastest and least precise, | ||
downcasting an FP32 to BF16 before multiplying. | ||
|
||
* `high` takes three passes and generates approximately 14 bits of | ||
precision. | ||
|
||
* `highest` is the most precise, and the slowest. It takes six | ||
passes and generates approximately 22 bits of precision. | ||
|
||
Args: | ||
precision (str): The precision to set for matrix multiplication. | ||
Must be one of 'default', 'high', or 'highest'. | ||
""" | ||
if precision not in [_DEFAULT, _HIGH, _HIGHEST]: | ||
raise ValueError(f"Invalid precision: {precision}. " | ||
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") | ||
|
||
torch_xla._XLAC._xla_set_mat_mul_precision(precision) | ||
|
||
|
||
def get_mat_mul_precision() -> _PrecisionType: | ||
"""Get the current mat mul precision for 32bit inputs. | ||
|
||
Returns: | ||
str: The current precision setting for matrix multiplication, | ||
one of 'default', 'high', or 'highest'. | ||
""" | ||
precision = torch_xla._XLAC._xla_get_mat_mul_precision() | ||
assert precision in [_DEFAULT, _HIGH, _HIGHEST | ||
], (f"Invalid precision: {precision}. " | ||
f"Must be one of {_DEFAULT}, {_HIGH}, {_HIGHEST}.") | ||
return precision |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.