@@ -99,13 +99,13 @@ def short_str(self):
99
99
100
100
def __post_init__ (self ):
101
101
if self .scaling_type is ScalingType .STATIC :
102
- assert (
103
- self . static_scale is not None
104
- ), "static_scale must be specified for static scaling"
102
+ assert self . static_scale is not None , (
103
+ " static_scale must be specified for static scaling"
104
+ )
105
105
if self .scaling_granularity is ScalingGranularity .AXISWISE :
106
- assert (
107
- self . scaling_type is ScalingType . DYNAMIC
108
- ), "only dynamic scaling type is supported for axiswise scaling granularity"
106
+ assert self . scaling_type is ScalingType . DYNAMIC , (
107
+ "only dynamic scaling type is supported for axiswise scaling granularity"
108
+ )
109
109
assert self .target_dtype is None or (
110
110
self .target_dtype .is_floating_point and self .target_dtype .itemsize == 1
111
111
), "must specify a 8-bit floating-point dtype"
@@ -130,9 +130,9 @@ class DelayedScalingConfig:
130
130
scale_fn_name : str = "max"
131
131
132
132
def __post_init__ (self ):
133
- assert (
134
- self .scale_fn_name == " max"
135
- ), f" { self . scale_fn_name } is not implemented yet. Only max is supported for now."
133
+ assert self . scale_fn_name == "max" , (
134
+ f" { self .scale_fn_name } is not implemented yet. Only max is supported for now. "
135
+ )
136
136
137
137
138
138
@dataclass (frozen = True )
@@ -148,7 +148,6 @@ class Float8GemmConfig:
148
148
149
149
# Pre-made recipes for common configurations
150
150
class Float8LinearRecipeName (enum .Enum ):
151
-
152
151
# Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel
153
152
TENSORWISE = "tensorwise"
154
153
@@ -291,7 +290,9 @@ def __post_init__(self):
291
290
292
291
# float8 all-gather only supports tensorwise, in the future may support blockwise
293
292
if self .cast_config_weight .scaling_granularity != ScalingGranularity .TENSORWISE :
294
- assert not self .enable_fsdp_float8_all_gather , f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
293
+ assert not self .enable_fsdp_float8_all_gather , (
294
+ f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got { self .cast_config_weight .scaling_granularity } "
295
+ )
295
296
296
297
# save some characters in the compatibility checks below
297
298
cc_i = self .cast_config_input
@@ -310,9 +311,9 @@ def __post_init__(self):
310
311
):
311
312
is_disabled_1 = cc1 .scaling_type is ScalingType .DISABLED
312
313
is_disabled_2 = cc1 .scaling_type is ScalingType .DISABLED
313
- assert (
314
- is_disabled_1 == is_disabled_2
315
- ), f"incompatible operand precision for { gemm_name } "
314
+ assert is_disabled_1 == is_disabled_2 , (
315
+ f"incompatible operand precision for { gemm_name } "
316
+ )
316
317
317
318
for cc1 , cc2 , operand_name , default_dtype in [
318
319
(cc_i , cc_i_gw , "input" , e4m3_dtype ),
@@ -324,9 +325,9 @@ def __post_init__(self):
324
325
object .__setattr__ (cc1 , "target_dtype" , default_dtype )
325
326
if cc2 .target_dtype is None :
326
327
object .__setattr__ (cc2 , "target_dtype" , default_dtype )
327
- assert (
328
- cc1 . target_dtype == cc2 . target_dtype
329
- ), f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
328
+ assert cc1 . target_dtype == cc2 . target_dtype , (
329
+ f" { operand_name } must be cast to the same dtype in both matmuls it's used in"
330
+ )
330
331
331
332
# See the comments around `force_recompute_fp8_weight_in_bwd` for more details of this warning.
332
333
if (
@@ -357,9 +358,9 @@ def from_recipe_name(
357
358
"""
358
359
if type (recipe_name ) == str :
359
360
valid_names = [n .value for n in Float8LinearRecipeName ]
360
- assert (
361
- recipe_name in valid_names
362
- ), f"recipe_name { recipe_name } not in valid names { valid_names } "
361
+ assert recipe_name in valid_names , (
362
+ f" recipe_name { recipe_name } not in valid names { valid_names } "
363
+ )
363
364
recipe_name = Float8LinearRecipeName (recipe_name )
364
365
365
366
if recipe_name is Float8LinearRecipeName .TENSORWISE :
@@ -385,7 +386,6 @@ def from_recipe_name(
385
386
)
386
387
387
388
elif recipe_name is Float8LinearRecipeName .ROWWISE_WITH_GW_HP :
388
-
389
389
# output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1
390
390
cc_i = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
391
391
cc_w = CastConfig (scaling_granularity = ScalingGranularity .AXISWISE )
0 commit comments