From 32278eaff6127761cb38effe5b6daaf581b73f97 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 4 Jun 2024 14:14:57 -0700 Subject: [PATCH 1/2] Adding a quick way for users to test model eval for hf models Summary: This script allows users to run evaluation and try out torchao APIs Test Plan: python hf_eval.py Reviewers: Subscribers: Tasks: Tags: --- scripts/hf_eval.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 scripts/hf_eval.py diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py new file mode 100644 index 0000000000..8a54ec342c --- /dev/null +++ b/scripts/hf_eval.py @@ -0,0 +1,59 @@ +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from lm_eval.models.huggingface import HFLM +from lm_eval.evaluator import evaluate +from lm_eval.tasks import get_task_dict + +from torchao.quantization.quant_api import ( + change_linear_weights_to_int4_woqtensors, + change_linear_weights_to_int8_dqtensors, + change_linear_weights_to_int8_woqtensors, + autoquant, +) + +torch._inductor.config.force_fuse_int_mm_with_mul = True +torch._inductor.config.fx_graph_cache = True + +def run_evaluation(repo_id, task_list, limit, device, precision, quantization, compile): + + tokenizer = AutoTokenizer.from_pretrained(repo_id) + model = AutoModelForCausalLM.from_pretrained(repo_id).to(device="cuda", dtype=precision) + + if compile: + torch.compile(model, mode="max-autotune", fullgraph=True) + + if quantization == "int8dq": + change_linear_weights_to_int8_dqtensors(model) + elif quantization == "int8wo": + change_linear_weights_to_int8_woqtensors(model) + elif quantization == "int4wo": + change_linear_weights_to_int4_woqtensors(model) + elif quantization == "autoquant": + model = autoquant(model) + + with torch.no_grad(): + result = evaluate( + HFLM(pretrained=model, tokenizer=tokenizer), + get_task_dict(task_list), + limit = limit + ) + for task, res in result["results"].items(): + print(f"{task}: {res}") + + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser(description='Run HF Model Evaluation') + parser.add_argument('--repo_id', type=str, default="meta-llama/Meta-Llama-3-8B", help='Repository ID to download from HF.') + parser.add_argument('--task_list', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') + parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') + parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') + parser.add_argument('--device', type=str, default="cuda", help='Dvice to use for evaluation') + parser.add_argument('--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') + parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') + + args = parser.parse_args() + run_evaluation(args.repo_id, args.task_list, args.limit, args.device, args.precision, args.quantization, args.compile) From 516bdf73e8419e320bcc718524e91b5e1ca983c7 Mon Sep 17 00:00:00 2001 From: Mark Saroufim Date: Tue, 4 Jun 2024 17:31:33 -0400 Subject: [PATCH 2/2] Update scripts/hf_eval.py --- scripts/hf_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/hf_eval.py b/scripts/hf_eval.py index 8a54ec342c..ab1a8adb17 100644 --- a/scripts/hf_eval.py +++ b/scripts/hf_eval.py @@ -51,7 +51,7 @@ def run_evaluation(repo_id, task_list, limit, device, precision, quantization, c parser.add_argument('--task_list', nargs='+', type=str, default=["wikitext"], help='List of lm-eluther tasks to evaluate usage: --tasks task1 task2') parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') - parser.add_argument('--device', type=str, default="cuda", help='Dvice to use for evaluation') + parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') parser.add_argument('--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply') parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')