Skip to content

float8 profiling script: filter out microbenchmarking overhead #629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 79 additions & 75 deletions benchmarks/float8/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
kernel_name_to_category,
parse_bw_and_kernel_name,
profiler_output_to_gpu_time_for_key,
profiler_output_to_time_by_kernel_name,
profiler_output_to_filtered_time_by_kernel_name,
)

# don't truncate long kernel names
Expand Down Expand Up @@ -312,85 +312,89 @@ def float8_forw_backward_wrapper(x):
# if the `TORCHINDUCTOR_PROFILE` env var is enabled, parse its output
# to populate triton kernel bandwidth further down in the script
f = io.StringIO()
with redirect_stdout(f):
# warm up
for _ in range(1):
try:
with redirect_stdout(f):
# warm up
for _ in range(1):
if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

num_leaf_tensors = 1 + len(list(m_ref.parameters()))

if dtype_filter != "float8":
ref_forw_backward(input_tensor)
if dtype_filter != "bfloat16":
float8_forw_backward_wrapper(input_tensor)

profile_iters = 5
ref_times, float8_times = None, None
data = []

if dtype_filter != "float8":
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
# Profile Reference Model
print("profiling ref")
ref_suffix = f"_{model_type}_ref_compile_{compile}.json"
ref_path = profile_path_prefix + ref_suffix
profile_config = ProfileConfig(
ref_path, ref_suffix, iters=profile_iters, warmup_iters=2, sync=True
)
p = profile_function(profile_config, ref_forw_backward, input_tensor)
print(f"saved {ref_path}")
ref_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in ref_times.values()) / 1e3 / profile_iters
for k, v in ref_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"0_ref",
k,
kernel_name_to_category(k),
v_ms,
v_ms / total_time_ms,
None,
]
)

if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_time_by_kernel_name(p)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
if dtype_filter != "bfloat16":
# Profile Float8 Model
print("profiling float8")
float8_suffix = (
f"_{model_type}_float8_compile_{compile}_{scaling_repr}.json"
)
float8_path = profile_path_prefix + float8_suffix
profile_config = ProfileConfig(
float8_path,
float8_suffix,
iters=profile_iters,
warmup_iters=2,
sync=True,
)
p = profile_function(
profile_config, float8_forw_backward_wrapper, input_tensor
)
print(f"saved {float8_path}")
float8_times = profiler_output_to_filtered_time_by_kernel_name(p, profile_iters, num_leaf_tensors)
total_time_ms = sum(v for v in float8_times.values()) / 1e3 / profile_iters
for k, v in float8_times.items():
v_ms = v / 1e3 / profile_iters
data.append(
[
"1_float8",
k,
kernel_name_to_category(k),
v / 1e3 / profile_iters,
v_ms / total_time_ms,
None,
]
)

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

# get the time spent per user annotation
sync_time_us = profiler_output_to_gpu_time_for_key(
p, "scale_amax_and_scales"
)
sync_time_ms = sync_time_us / profile_iters / 1e3
print(f"Sync time ms: {sync_time_ms}")

# print the redirected stdout back to regular stdout
print(f.getvalue())
finally:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is needed so stdout is still printed if the code inside the try statement hits an exception, useful for debugging

# print the redirected stdout back to regular stdout
print(f.getvalue())

# populate the triton kernel bandwidth
for line in f.getvalue().split("\n"):
Expand Down
64 changes: 60 additions & 4 deletions benchmarks/float8/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,44 @@
from typing import Optional


def profiler_output_to_time_by_kernel_name(prof):
def profiler_output_to_filtered_time_by_kernel_name(
prof,
num_iter: int,
num_leaf_tensors: int,
):
"""
Input: a profiler with captured events.
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name
Input:
* `prof`: a profiler with captured events
* `num_iter`: number of iterations used to capture `prof`
* `num_leaf_tensors`: number of leaf tensors to accumulate gradients to
Output: a deduplicated list of GPU time in nanoseconds grouped by CPU kernel name,
with the microbenchmark overhead filtered out

Currently assumes that `prof` captured events from a microbenchmark which was
set up as follows:

#
# Forward pass
#

# Expected GPU kernel overhead: none
y = func(...)

# Convenient way to set up the backward pass without caring about shapes
y_sum = y.sum()

# Expected GPU kernel overhead:
# * the call to `sum`

#
# Backward pass
#
y_sum.backward()

# Expected GPU kernel overhead:
# * the call to `aten.fill_` to put a tensor with a single 1.0 value as the input to the backward
# * the call to `aten.copy_` to fill the first `grad_output` tensor with 1.0
# * the call to `aten.add_` to accumulate grads, once per leaf tensor

Note that if there are user_annotations in the captured events, `torch.profiler`
will include their time in the total GPU time displayed at the bottom of
Expand All @@ -23,13 +57,35 @@ def profiler_output_to_time_by_kernel_name(prof):
thresh = 1e-10
kernel_name_to_gpu_time_us = collections.defaultdict(float)
for e in key_averages:

# manually filter top-level CPU events with attributed CUDA time
# example CPU event row:
# example CPU event row from printing `key_averages`:
# aten::addmm 0.83% 76.554us 0.98% 90.846us 90.846us 1.022ms 31.82% 1.022ms 1.022ms 1
# and it maps to this CUDA event:
# sm80_xmma_gemm_f32f32_f32f32_f32_tn_n_tilesize256x64... 0.00% 0.000us 0.00% 0.000us 0.000us 1.022ms 31.82% 1.022ms 1.022ms 1
if not (e.self_cpu_time_total > thresh and e.self_device_time_total > thresh):
continue

# manually filter expected microbenchmarking overhead, in order of execution
if e.key == 'aten::sum':
# forward pass sum
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::fill_':
# filling the forward pass sum with 1.0
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::copy_':
# copying 1.0 from grad_out of `sum` to grad_out of next op
assert e.count == num_iter, f'unexpected number of iter for {e.key}'
continue
elif e.key == 'aten::add_':
# accumulating gradients into leaf tensors
assert e.count == (num_iter * num_leaf_tensors), f'unexpected number of iter for {e.key}'
continue
elif e.key == 'cudaDeviceSynchronize':
continue

kernel_name_to_gpu_time_us[e.key] = e.self_device_time_total
return kernel_name_to_gpu_time_us

Expand Down
Loading