Skip to content

Commit d10e227

Browse files
authored
feat(sanitize): multiprocessing support
2 parents 19baeb4 + 10d062f commit d10e227

File tree

3 files changed

+67
-34
lines changed

3 files changed

+67
-34
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,9 @@ We provide a tool namely `bigcodebench.sanitize` to clean up the code:
193193
bigcodebench.sanitize --samples samples.jsonl --calibrate
194194
# Sanitized code will be produced to `samples-sanitized-calibrated.jsonl`
195195

196+
# 💡 Optionally run the sanitization step with multiprocessing to speedup
197+
bigcodebench.sanitize --samples samples.jsonl --calibrate --parallel 8
198+
196199
# 💡 If you want to get the original results:
197200
bigcodebench.sanitize --samples samples.jsonl
198201
# Sanitized code will be produced to `samples-sanitized.jsonl`

Requirements/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
appdirs>=1.4.4
22
fire>=0.6.0
33
multipledispatch>=0.6.0
4+
pqdm>=0.2.0
45
tempdir>=0.7.1
56
termcolor>=2.0.0
67
tqdm>=4.56.0
78
tree_sitter_languages>=1.10.2
89
tree-sitter==0.21.3
9-
wget>=3.2
10+
wget>=3.2

bigcodebench/sanitize.py

Lines changed: 62 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import pathlib
55
from typing import Dict, Generator, List, Optional, Set, Tuple
6+
from pqdm.processes import pqdm
67

78
from tqdm import tqdm
89
from tree_sitter import Node
@@ -178,8 +179,48 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
178179
return sanitized_output
179180

180181

182+
def process_solution(
183+
sample_solution: Dict,
184+
dataset: Dict,
185+
entry_point: Dict,
186+
debug_task: str = None,
187+
calibrate: bool = False,
188+
is_folder: bool = False,
189+
target_path: str = None,
190+
):
191+
192+
task_id = sample_solution.get("task_id")
193+
if not task_id or task_id not in dataset:
194+
return None
195+
196+
dbg_identifier = sample_solution["_identifier"]
197+
if debug_task is not None and task_id != debug_task:
198+
return None
199+
200+
function_name = entry_point.get(task_id)
201+
old_code = sample_solution.get("solution")
202+
203+
if old_code is None:
204+
assert "completion" in sample_solution, sample_solution
205+
old_code = dataset[task_id]["complete_prompt"] + "\n" + sample_solution.get("completion")
206+
else:
207+
if calibrate:
208+
old_code = old_code.replace("```python\n ", "```python\n"+dataset[task_id]["complete_prompt"]+" ")
209+
210+
new_code = sanitize(code=old_code, entrypoint=function_name)
211+
212+
# if old code and new code are different, print msg
213+
if new_code != old_code:
214+
msg = "Sanitized: " + dbg_identifier
215+
if is_folder:
216+
msg += " -> " + dbg_identifier.replace(samples, target_path)
217+
print(msg)
218+
219+
return {"task_id": task_id, "solution": new_code}
220+
221+
181222
def script(
182-
samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False
223+
samples: str, inplace: bool = False, debug_task: str = None, calibrate: bool = False, parallel: int=32
183224
):
184225
# task_id -> entry_point
185226
entry_point = {}
@@ -211,38 +252,26 @@ def script(
211252

212253
new_solutions = []
213254

214-
for solution in tqdm(load_solutions(samples)):
215-
task_id = solution["task_id"]
216-
if task_id not in dataset:
217-
print(
218-
f"Skiping {task_id} as it does not existing in the latest EvalPlus dataset."
219-
)
220-
continue
221-
222-
function_name = entry_point[task_id] if task_id in entry_point else None
223-
dbg_identifier = solution["_identifier"]
224-
if debug_task is not None and task_id != debug_task:
225-
continue
226-
227-
ntotal += 1
228-
if "solution" in solution:
229-
old_code = solution["solution"]
230-
if calibrate:
231-
old_code = solution["solution"].replace("```python\n ", "```python\n"+dataset[task_id]["complete_prompt"]+" ")
232-
else:
233-
assert "completion" in solution
234-
old_code = dataset[task_id]["complete_prompt"] + "\n" + solution["completion"]
235-
236-
new_code = sanitize(code=old_code, entrypoint=function_name)
237-
# if changed, print the message
238-
if new_code != old_code:
239-
msg = "Sanitized: " + dbg_identifier
240-
if is_folder:
241-
msg += " -> " + dbg_identifier.replace(samples, target_path)
242-
print(msg)
255+
parallel_arg_list = [
256+
{
257+
"sample_solution": sample_solution,
258+
"dataset": dataset,
259+
"entry_point": entry_point,
260+
"debug_task": debug_task,
261+
"calibrate": calibrate,
262+
"is_folder": is_folder,
263+
"target_path": target_path
264+
}
265+
for sample_solution in load_solutions(samples)
266+
]
267+
268+
results = pqdm(parallel_arg_list, process_solution, n_jobs=min(parallel, os.cpu_count()), argument_type="kwargs")
269+
270+
for result in results:
271+
if result is not None:
272+
new_solutions.append(result)
243273
nsan += 1
244-
245-
new_solutions.append({"task_id": task_id, "solution": new_code})
274+
ntotal += 1
246275

247276
if is_folder:
248277
write_directory(target_path, new_solutions)
@@ -263,4 +292,4 @@ def main():
263292

264293

265294
if __name__ == "__main__":
266-
main()
295+
main()

0 commit comments

Comments
 (0)