We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6cfbe7d commit 7365341Copy full SHA for 7365341
torchao/float8/float8_utils.py
@@ -99,7 +99,7 @@ def amax_history_to_scale_stack(
99
100
@torch.no_grad()
101
def tensor_to_amax(
102
- x: torch.Tensor, reduce_amax: bool = False, device_mesh = None
+ x: torch.Tensor, reduce_amax: bool = False, device_mesh=None
103
) -> torch.Tensor:
104
amax = torch.max(torch.abs(x))
105
@@ -118,7 +118,7 @@ def tensor_to_scale(
118
x: torch.Tensor,
119
float8_dtype: torch.dtype,
120
reduce_amax: bool = False,
121
- device_mesh = None,
+ device_mesh=None,
122
123
amax = tensor_to_amax(x, reduce_amax=reduce_amax, device_mesh=device_mesh)
124
return amax_to_scale(amax, float8_dtype, x.dtype)
0 commit comments