diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py new file mode 100644 index 0000000000..f58fab5cfd --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_modelsize.py @@ -0,0 +1,348 @@ +import sys + +import torch +import torch.nn as nn +from torchao.quantization import quantize_ +import random + +from naive_intNwo import intN_weight_only + +import copy +from lm_eval.evaluator import evaluate +from lm_eval.models.huggingface import HFLM +from lm_eval.tasks import get_task_dict + +from transformers import AutoModelForCausalLM, AutoTokenizer +from ax.service.ax_client import AxClient, ObjectiveProperties +import torch.multiprocessing as mp +from ax.modelbridge.cross_validation import cross_validate +from utils import write_history_to_csv, cal_wikitext_ppl, cal_model_size, load_model, quantize_by_fqn_to_config + +# return evaluation results to complete BO trials +def eval(model, tokenizer, num_PPL_eval_samples, fqn_to_config): + return { + "cal_PPL": (cal_wikitext_ppl(model, tokenizer, num_PPL_eval_samples), 0.0), + "model_size": (cal_model_size(model, fqn_to_config), 0.0), + } + +# TODO: make it into a yaml or json file to enable users specify their custom model formats +def define_parameter_list(): + + # define the search space for all layers + parameters_list = [] + + for i in range(0, 3): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "fixed", + "value_type": "int", + "value": 5, + "is_ordered": True, + "sort_values": True, + } + ) + + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "fixed", + "value_type": "int", + "value": 32, + "is_ordered": True, + "sort_values": True, + } + ) + + for i in range(3, 30): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "choice", + "value_type": "int", + "values": [2,3,4,5,6,8], + "is_ordered": True, + "sort_values": True, + } + ) + + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "choice", + "value_type": "int", + "values": [32, 64, 128, 256], + "is_ordered": True, + "sort_values": True, + } + ) + + for i in range(30, 32): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "fixed", + "value_type": "int", + "value": 5, + "is_ordered": True, + "sort_values": True, + } + ) + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "fixed", + "value_type": "int", + "value": 32, + "is_ordered": True, + "sort_values": True, + } + ) + + return parameters_list + +# add initial search points based on the sensitivity score +# TODO: automate the initial samples by better leverage the sensitivity scores +def get_initial_samples(num_BO_initial_samples=50): + initial_points_set = [] + + # auto sample the bit choices with random choice probability positive correlated to FIT score + for _ in range(num_BO_initial_samples): + initial_points = {} + for i in range(0, 3): + initial_points["bitwidth." + str(i) + "."] = 5 + initial_points["groupsize." + str(i) + "."] = 32 + + for i in range(3, 18): + if i in [5,6,7,10,11,12,16]: + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [20, 80])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [30, 70])[0] + else: + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4], [30, 70])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [40, 60])[0] + + for i in range(18, 30): + if i in [22,23,24]: + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [20, 55, 20, 5])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [30, 40, 25, 5])[0] + else: + initial_points["bitwidth." + str(i) + "."] = random.choices([5, 4, 3, 2], [30, 55, 10, 5])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [40, 40, 15, 5])[0] + + for i in range(30, 32): + initial_points["bitwidth." + str(i) + "."] = 5 + initial_points["groupsize." + str(i) + "."] = 32 + + initial_points_set.append(initial_points) + return initial_points_set + +''' +This function will run BO trials sequentially on a single GPU. +Each time the BO gets one new trial, evaluates the trial on the GPU and return the evaluation results to update the BO. +One trial, one BO update. +TODO: refactor the sequential BO and parallel BO into a single function +''' +def run_sequential_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, output_file): + + parameters_list = define_parameter_list() + initial_points_set = get_initial_samples(num_BO_initial_samples) + + #initialize ax_client + constraint="model_size <= "+str(model_size_constraint) + ax_client = AxClient() + ax_client.create_experiment( + parameters = parameters_list, + name = "test_quantize_BO", + objectives = {"cal_PPL": ObjectiveProperties(minimize=True)}, + choose_generation_strategy_kwargs = { + "num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy + }, + outcome_constraints = [constraint], + ) + + history=[] + trial_id = 0 + + # add initial points into the BO trials + for i in range(num_BO_initial_samples): + + ax_client.attach_trial(parameters=initial_points_set[i]) + + m, tokenizer = load_model(checkpoint, device) + quantize_by_fqn_to_config(m, device, initial_points_set[i]) + + eval_results = eval(m, tokenizer, num_PPL_eval_samples, initial_points_set[i]) + + print("------------") + print(trial_id, initial_points_set[i], eval_results) + + history.append((eval_results, initial_points_set[i])) + ax_client.complete_trial( + trial_index = trial_id, + raw_data = eval_results, + ) + trial_id += 1 + del m + torch.cuda.empty_cache() + + + # run new BO trials + for k_ in range(num_trials): + parameters, trial_idx = ax_client.get_next_trial() + + m, tokenizer = load_model(checkpoint, device) + + quantize_by_fqn_to_config(m, device, parameters) + + eval_results = eval(m, tokenizer, num_PPL_eval_samples, parameters) + + print("------------") + print(trial_idx, parameters, eval_results) + history.append((eval_results, parameters)) + + ax_client.complete_trial( + trial_index=trial_idx, + raw_data=eval_results, + ) + + del m + torch.cuda.empty_cache() + + + print("------Finish BO------") + for h in history: + print(h) + write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) + + print("------Best config------") + best_parameters, values = ax_client.get_best_parameters() + print(best_parameters, values) + +# Worker function to perform BO trials on a specific GPU +def eval_in_parallel(gpu_id, checkpoint, num_PPL_eval_samples, config, return_dict, proc_id, trial_id): + + model, tokenizer = load_model(checkpoint, f'cuda:{gpu_id}') + parameters_list = define_parameter_list() + + print(f"Process {proc_id} on GPU {gpu_id} starts!") + + quantize_by_fqn_to_config(model=model, device=f'cuda:{gpu_id}', fqn_to_config=dict(config)) + + eval_results = eval(model, tokenizer, num_PPL_eval_samples, config) + + return_dict[proc_id] = (trial_id, config, eval_results) + + del model + torch.cuda.empty_cache() + +''' +This function will run BO trials in parallel on multiple GPUs. +Each time the BO gets multiple new trials, evaluates the trials on the GPUs and return the evaluation results to update the BO. +Multiple trials, one BO update. +''' +def run_parallel_BO(device, checkpoint, num_PPL_eval_samples, num_BO_initial_samples, num_trials, model_size_constraint, gpu_list): + + parameters_list = define_parameter_list() + initial_points_set = get_initial_samples(num_BO_initial_samples) + + #initialize ax_client + constraint="model_size <= "+str(model_size_constraint) + ax_client = AxClient() + ax_client.create_experiment( + parameters = parameters_list, + name = "test_quantize_BO", + objectives = {"cal_PPL": ObjectiveProperties(minimize=True)}, + choose_generation_strategy_kwargs = { + "num_initialization_trials": num_BO_initial_samples, # the number of trials to build generation strategy + }, + outcome_constraints=[constraint], + ) + + gpu_list = [int(i) for i in gpu_list.split(",")] + + history=[] + trial_id = 0 + + # Set the multiprocessing start method to 'spawn' + mp.set_start_method("spawn", force=True) + + # add initial points into the BO trials + for id in range(num_BO_initial_samples//len(gpu_list)): + processes = [] + manager = mp.Manager() + return_dict = manager.dict() + + # Start the worker processes + for i, gpu_id in enumerate(gpu_list): + ax_client.attach_trial(parameters=dict(initial_points_set[id*len(gpu_list)+i])) + p = mp.Process(target=eval_in_parallel, args=(gpu_id, checkpoint, num_PPL_eval_samples, initial_points_set[id*len(gpu_list)+i], return_dict, i, trial_id)) + trial_id += 1 + p.start() + processes.append(p) + + # Wait for all processes to finish + for p in processes: + p.join() + + # Print the results after all processes have finished + print(return_dict) + for i in range(len(gpu_list)): + current_trial_id, config, eval_results = return_dict[i] + history.append((eval_results, config)) + ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,) + + # run new BO trials + for id in range(num_trials//len(gpu_list)): + processes = [] + manager = mp.Manager() + return_dict = manager.dict() + + # Start the worker processes + for i, gpu_id in enumerate(gpu_list): + parameters, trial_idx = ax_client.get_next_trial() + parameter_tuple = [] + for k, v in parameters.items(): + parameter_tuple.append((k, v)) + p = mp.Process(target = eval_in_parallel, args = (gpu_id, checkpoint, num_PPL_eval_samples, parameter_tuple, return_dict, i, trial_idx)) + p.start() + processes.append(p) + + # Wait for all processes to finish + for p in processes: + p.join() + + # Print the results after all processes have finished + print(return_dict) + for i in range(len(gpu_list)): + current_trial_id, config, eval_results = return_dict[i] + history.append((eval_results, config)) + ax_client.complete_trial(trial_index = current_trial_id, raw_data = eval_results,) + + print("------Finish BO------") + for h in history: + print(h) + write_history_to_csv(history, output_file, ["cal_PPL", "model_size", "quant_config"]) + + print("------Best config------") + best_parameters, values = ax_client.get_best_parameters() + print(best_parameters, values) + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser(description='Bayesian optimization for mixed-precision quantization to optimize accuracy under model size constraint.') + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') + parser.add_argument('--checkpoint', type=str, default="/tmp/Meta-Llama-3-8B", help='Path to load model') + parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl') + parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores') + parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO') + parser.add_argument('--model_size_constraint', type=float, default=6.0, help='The model size (GB) constraint for BO') + parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3") + parser.add_argument('--output_path', type=str, default="BO_acc_modelsize_output.csv", help="The file path to save the BO search trials") + args = parser.parse_args() + + if args.gpu_list != "": + run_sequential_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, output_path=args.output_path) + else: + run_parallel_BO(device=args.device, checkpoint=args.checkpoint, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, model_size_constraint=args.model_size_constraint, gpu_list=args.gpu_list, output_path=args.output_path) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py new file mode 100644 index 0000000000..85138403bb --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/BO_acc_throughput.py @@ -0,0 +1,487 @@ +import sys + +import torch +import time +import torch.nn as nn +from torchao.quantization import quantize_ +import random +from naive_intNwo import intN_weight_only + +import copy +from lm_eval.evaluator import evaluate +from lm_eval.models.huggingface import HFLM +from lm_eval.tasks import get_task_dict + +from transformers import AutoModelForCausalLM, AutoTokenizer +from ax.service.ax_client import AxClient, ObjectiveProperties +import torch.multiprocessing as mp + +import os +import sys +from pathlib import Path +from typing import Optional, Tuple +from datetime import datetime +import torchao +import torch._dynamo.config +import torch._inductor.config +from torchao.utils import get_model_size_in_bytes +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +from torchao.quantization.quant_api import int4_weight_only +from torchao._models.llama.model import Transformer, prepare_inputs_for_model +from torchao._models.llama.tokenizer import get_tokenizer +from torchao._models._eval import TransformerEvalWrapper, InputRecorder + +from torchao.dtypes import TensorCoreTiledLayoutType + +from torchao._models.llama.generate import ( + device_sync, + multinomial_sample_one_no_sync, + logits_to_probs, + sample, + prefill, + decode_one_token, + model_forward, + encode_tokens, + _load_model, +) + +from utils import write_history_to_csv, cal_wikitext_ppl, load_model, quantize_by_fqn_to_config + +default_device = 'cuda' if torch.cuda.is_available() else 'cpu' + +def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + input_pos += 1 + new_tokens.append(next_token) + callback(new_tokens[-1]) + new_probs.append(next_prob) + cur_token = next_token.view(1, -1) + + return new_tokens, new_probs + +@torch.no_grad() +def generate( + model: Transformer, + prompt: torch.Tensor, + max_new_tokens: int, + *, + interactive: bool, + callback = lambda x: x, + kv_cache_quantization: bool = False, + **sampling_kwargs +) -> torch.Tensor: + """ + Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + """ + + # create an empty tensor of the expected final shape and fill in the current tokens + device = prompt.device + T = prompt.numel() + + # calculate how many tokens to generate based on max_new_tokens and model's upper bound (block_size) + max_seq_length = min(T + max_new_tokens, model.config.block_size) if not interactive else 350 + new_tokens = max_seq_length - T + + # full prompt+output will be stored in seq + seq = torch.empty(max_seq_length, dtype=prompt.dtype, device=device) + seq[:T] = prompt.view(-1) + + # setup model caches + with torch.device(device): + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + if kv_cache_quantization: + from model import AffineQuantizedKVCache + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + _replace_with_custom_fn_if_matches_filter( + model, + AffineQuantizedKVCache.from_float, + lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), + ) + + + # format model input + x, input_pos = prepare_inputs_for_model(prompt, max_new_tokens) + + # execute prefill + next_token = prefill(model, x, input_pos, **sampling_kwargs).clone() + seq[T] = next_token + + # execute token generation + input_pos = torch.tensor([T], device=device, dtype=torch.int) + generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) + seq[T + 1:] = torch.cat(generated_tokens) + + return seq + +def cal_throughput( + model, + tokenizer, + device, + prompt: str = "Hello, my name is", + interactive: bool = False, + num_samples: int = 5, + max_new_tokens: int = 100, + top_k: int = 200, + temperature: float = 0.8, + checkpoint_path: Path = Path("/tmp/Meta-Llama-3-8B/model.pth"), + quantization: Optional[str] = None, + kv_cache_quantization: bool = False, + save: bool = False, + compile: bool = True, + compile_prefill: bool = False, + profile: Optional[Path] = None, + precision=torch.bfloat16, + write_result: Optional[Path] = None, +) -> None: + """Generates text samples based on a pre-trained Transformer model and tokenizer. + """ + B_INST, E_INST = "[INST]", "[/INST]" + + torchao.quantization.utils.recommended_inductor_config_setter() + + is_chat = "chat" in str(checkpoint_path) + + device_sync(device=device) # MKG + + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + prompt_length = encoded.size(0) + + torch.manual_seed(1234) + + + if compile: + print("Compiling Model") + global decode_one_token, prefill + decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True) + + if compile_prefill: + prefill = torch.compile(prefill, fullgraph=True, dynamic=True) + + + aggregate_metrics = { + 'tokens_per_sec': [], + } + start = -1 if compile else 0 + + for i in range(start, num_samples): + if i==0: + torch.cuda.reset_peak_memory_stats() + device_sync(device=device) # MKG + if i >= 0 and interactive: + prompt = input("What is your prompt? ") + if is_chat: + prompt = f"{B_INST} {prompt.strip()} {E_INST}" + encoded = encode_tokens(tokenizer, prompt, bos=True, device=device) + + if interactive and i >= 0: + buffer = [] + period_id = tokenizer.encode('.')[0] + done_generating = False + def callback(x): + nonlocal done_generating + if done_generating: + return + buffer.append(tokenizer.decode([period_id] + x.tolist())[1:]) + if x.item() == tokenizer.eos_id(): + done_generating = True + if len(buffer) == 4 or done_generating: + print(''.join(buffer), end='', flush=True) + buffer.clear() + # print(, end='', flush=True) + else: + callback = lambda x : x + t0 = time.perf_counter() + import contextlib + if (i != num_samples - 1 or not profile): + prof = contextlib.nullcontext() + else: + torch.profiler._utils._init_for_cuda_graphs() + prof = torch.profiler.profile() + with prof: + y = generate( + model, + encoded, + max_new_tokens, + interactive=interactive, + callback=callback, + temperature=temperature, + top_k=top_k, + kv_cache_quantization=kv_cache_quantization, + ) + if i == -1: + print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") + continue + if hasattr(prof, "export_chrome_trace"): + prof.export_chrome_trace(f"{profile}.json") + device_sync(device=device) # MKG + t = time.perf_counter() - t0 + + if not interactive: + tok_list = y.tolist() + # truncate text after end of string token + tokens = tok_list if not tokenizer.eos_id() in y else tok_list[:tok_list.index(tokenizer.eos_id())] + else: + print() + tokens_generated = y.size(0) - prompt_length + tokens_sec = tokens_generated / t + aggregate_metrics['tokens_per_sec'].append(tokens_sec) + + tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + print("tokpersec", tokpersec) + return tokpersec + + +# return evaluation results to complete BO trials +def eval(model4ppl, model4tp, tokenizer, device, num_PPL_eval_samples, fqn_to_config): + return { + "cal_PPL": (cal_wikitext_ppl(model4ppl, tokenizer, num_PPL_eval_samples), 0.0), + "cal_throughput": (cal_throughput(model=model4tp, tokenizer=tokenizer, device=device), 0.0), + } + +# TODO: make it into a yaml or json file to enable users specify their custom model formats +def define_parameter_list(): + + # define the search space for all layers + parameters_list = [] + + for i in range(0, 3): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "fixed", + "value_type": "int", + "value": 8, + "is_ordered": True, + "sort_values": True, + } + ) + + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "fixed", + "value_type": "int", + "value": 32, + "is_ordered": True, + "sort_values": True, + } + ) + + for i in range(3, 30): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "choice", + "value_type": "int", + "values": [2,3,4,5,6,8], + "is_ordered": True, + "sort_values": True, + } + ) + + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "choice", + "value_type": "int", + "values": [32, 64, 128, 256], + "is_ordered": True, + "sort_values": True, + } + ) + + for i in range(30, 32): + parameters_list.append( + { + "name": f"bitwidth.{i}.", + "type": "fixed", + "value_type": "int", + "value": 8, + "is_ordered": True, + "sort_values": True, + } + ) + parameters_list.append( + { + "name": f"groupsize.{i}.", + "type": "fixed", + "value_type": "int", + "value": 32, + "is_ordered": True, + "sort_values": True, + } + ) + + return parameters_list + +# add initial search points based on the sensitivity score +# TODO: automate the initial samples by better leverage the sensitivity scores +def get_initial_samples(num_BO_initial_samples=50): + + initial_points_set = [] + + # auto sample the bit choices with random choice probability positive correlated to FIT score + for _ in range(num_BO_initial_samples): + initial_points = {} + for i in range(0, 3): + initial_points["bitwidth." + str(i) + "."] = 8 + initial_points["groupsize." + str(i) + "."] = 32 + + for i in range(3, 18): + if i in [5,6,7,10,11,12,16]: + initial_points["bitwidth." + str(i) + "."] = random.choices([8, 6, 5, 4], [25, 2, 2, 71])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [40, 60])[0] + else: + initial_points["bitwidth." + str(i) + "."] = random.choices([8, 6, 5, 4], [30, 2, 2,66])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64], [50, 50])[0] + + for i in range(18, 30): + if i in [22,23,24]: + initial_points["bitwidth." + str(i) + "."] = random.choices([8, 6, 5, 4], [10, 2, 2, 86])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [35, 45, 10, 10])[0] + else: + initial_points["bitwidth." + str(i) + "."] = random.choices([8, 6, 5, 4], [20, 2, 2, 76])[0] + initial_points["groupsize." + str(i) + "."] = random.choices([32, 64, 128, 256], [30, 40, 25, 5])[0] + + for i in range(30, 32): + initial_points["bitwidth." + str(i) + "."] = 8 + initial_points["groupsize." + str(i) + "."] = 32 + + initial_points_set.append(initial_points) + + return initial_points_set + +''' +This function will run BO trials sequentially on a single GPU. +Each time the BO gets one new trial, evaluates the trial on the GPU and return the evaluation results to update the BO. +One trial, one BO update. +''' +def run_sequential_BO(device, checkpoint_path, repo_id, num_PPL_eval_samples, num_BO_initial_samples, num_trials, ppl_constraint, args): + ''' + currently use the loader and benchmark code from torchao/_models/llama/generate, + and use lm_eval for ppl evaluation + ''' + # load tokenizers + assert checkpoint_path.is_file(), checkpoint_path + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), str(tokenizer_path) + device_sync(device=device) # MKG + tokenizer4tp = get_tokenizer(tokenizer_path, checkpoint_path) + tokenizer4ppl = AutoTokenizer.from_pretrained(repo_id) + + # initialize parameters + parameters_list = define_parameter_list() + + # sample initial points + initial_points_set = get_initial_samples(num_BO_initial_samples) + + # initialize BO experiment + constraint="cal_PPL <= "+str(ppl_constraint) + ax_client = AxClient() + ax_client.create_experiment( + parameters=parameters_list, + name="test_quantize_BO", + objectives={"cal_throughput": ObjectiveProperties(minimize=False)}, + choose_generation_strategy_kwargs={ + "num_initialization_trials": num_BO_initial_samples # the number of trials to build generation strategy + }, + outcome_constraints=[constraint], + ) + + history=[] + trial_id = 0 + + # add initial points into the BO trials + for i in range(num_BO_initial_samples): + + ax_client.attach_trial(parameters=initial_points_set[i]) + + # evaluate throuput of quantized model under torch.compile() + model4tp = _load_model(checkpoint_path, device, torch.bfloat16) + quantize_by_fqn_to_config(model = model4tp, device=device, fqn_to_config = initial_points_set[i]) + tp=cal_throughput(model=model4tp, tokenizer=tokenizer4tp, device=device) + del model4tp + torch.cuda.empty_cache() + + # evaluate ppl of quantized model + model4ppl = load_model(repo_id, device) + quantize_by_fqn_to_config(model = model4ppl, device=device, fqn_to_config = initial_points_set[i]) + ppl=cal_wikitext_ppl(model4ppl, tokenizer4ppl, num_PPL_eval_samples) + del model4ppl + torch.cuda.empty_cache() + + eval_results= {"cal_PPL": (ppl, 0.0), "cal_throughput": (tp, 0.0),} + + print("------------") + print(trial_id, initial_points_set[i], eval_results) + + history.append((eval_results, initial_points_set[i])) + ax_client.complete_trial( + trial_index=trial_id, + raw_data=eval_results, + ) + trial_id += 1 + + # run new BO trials + for k_ in range(num_trials): + parameters, trial_idx = ax_client.get_next_trial() + + # evaluate throuput of quantized model under torch.compile() + model4tp = _load_model(checkpoint_path, device, torch.bfloat16) + quantize_by_fqn_to_config(model = model4tp, device=device, fqn_to_config = initial_points_set[i]) + tp=cal_throughput(model=model4tp, tokenizer=tokenizer4tp, device=device) + del model4tp + torch.cuda.empty_cache() + + # evaluate ppl of quantized model + model4ppl = load_model(repo_id, device) + quantize_by_fqn_to_config(model = model4ppl, device=device, fqn_to_config = initial_points_set[i]) + ppl=cal_wikitext_ppl(model4ppl, tokenizer4ppl, num_PPL_eval_samples) + del model4ppl + torch.cuda.empty_cache() + + eval_results= {"cal_PPL": (ppl, 0.0), "cal_throughput": (tp, 0.0),} + + print("------------") + print(trial_idx, parameters, eval_results) + + history.append((eval_results, parameters)) + + ax_client.complete_trial( + trial_index=trial_idx, + raw_data=eval_results, + ) + + print("------Finish BO------") + for h in history: + print(h) + write_history_to_csv(history, args.output_file, ["cal_PPL", "cal_throughput", "quant_config"]) + + print("------Best config------") + best_parameters, values = ax_client.get_best_parameters() + print(best_parameters, values) + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser(description='Bayesian optimization for mixed-precision quantization to optimize inference speed under model accuracy constraint.') + + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') + parser.add_argument('--checkpoint_path', type=Path, default=Path("/tmp/Meta-Llama-3-8B/model.pth"), help='Model checkpoint path for model.pth.') + parser.add_argument('--repo_id', type=str, default=Path("/tmp/Meta-Llama-3-8B"), help='Model repo id.') + parser.add_argument('--num_PPL_eval_samples', type=int, default=None, help='Number of samples to evaluate ppl') + parser.add_argument('--num_BO_initial_samples', type=int, default=50, help='Number of initial points sampled by sensitivity scores') + parser.add_argument('--num_trials', type=int, default=150, help='Number of trials to run BO') + parser.add_argument('--ppl_constraint', type=float, default=7.5, help='The ppl constraint for BO') + parser.add_argument('--multi_gpus', action='store_true', help="Use multi-processing to run evaluation on multi-gpus") + parser.add_argument('--gpu_list', type=str, default="", help="A list of gpus to run evaluation, separated by comma, e.g., --gpu_lists=0,1,2,3") + parser.add_argument('--output_path', type=str, default="BO_acc_speed_output.csv", help="The csv file path to save the BO search trials") + + args = parser.parse_args() + run_sequential_BO(device=args.device, checkpoint_path=args.checkpoint_path, repo_id=args.repo_id, num_PPL_eval_samples=args.num_PPL_eval_samples, num_BO_initial_samples=args.num_BO_initial_samples, num_trials=args.num_trials, ppl_constraint=args.ppl_constraint, args=args) diff --git a/torchao/quantization/prototype/mixed_precision/scripts/fit.py b/torchao/quantization/prototype/mixed_precision/scripts/fit.py index 19366a01c9..78ec878d33 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/fit.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/fit.py @@ -95,7 +95,7 @@ def main(max_seqlen, checkpoint, nsamples, maxIter, num_layers): if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') + parser = argparse.ArgumentParser(description='Calculate layer-wised fish information matric trace.') parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') parser.add_argument('--maxIter', type=int, default=100, help='The number of iterations to calculate FIT') diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py index bb33cff39f..8aea888925 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_grad.py @@ -141,7 +141,7 @@ def main(layer_id, checkpoint, max_seqlen, maxIter, nsamples): if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') + parser = argparse.ArgumentParser(description='Calculate layer-wised Hessian trace leveraging autograd.') parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the trace and hessian') parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') diff --git a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py index 480365c66d..3470031cb1 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/hessian_vhp.py @@ -154,7 +154,7 @@ def f(*new_params): if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Your CLI description.') + parser = argparse.ArgumentParser(description="Calculate layer-wised Hessian trace leveraging torch's vhp function.") parser.add_argument('--layer_id', type=int, default=0, help='Which layer to compute the Hessian trace') parser.add_argument('--checkpoint', type=str, default="/home/hanxianhuang/ao/torchao/quantization/prototype/mixed_precision/checkpoints/meta-llama/Meta-Llama-3-8B", help='Path to load model') parser.add_argument('--max_seqlen', type=int, default=2048, help='Max sequence length') diff --git a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py index d17b76159e..852fbd7e46 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/mp_quant_eval.py @@ -77,7 +77,7 @@ def filter_fn_nonsen(child: torch.nn.Module, cur_fqn:str) -> bool: if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='Run HF Model Evaluation') + parser = argparse.ArgumentParser(description='Run evaluation for uniform or mixed-precision quantization.') parser.add_argument('--repo_id', type=str, default="checkpoints/meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') parser.add_argument('--tasks', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') diff --git a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py index 6ebe458a46..b29df38552 100644 --- a/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py +++ b/torchao/quantization/prototype/mixed_precision/scripts/naive_intNwo.py @@ -40,9 +40,11 @@ def apply_intN_weight_only_quant_sym(weight): mapping_type = MappingType.SYMMETRIC block_size = (1, group_size) target_dtype = torch.int8 + quant_min = -2**(n-1) + quant_max = 2**(n-1)-1 eps = 1e-6 zero_point_dtype = torch.int64 - return to_affine_quantized(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return to_affine_quantized(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps=eps, zero_point_dtype=zero_point_dtype) try: assert n in [8, 6, 5, 4, 3, 2], "n must be one of [8, 6, 5, 4, 3, 2]" diff --git a/torchao/quantization/prototype/mixed_precision/scripts/utils.py b/torchao/quantization/prototype/mixed_precision/scripts/utils.py new file mode 100644 index 0000000000..4d075b4699 --- /dev/null +++ b/torchao/quantization/prototype/mixed_precision/scripts/utils.py @@ -0,0 +1,108 @@ +import csv +import sys + +import torch +import torch.nn as nn +from torchao.quantization import quantize_ +import random + +from naive_intNwo import intN_weight_only + +import copy +from lm_eval.evaluator import evaluate +from lm_eval.models.huggingface import HFLM +from lm_eval.tasks import get_task_dict + +from transformers import AutoModelForCausalLM, AutoTokenizer + +def write_history_to_csv(history, output_file, keyword): + #keyword example: ['cal_PPL', 'cal_throughput', 'config'] + + with open(output_file, mode='w', newline='') as file: + writer = csv.writer(file) + + # Write the header row + writer.writerow(keyword) + + for eval_results, config in history: + obj1 = eval_results[keyword[0]][0] + obj2 = eval_results[keyword[1]][0] + + writer.writerow([obj1, obj2, config]) + +# quantize a model based on a given quantization configuration +def quantize_by_fqn_to_config(model, device, fqn_to_config): + it = iter(fqn_to_config.items()) + while True: + try: + k1, v1 = next(it) + k2, v2 = next(it) + fqn = k1[8:] + bit_width, groupsize = v1, v2 + + def filter_fn_sen(child: torch.nn.Module, cur_fqn: str) -> bool: + return isinstance(child, torch.nn.Linear) and (fqn in cur_fqn) + + quantize_( + model.to(device=device), + intN_weight_only(n=bit_width, group_size=groupsize), + filter_fn_sen, + ) + except StopIteration: + break + + +# calculate perplexity on wikitext-document, need to support more tasks +def cal_wikitext_ppl(model, tokenizer, limit=62): + + with torch.no_grad(): + result = evaluate( + HFLM(pretrained=model, tokenizer=tokenizer, batch_size=1), + get_task_dict("wikitext"), + limit=limit + ) + + return result["results"]["wikitext"]["word_perplexity,none"] + +# TODO: make it generalize to more models +def cal_model_size(model, fqn_to_config): + _sum = 0 + fqn_cofg_dict = dict() + + it = iter(fqn_to_config.items()) + while True: + try: + k1, v1 = next(it) + k2, v2 = next(it) + bit_width, groupsize = v1, v2 + bit_zeropoint = 32 + bit_scale = 8 + fqn = k1[8:] + fqn_cofg_dict[fqn] = (bit_width, groupsize, bit_zeropoint, bit_scale) + except StopIteration: + break + + for name, parameter in model.named_parameters(): + flag = 0 + for fqn in fqn_cofg_dict: + if fqn in name: + flag = 1 + if "self_attn" in name or "mlp" in name: + _sum += parameter.numel() * fqn_cofg_dict[fqn][ + 0 + ] + parameter.numel() // fqn_cofg_dict[fqn][1] * ( + fqn_cofg_dict[fqn][2] + fqn_cofg_dict[fqn][3] + ) + if flag == 0: + _sum += parameter.numel() * 16 + + _sum_in_byte = _sum / 8.0 + _sum_in_GB = _sum_in_byte / (1024**3) / 1.0 + return _sum_in_GB + +def load_model(repo_id, device): + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to( + device=device + ) + return model, tokenizer