diff --git a/README.md b/README.md index 17d066c0..10922e5e 100755 --- a/README.md +++ b/README.md @@ -193,6 +193,9 @@ We provide a tool namely `bigcodebench.sanitize` to clean up the code: bigcodebench.sanitize --samples samples.jsonl --calibrate # Sanitized code will be produced to `samples-sanitized-calibrated.jsonl` +# 💡 Optionally run the sanitization step with multiprocessing to speedup +bigcodebench.sanitize --samples samples.jsonl --calibrate --parallel 8 + # 💡 If you want to get the original results: bigcodebench.sanitize --samples samples.jsonl # Sanitized code will be produced to `samples-sanitized.jsonl` diff --git a/Requirements/requirements.txt b/Requirements/requirements.txt index 69d8b6c2..178ae814 100644 --- a/Requirements/requirements.txt +++ b/Requirements/requirements.txt @@ -1,9 +1,10 @@ appdirs>=1.4.4 fire>=0.6.0 multipledispatch>=0.6.0 +pqdm>=0.2.0 tempdir>=0.7.1 termcolor>=2.0.0 tqdm>=4.56.0 tree_sitter_languages>=1.10.2 tree-sitter==0.21.3 -wget>=3.2 \ No newline at end of file +wget>=3.2 diff --git a/bigcodebench/sanitize.py b/bigcodebench/sanitize.py index 6a93f2e0..df9ed4eb 100644 --- a/bigcodebench/sanitize.py +++ b/bigcodebench/sanitize.py @@ -3,6 +3,7 @@ import os import pathlib from typing import Dict, Generator, List, Optional, Set, Tuple +from pqdm.processes import pqdm from tqdm import tqdm from tree_sitter import Node @@ -178,8 +179,48 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str: return sanitized_output +def process_solution( + sample_solution: Dict, + dataset: Dict, + entry_point: Dict, + debug_task: str = None, + calibrate: bool = False, + is_folder: bool = False, + target_path: str = None, +): + + task_id = sample_solution.get("task_id") + if not task_id or task_id not in dataset: + return None + + dbg_identifier = sample_solution["_identifier"] + if debug_task is not None and task_id != debug_task: + return None + + function_name = entry_point.get(task_id) + old_code = sample_solution.get("solution") + + if old_code is None: + assert "completion" in sample_solution, sample_solution + old_code = dataset[task_id]["complete_prompt"] + "\n" + sample_solution.get("completion") + else: + if calibrate: + old_code = old_code.replace("```python\n ", "```python\n"+dataset[task_id]["complete_prompt"]+" ") + + new_code = sanitize(code=old_code, entrypoint=function_name) + + # if old code and new code are different, print msg + if new_code != old_code: + msg = "Sanitized: " + dbg_identifier + if is_folder: + msg += " -> " + dbg_identifier.replace(samples, target_path) + print(msg) + + return {"task_id": task_id, "solution": new_code} + + def script( - samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False + samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False, parallel: int=32 ): # task_id -> entry_point entry_point = {} @@ -211,38 +252,26 @@ def script( new_solutions = [] - for solution in tqdm(load_solutions(samples)): - task_id = solution["task_id"] - if task_id not in dataset: - print( - f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset." - ) - continue - - function_name = entry_point[task_id] if task_id in entry_point else None - dbg_identifier = solution["_identifier"] - if debug_task is not None and task_id != debug_task: - continue - - ntotal += 1 - if "solution" in solution: - old_code = solution["solution"] - if calibrate: - old_code = solution["solution"].replace("```python\n ", "```python\n"+dataset[task_id]["complete_prompt"]+" ") - else: - assert "completion" in solution - old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"] - - new_code = sanitize(code=old_code, entrypoint=function_name) - # if changed, print the message - if new_code != old_code: - msg = "Sanitized: " + dbg_identifier - if is_folder: - msg += " -> " + dbg_identifier.replace(samples, target_path) - print(msg) + parallel_arg_list = [ + { + "sample_solution": sample_solution, + "dataset": dataset, + "entry_point": entry_point, + "debug_task": debug_task, + "calibrate": calibrate, + "is_folder": is_folder, + "target_path": target_path + } + for sample_solution in load_solutions(samples) + ] + + results = pqdm(parallel_arg_list, process_solution, n_jobs=min(parallel, os.cpu_count()), argument_type="kwargs") + + for result in results: + if result is not None: + new_solutions.append(result) nsan += 1 - - new_solutions.append({"task_id": task_id, "solution": new_code}) + ntotal += 1 if is_folder: write_directory(target_path, new_solutions) @@ -263,4 +292,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file