Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions scripts/compare-llama-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,23 @@
parser.add_argument("--check", action="store_true", help="check if all required Python libraries are installed")
parser.add_argument("-s", "--show", help=help_s)
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
parser.add_argument("--plot", help="generate a performance comparison plot and save to specified file (e.g., plot.png)")
parser.add_argument("--plot_x", help="parameter to use as x-axis for plotting (default: n_depth)", default="n_depth")

known_args, unknown_args = parser.parse_known_args()

logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO)

# Check for matplotlib if plotting is requested
if known_args.plot:
try:
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')
except ImportError as e:
print("matplotlib is required for --plot.")
raise e

if known_args.check:
# Check if all required Python libraries are installed. Would have failed earlier if not.
sys.exit(0)
Expand Down Expand Up @@ -600,6 +612,157 @@
headers = [PRETTY_NAMES[p] for p in show]
headers += ["Test", f"t/s {name_baseline}", f"t/s {name_compare}", "Speedup"]

if known_args.plot:
def create_performance_plot(table_data, headers, baseline_name, compare_name, output_file, plot_x_param):

data_headers = headers[:-4] #Exclude the last 4 columns (Test, baseline t/s, compare t/s, Speedup)
plot_x_index = None
plot_x_label = plot_x_param

if plot_x_param not in ["n_prompt", "n_gen", "n_depth"]:
pretty_name = PRETTY_NAMES.get(plot_x_param, plot_x_param)
if pretty_name in data_headers:
plot_x_index = data_headers.index(pretty_name)
plot_x_label = pretty_name
elif plot_x_param in data_headers:
plot_x_index = data_headers.index(plot_x_param)
plot_x_label = plot_x_param
else:
logger.error(f"Parameter '{plot_x_param}' not found in current table columns. Available columns: {', '.join(data_headers)}")
logger.error(f"To plot by '{plot_x_param}', include it in --show parameter or ensure it varies in your data.")
return
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code should just automatically add plot_x_param to the shown if it's not contained already.


grouped_data = {}

for i, row in enumerate(table_data):
group_key_parts = []
test_name = row[-4]

if plot_x_param in ["n_prompt", "n_gen", "n_depth"]:
for j, val in enumerate(row[:-4]):
header_name = data_headers[j]
if val is not None and str(val).strip():
group_key_parts.append(f"{header_name}={val}")

if plot_x_param == "n_prompt":
assert "pp" in test_name, f"n_prompt test name {test_name} does not contain 'pp'"
base_test = test_name.split("@")[0]
x_value = base_test
elif plot_x_param == "n_gen" and "tg" in test_name:
assert "tg" in test_name, f"n_gen test name {test_name} does not contain 'tg'"
x_value = test_name.split("@")[0]
elif plot_x_param == "n_depth" and "@d" in test_name:
assert "@d" in test_name, f"n_depth test name {test_name} does not contain '@d'"
base_test = test_name.split("@d")[0]
x_value = int(test_name.split("@d")[1])
else:
base_test = test_name

if base_test.strip():

Check failure on line 661 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"base_test" is possibly unbound (reportPossiblyUnboundVariable)
group_key_parts.append(f"Test={base_test}")

Check failure on line 662 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"base_test" is possibly unbound (reportPossiblyUnboundVariable)
else:
for j, val in enumerate(row[:-4]):
if j != plot_x_index:
header_name = data_headers[j]
if val is not None and str(val).strip():
group_key_parts.append(f"{header_name}={val}")
else:
x_value = val

group_key_parts.append(f"Test={test_name}")

group_key = tuple(sorted(group_key_parts))

if group_key not in grouped_data:
grouped_data[group_key] = []

grouped_data[group_key].append({
'x_value': x_value,

Check failure on line 680 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"x_value" is possibly unbound (reportPossiblyUnboundVariable)
'baseline': float(row[-3]),
'compare': float(row[-2]),
'speedup': float(row[-1])
})

if not grouped_data:
logger.error("No data available for plotting")
return


def make_axes(num_groups, max_cols=2, base_size=(8, 4)):
from math import ceil
cols = 1 if num_groups == 1 else min(max_cols, num_groups)
rows = ceil(num_groups / cols)

# scale figure size by grid dimensions
w, h = base_size
fig, ax_arr = plt.subplots(rows, cols,

Check failure on line 698 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"plt" is possibly unbound (reportPossiblyUnboundVariable)
figsize=(w * cols, h * rows),
squeeze=False)

axes = ax_arr.flatten()[:num_groups]
return fig, axes

num_groups = len(grouped_data)
fig, axes = make_axes(num_groups)

plot_idx = 0

for group_key, points in grouped_data.items():
if plot_idx >= len(axes):
break
ax = axes[plot_idx]

try:
points_sorted = sorted(points, key=lambda p: float(p['x_value']) if p['x_value'] is not None else 0)
x_values = [float(p['x_value']) if p['x_value'] is not None else 0 for p in points_sorted]
except ValueError:
points_sorted = sorted(points, key=lambda p: group_key)
x_values = [p['x_value'] for p in points_sorted]

baseline_vals = [p['baseline'] for p in points_sorted]
compare_vals = [p['compare'] for p in points_sorted]

ax.plot(x_values, baseline_vals, 'o-', color='skyblue',
label=f'{baseline_name}', linewidth=2, markersize=6)
ax.plot(x_values, compare_vals, 's--', color='lightcoral', alpha=0.8,
label=f'{compare_name}', linewidth=2, markersize=6)

if plot_x_param == "n_depth" and max(x_values) > 0 and max(x_values) > min(x_values) * 4:
ax.set_xscale('log', base=2)
unique_x = sorted(set(x_values))
ax.set_xticks(unique_x)
ax.set_xticklabels([str(int(x)) for x in unique_x])

title_parts = []
for part in group_key:
if '=' in part:
key, value = part.split('=', 1)
title_parts.append(f"{key}: {value}")

title = ', '.join(title_parts) if title_parts else "Performance Comparison"

ax.set_xlabel(plot_x_label, fontsize=12, fontweight='bold')
ax.set_ylabel('Tokens per Second (t/s)', fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=12, fontweight='bold')
ax.legend(loc='best', fontsize=10)
ax.grid(True, alpha=0.3)

plot_idx += 1

for i in range(plot_idx, len(axes)):
axes[i].set_visible(False)

fig.suptitle(f'Performance Comparison: {compare_name} vs {baseline_name}',
fontsize=14, fontweight='bold')
fig.subplots_adjust(top=1)


plt.tight_layout()

Check failure on line 760 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"plt" is possibly unbound (reportPossiblyUnboundVariable)
plt.savefig(output_file, dpi=300, bbox_inches='tight')

Check failure on line 761 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"plt" is possibly unbound (reportPossiblyUnboundVariable)
plt.close()

Check failure on line 762 in scripts/compare-llama-bench.py

View workflow job for this annotation

GitHub Actions / pyright type-check

"plt" is possibly unbound (reportPossiblyUnboundVariable)

create_performance_plot(table, headers, name_baseline, name_compare, known_args.plot, known_args.plot_x)

print(tabulate( # noqa: NP100
table,
headers=headers,
Expand Down
Loading