Skip to content

Commit 14cfbc7

Browse files
authored
formating config.py
1 parent 78a3412 commit 14cfbc7

File tree

1 file changed

+21
-21
lines changed

1 file changed

+21
-21
lines changed

torchao/float8/config.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,13 @@ def short_str(self):
9999

100100
def __post_init__(self):
101101
if self.scaling_type is ScalingType.STATIC:
102-
assert (
103-
self.static_scale is not None
104-
), "static_scale must be specified for static scaling"
102+
assert self.static_scale is not None, (
103+
"static_scale must be specified for static scaling"
104+
)
105105
if self.scaling_granularity is ScalingGranularity.AXISWISE:
106-
assert (
107-
self.scaling_type is ScalingType.DYNAMIC
108-
), "only dynamic scaling type is supported for axiswise scaling granularity"
106+
assert self.scaling_type is ScalingType.DYNAMIC, (
107+
"only dynamic scaling type is supported for axiswise scaling granularity"
108+
)
109109
assert self.target_dtype is None or (
110110
self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1
111111
), "must specify a 8-bit floating-point dtype"
@@ -130,9 +130,9 @@ class DelayedScalingConfig:
130130
scale_fn_name: str = "max"
131131

132132
def __post_init__(self):
133-
assert (
134-
self.scale_fn_name == "max"
135-
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
133+
assert self.scale_fn_name == "max", (
134+
f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."
135+
)
136136

137137

138138
@dataclass(frozen=True)
@@ -148,7 +148,6 @@ class Float8GemmConfig:
148148

149149
# Pre-made recipes for common configurations
150150
class Float8LinearRecipeName(enum.Enum):
151-
152151
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
153152
TENSORWISE = "tensorwise"
154153

@@ -291,7 +290,9 @@ def __post_init__(self):
291290

292291
# float8 all-gather only supports tensorwise, in the future may support blockwise
293292
if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE:
294-
assert not self.enable_fsdp_float8_all_gather, f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
293+
assert not self.enable_fsdp_float8_all_gather, (
294+
f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}"
295+
)
295296

296297
# save some characters in the compatibility checks below
297298
cc_i = self.cast_config_input
@@ -310,9 +311,9 @@ def __post_init__(self):
310311
):
311312
is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED
312313
is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED
313-
assert (
314-
is_disabled_1 == is_disabled_2
315-
), f"incompatible operand precision for {gemm_name}"
314+
assert is_disabled_1 == is_disabled_2, (
315+
f"incompatible operand precision for {gemm_name}"
316+
)
316317

317318
for cc1, cc2, operand_name, default_dtype in [
318319
(cc_i, cc_i_gw, "input", e4m3_dtype),
@@ -324,9 +325,9 @@ def __post_init__(self):
324325
object.__setattr__(cc1, "target_dtype", default_dtype)
325326
if cc2.target_dtype is None:
326327
object.__setattr__(cc2, "target_dtype", default_dtype)
327-
assert (
328-
cc1.target_dtype == cc2.target_dtype
329-
), f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
328+
assert cc1.target_dtype == cc2.target_dtype, (
329+
f"{operand_name} must be cast to the same dtype in both matmuls it's used in"
330+
)
330331

331332
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
332333
if (
@@ -357,9 +358,9 @@ def from_recipe_name(
357358
"""
358359
if type(recipe_name) == str:
359360
valid_names = [n.value for n in Float8LinearRecipeName]
360-
assert (
361-
recipe_name in valid_names
362-
), f"recipe_name {recipe_name} not in valid names {valid_names}"
361+
assert recipe_name in valid_names, (
362+
f"recipe_name {recipe_name} not in valid names {valid_names}"
363+
)
363364
recipe_name = Float8LinearRecipeName(recipe_name)
364365

365366
if recipe_name is Float8LinearRecipeName.TENSORWISE:
@@ -385,7 +386,6 @@ def from_recipe_name(
385386
)
386387

387388
elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP:
388-
389389
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
390390
cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
391391
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

0 commit comments

Comments
 (0)