From 50279e9954e9d575c0c3fd69cccd6cffc6d3c091 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 18 Nov 2024 15:47:24 -0800 Subject: [PATCH 1/2] Added cpu support --- torchao/_models/llama/eval.py | 1 - torchao/_models/llama/generate.py | 31 +++++++++++++++++++------------ 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 25b65cd1ec..8e19ef4054 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -22,7 +22,6 @@ unwrap_tensor_subclass, float8_weight_only, float8_dynamic_activation_float8_weight, - float8_static_activation_float8_weight, ) from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 1efa6b04b3..7e5bbc57a1 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -67,7 +67,7 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): new_tokens, new_probs = [], [] for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): next_token, next_prob = decode_one_token( model, cur_token, input_pos, **sampling_kwargs ) @@ -347,7 +347,10 @@ def main( prefill = torch.compile(prefill, fullgraph=True, dynamic=True) if memory_profile: - torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) + if device != "cuda": + print("Memory profiling only works on CUDA") + else: + torch.cuda.memory._record_memory_history(True,trace_alloc_max_entries=250000, trace_alloc_record_context=True) aggregate_metrics = { 'tokens_per_sec': [], } @@ -355,7 +358,8 @@ def main( for i in range(start, num_samples): if i==0: - torch.cuda.reset_peak_memory_stats() + if device == "cuda": + torch.cuda.reset_peak_memory_stats() # MKG device_sync(device=device) # MKG if i >= 0 and interactive: prompt = input("What is your prompt? ") @@ -423,15 +427,18 @@ def callback(x): print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") if memory_profile and i==0: - snapshot = torch.cuda.memory._snapshot() - with open(f"{memory_profile}.pickle", 'wb') as f: - from pickle import dump - dump(snapshot, f) - print( - f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", - "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" - ) - break + if device != "cuda": + print("Memory profiling only works on CUDA") + else: + snapshot = torch.cuda.memory._snapshot() + with open(f"{memory_profile}.pickle", 'wb') as f: + from pickle import dump + dump(snapshot, f) + print( + f"\nmemory profile {memory_profile}.pickle saved, to convert that to a usable file, use", + "python pytorch/torch/cuda/_memory_viz.py trace_plot -o .html" + ) + break print("==========") From 569f0b584f709dbb66c48f3c9b7e595d11917a2b Mon Sep 17 00:00:00 2001 From: jainapurva Date: Mon, 18 Nov 2024 16:06:26 -0800 Subject: [PATCH 2/2] Eval fixes --- torchao/_models/llama/eval.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index 8e19ef4054..43667487d8 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -10,26 +10,22 @@ from generate import ( _load_model, device_sync, - ) -from torchao.quantization.quant_api import ( +from torchao.quantization import ( quantize_, int4_weight_only, int8_weight_only, int8_dynamic_activation_int8_weight, fpx_weight_only, uintx_weight_only, - unwrap_tensor_subclass, float8_weight_only, float8_dynamic_activation_float8_weight, ) -from torchao._models._eval import TransformerEvalWrapper, InputRecorder from torchao._models.llama.model import prepare_inputs_for_model -from torchao.quantization.granularity import PerRow, PerTensor - +from torchao.quantization import PerRow, PerTensor from tokenizer import get_tokenizer import time -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass def run_evaluation( checkpoint_path: Path,