3
3
import os
4
4
import pathlib
5
5
from typing import Dict , Generator , List , Optional , Set , Tuple
6
+ from pqdm .processes import pqdm
6
7
7
8
from tqdm import tqdm
8
9
from tree_sitter import Node
@@ -178,8 +179,48 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
178
179
return sanitized_output
179
180
180
181
182
+ def process_solution (
183
+ sample_solution : Dict ,
184
+ dataset : Dict ,
185
+ entry_point : Dict ,
186
+ debug_task : str = None ,
187
+ calibrate : bool = False ,
188
+ is_folder : bool = False ,
189
+ target_path : str = None ,
190
+ ):
191
+
192
+ task_id = sample_solution .get ("task_id" )
193
+ if not task_id or task_id not in dataset :
194
+ return None
195
+
196
+ dbg_identifier = sample_solution ["_identifier" ]
197
+ if debug_task is not None and task_id != debug_task :
198
+ return None
199
+
200
+ function_name = entry_point .get (task_id )
201
+ old_code = sample_solution .get ("solution" )
202
+
203
+ if old_code is None :
204
+ assert "completion" in sample_solution , sample_solution
205
+ old_code = dataset [task_id ]["complete_prompt" ] + "\n " + sample_solution .get ("completion" )
206
+ else :
207
+ if calibrate :
208
+ old_code = old_code .replace ("```python\n " , "```python\n " + dataset [task_id ]["complete_prompt" ]+ " " )
209
+
210
+ new_code = sanitize (code = old_code , entrypoint = function_name )
211
+
212
+ # if old code and new code are different, print msg
213
+ if new_code != old_code :
214
+ msg = "Sanitized: " + dbg_identifier
215
+ if is_folder :
216
+ msg += " -> " + dbg_identifier .replace (samples , target_path )
217
+ print (msg )
218
+
219
+ return {"task_id" : task_id , "solution" : new_code }
220
+
221
+
181
222
def script (
182
- samples : str , inplace : bool = False , debug_task : str = None , calibrate : bool = False
223
+ samples : str , inplace : bool = False , debug_task : str = None , calibrate : bool = False , parallel : int = 32
183
224
):
184
225
# task_id -> entry_point
185
226
entry_point = {}
@@ -211,38 +252,26 @@ def script(
211
252
212
253
new_solutions = []
213
254
214
- for solution in tqdm (load_solutions (samples )):
215
- task_id = solution ["task_id" ]
216
- if task_id not in dataset :
217
- print (
218
- f"Skiping { task_id } as it does not existing in the latest EvalPlus dataset."
219
- )
220
- continue
221
-
222
- function_name = entry_point [task_id ] if task_id in entry_point else None
223
- dbg_identifier = solution ["_identifier" ]
224
- if debug_task is not None and task_id != debug_task :
225
- continue
226
-
227
- ntotal += 1
228
- if "solution" in solution :
229
- old_code = solution ["solution" ]
230
- if calibrate :
231
- old_code = solution ["solution" ].replace ("```python\n " , "```python\n " + dataset [task_id ]["complete_prompt" ]+ " " )
232
- else :
233
- assert "completion" in solution
234
- old_code = dataset [task_id ]["complete_prompt" ] + "\n " + solution ["completion" ]
235
-
236
- new_code = sanitize (code = old_code , entrypoint = function_name )
237
- # if changed, print the message
238
- if new_code != old_code :
239
- msg = "Sanitized: " + dbg_identifier
240
- if is_folder :
241
- msg += " -> " + dbg_identifier .replace (samples , target_path )
242
- print (msg )
255
+ parallel_arg_list = [
256
+ {
257
+ "sample_solution" : sample_solution ,
258
+ "dataset" : dataset ,
259
+ "entry_point" : entry_point ,
260
+ "debug_task" : debug_task ,
261
+ "calibrate" : calibrate ,
262
+ "is_folder" : is_folder ,
263
+ "target_path" : target_path
264
+ }
265
+ for sample_solution in load_solutions (samples )
266
+ ]
267
+
268
+ results = pqdm (parallel_arg_list , process_solution , n_jobs = min (parallel , os .cpu_count ()), argument_type = "kwargs" )
269
+
270
+ for result in results :
271
+ if result is not None :
272
+ new_solutions .append (result )
243
273
nsan += 1
244
-
245
- new_solutions .append ({"task_id" : task_id , "solution" : new_code })
274
+ ntotal += 1
246
275
247
276
if is_folder :
248
277
write_directory (target_path , new_solutions )
@@ -263,4 +292,4 @@ def main():
263
292
264
293
265
294
if __name__ == "__main__" :
266
- main ()
295
+ main ()
0 commit comments