Skip to content

Commit 812aece

Browse files
vkuzoamdfaa
authored andcommitted
fix local float8 tests on H100 (#1438)
Summary: float8 tests were failing my machine. I bisected the failure to #1344. Further investigation found that that PR was fine, but the tolerance for one of the tests was too tight, adding a test changed the random seed of the data, and things started failing. Switching to SQNR for a more robust measurement. Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent db07bb8 commit 812aece

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

test/float8/test_base.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -730,14 +730,10 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
730730
emulated_config,
731731
GemmInputRole.WEIGHT,
732732
)
733-
out_emualted = a_fp8 @ b_fp8
734-
out_emualted.to(compare_type)
735-
736-
if base_dtype in {torch.bfloat16, torch.float16}:
737-
atol, rtol = 7e-2, 7e-2
738-
else:
739-
atol, rtol = 2e-3, 2e-3
740-
torch.testing.assert_close(out_padded, out_emualted, atol=atol, rtol=rtol)
733+
out_emulated = a_fp8 @ b_fp8
734+
out_emulated.to(compare_type)
735+
sqnr = compute_error(out_padded, out_emulated)
736+
assert sqnr > 50.0
741737

742738

743739
class TestNumerics:

0 commit comments

Comments
 (0)