From 19ee6b48d6dbee6addcd83c9a1002dc371439c9f Mon Sep 17 00:00:00 2001 From: yingfhu Date: Fri, 1 Sep 2023 18:43:39 +0800 Subject: [PATCH] [Feat] update proprocessing with new params --- evals/humanevalx/evaluation.py | 17 ++++++++++++++--- evals/humanevalx/utils.py | 27 ++++++++++----------------- scripts/eval_humanevalx.sh | 6 +++++- server.py | 10 +++++++--- 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/evals/humanevalx/evaluation.py b/evals/humanevalx/evaluation.py index 0afee48..6d35b10 100644 --- a/evals/humanevalx/evaluation.py +++ b/evals/humanevalx/evaluation.py @@ -41,7 +41,7 @@ def postprocess_generation(sample, generation_mode="completion"): return sample -def process_test(sample, problems, dataset_type, language_type, generation_mode): +def process_test(sample, problems, dataset_type, language_type, generation_mode, with_prompt): sample["generation"] = cleanup_code(sample["generation"], dataset_type, language_type) if dataset_type == "humanevalx": task_id = sample["task_id"] @@ -49,12 +49,19 @@ def process_test(sample, problems, dataset_type, language_type, generation_mode) test = problems[task_id]["test"] code = sample["generation"] + if not with_prompt: + prompt = '' # Pre-process for different languages if language_type == "python": test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" test_string = test_setup + prompt + code + "\n" + test + "\n" elif language_type == "cpp": test_set_up = "" + funcname = re.search('using namespace std;\n(.*?){\s', prompt) + if funcname: + funcname = funcname.group()[21:].strip() + if funcname in code: + code = code[code.find('{')+1:] for s in IMPORT_HELPER["cpp"]: if s not in prompt: test_set_up += s + "\n" @@ -66,6 +73,9 @@ def process_test(sample, problems, dataset_type, language_type, generation_mode) elif language_type == "go": import_string = problems[task_id]["import"] prompt = prompt.replace(import_string, "") + # remove unnecessary title if not use prompt + if "func " in code and not with_prompt: + code = code[code.find("func "):] test = problems[task_id]["test"] test_setup = problems[task_id]["test_setup"] other_pkgs = [] @@ -106,6 +116,7 @@ def evaluate_functional_correctness( dataset_type: str = "humanevalx", generation_mode: str = "completion", test_groundtruth: bool = False, + with_prompt: bool = True, ): if log_path is None: log_path = os.path.join(output_path, "evaluation.log") @@ -127,7 +138,7 @@ def evaluate_functional_correctness( if output_path is not None: os.makedirs(output_path, exist_ok=True) - + with ThreadPoolExecutor(max_workers=n_workers) as executor: futures = [] @@ -152,7 +163,7 @@ def evaluate_functional_correctness( sample["prompt"] = problems[task_id]["prompt"] sample["prompt"] = problems[task_id]["prompt"] sample = postprocess_generation(sample, generation_mode) - sample["test_code"] = process_test(sample, problems, dataset_type, language_type, generation_mode) + sample["test_code"] = process_test(sample, problems, dataset_type, language_type, generation_mode, with_prompt) if sample["test_code"] is None: continue if "completion_id" in sample: diff --git a/evals/humanevalx/utils.py b/evals/humanevalx/utils.py index fe77063..f21ddb3 100644 --- a/evals/humanevalx/utils.py +++ b/evals/humanevalx/utils.py @@ -387,21 +387,8 @@ def cleanup_code( code = first_block(code, stop_words) elif dataset_type == "humanevalx": if language_type.lower() == "python": - code_splits = code.split("\n") - is_empty_line = False - ind_empty_line = None - for i, line in enumerate(code_splits): - if len(line.strip()) > 0 and line[0] != ' ' and line[0] != '\t': - is_empty_line = True - ind_empty_line = i - break - if is_empty_line: - code = "\n".join(code_splits[:ind_empty_line]) - else: - end_words = ["\ndef", "\nclass", "\n#", "\nassert", '\n"""', "\nprint", "\nif", "\n\n\n"] - for w in end_words: - if w in code: - code = code[:code.rfind(w)] + # already processed in Opencompass + pass elif language_type.lower() == "java": main_pos = code.find("public static void main") if main_pos != -1: @@ -410,12 +397,18 @@ def cleanup_code( code = code[:code.rfind('}')] + '}' if code.count('{') + 1 == code.count('}'): code += "\n}" + if "public class" in code: + code = code[:code.find("public class")] elif language_type.lower() == "go": - if "\nfunc main(" in code: - code = code[:code.rfind("func main(")] + pattern = re.compile("func main\((.*?)\n}", re.DOTALL) + code = pattern.sub('', code) + if "package main" in code: + code = code[code.find("package main"):] if '}' in code: code = code[:code.rfind('}')] + '}' elif language_type.lower() == "cpp": + if "using namespace std;" not in code: + code = "using namespace std;\n"+code code = extract_block(code) if "\nint main()" in code: code = code[:code.rfind("int main()")] diff --git a/scripts/eval_humanevalx.sh b/scripts/eval_humanevalx.sh index 1a8b11d..1649675 100755 --- a/scripts/eval_humanevalx.sh +++ b/scripts/eval_humanevalx.sh @@ -12,7 +12,7 @@ OUTPUT_DIR=outputs/humanevalx-${LANGUAGE} TMP_DIR=tmp OPTIND=3 -while getopts "n:o:t:l:" OPT; +while getopts "n:o:t:p:l:" OPT; do case $OPT in n) @@ -24,6 +24,9 @@ do t) TMP_DIR="$OPTARG" ;; + p) + WITH_PROMPT=$OPTARG + ;; \?) echo "Invalid option -$OPTARG" >&2 exit 1 @@ -77,6 +80,7 @@ CMD="python ./evals/humanevalx/evaluation.py \ --n_workers $NUM_WORKERS \ --tmp_dir $TMP_DIR \ --problem_file $DATA_PATH \ + --with_prompt $WITH_PROMPT \ --timeout $TIMEOUT" echo "Running CMD: " $CMD diff --git a/server.py b/server.py index d0353c6..95494ee 100644 --- a/server.py +++ b/server.py @@ -35,7 +35,7 @@ def check_datasets(dataset): else: raise NotImplementedError(f"{dataset} not implemented...") -def make_cmd(eval_filepath, dataset, ip_address): +def make_cmd(eval_filepath, dataset, ip_address, with_prompt): if 'humanevalx' in dataset: dataset, language = dataset.split("/") result_dir = f"outputs/{ip_address}-{dataset}-{language}" @@ -45,6 +45,7 @@ def make_cmd(eval_filepath, dataset, ip_address): eval_filepath, language, "-n", '8', + "-p", f"{with_prompt}", "-o", result_dir, "-t", tmp_dir], result_dir @@ -56,7 +57,10 @@ def _eval(single_request): dataset = single_request.form.get('dataset') ip_address = single_request.remote_addr - + if 'With-Prompt' in single_request.headers: + with_prompt = single_request.headers['With-Prompt'] + else: + with_prompt = True try: check_datasets(dataset) except ValueError as e: @@ -64,7 +68,7 @@ def _eval(single_request): except NotImplementedError as e: return {'message':f'Dataset({dataset}) not supported.', 'exception': e}, 400 - cmd_items, result_dir = make_cmd(eval_filepath, dataset, ip_address) + cmd_items, result_dir = make_cmd(eval_filepath, dataset, ip_address, with_prompt) print("RUN CMD : " + " ".join(cmd_items)) result = subprocess.run(cmd_items, capture_output=True, text=True)