What this cast is doing * reshape the tensor into shape of (-1, block_size), where block_size is usually 32 or 16 * for each block, calculate a single scale, and then cast that block to torch.float8_e4m3fn * rearrange the scale to swizzled format expected by gemm kernel * return the casted elements and the swizzled scale What we currently see: ``` TORCH_LOGS_FORMAT=short TORCH_LOGS=aot_graphs,output_code python benchmarks/float8/profile_lowp_training.py ~/local/tmp/20250223_test --mx_recipe_name mxfp8_emulated --experiment_filter lowp --mode_filter cast_with_to_blocked ``` Output: https://gist.github.com/vkuzo/9bb4194b289003b6d8bf32d066e3f8e1 (i) one kernel to calculate the unswizzled scale and to cast the elements, (ii) one kernel to convert scale layout.