@@ -1151,6 +1151,12 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11511151 help = "Do not oversample if the dataset has " \
11521152 "fewer samples than num-prompts." ,
11531153 )
1154+ parser .add_argument (
1155+ "--skip-chat-template" ,
1156+ action = "store_true" ,
1157+ help =
1158+ "Skip applying chat template to prompt for datasets that support it." ,
1159+ )
11541160
11551161 # group for dataset specific arguments
11561162 custom_group = parser .add_argument_group ("custom dataset options" )
@@ -1161,12 +1167,6 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11611167 help =
11621168 "Number of output tokens per request, used only for custom dataset." ,
11631169 )
1164- custom_group .add_argument (
1165- "--custom-skip-chat-template" ,
1166- action = "store_true" ,
1167- help =
1168- "Skip applying chat template to prompt, used only for custom dataset." ,
1169- )
11701170
11711171 spec_bench_group = parser .add_argument_group ("spec bench dataset options" )
11721172 spec_bench_group .add_argument (
@@ -1435,7 +1435,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14351435 num_requests = args .num_prompts ,
14361436 tokenizer = tokenizer ,
14371437 output_len = args .custom_output_len ,
1438- skip_chat_template = args .custom_skip_chat_template ,
1438+ skip_chat_template = args .skip_chat_template ,
14391439 request_id_prefix = args .request_id_prefix ,
14401440 no_oversample = args .no_oversample ,
14411441 )
@@ -1576,6 +1576,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
15761576 output_len = args .hf_output_len ,
15771577 request_id_prefix = args .request_id_prefix ,
15781578 no_oversample = args .no_oversample ,
1579+ skip_chat_template = args .skip_chat_template ,
15791580 ** hf_kwargs
15801581 )
15811582
@@ -1815,7 +1816,6 @@ def load_data(self) -> None:
18151816
18161817 def sample (self , ** kwargs ) -> list :
18171818 # leverage CustomDataset sample
1818- kwargs ["skip_chat_template" ] = False
18191819 return super ().sample (** kwargs )
18201820
18211821
@@ -2221,6 +2221,7 @@ def sample(self,
22212221 num_requests : int ,
22222222 output_len : Optional [int ] = None ,
22232223 enable_multimodal_chat : bool = False ,
2224+ skip_chat_template : bool = False ,
22242225 request_id_prefix : str = "" ,
22252226 no_oversample : bool = False ,
22262227 ** kwargs ) -> list :
@@ -2236,14 +2237,15 @@ def sample(self,
22362237 )
22372238
22382239 # apply template
2239- prompt = tokenizer .apply_chat_template (
2240- [{
2241- "role" : "user" ,
2242- "content" : prompt
2243- }],
2244- add_generation_prompt = True ,
2245- tokenize = False ,
2246- )
2240+ if not skip_chat_template :
2241+ prompt = tokenizer .apply_chat_template (
2242+ [{
2243+ "role" : "user" ,
2244+ "content" : prompt
2245+ }],
2246+ add_generation_prompt = True ,
2247+ tokenize = False ,
2248+ )
22472249
22482250 prompt_len = len (tokenizer (prompt ).input_ids )
22492251 sampled_requests .append (
@@ -2284,6 +2286,7 @@ def sample(
22842286 num_requests : int ,
22852287 output_len : Optional [int ] = None ,
22862288 enable_multimodal_chat : bool = False ,
2289+ skip_chat_template : bool = False ,
22872290 request_id_prefix : str = "" ,
22882291 no_oversample : bool = False ,
22892292 ** kwargs ,
@@ -2298,14 +2301,18 @@ def sample(
22982301 prompt = item ["turns" ][0 ]
22992302
23002303 # apply template
2301- prompt = tokenizer .apply_chat_template (
2302- [{
2303- "role" : "user" ,
2304- "content" : prompt
2305- }],
2306- add_generation_prompt = True ,
2307- tokenize = False ,
2308- )
2304+ if not skip_chat_template :
2305+ prompt = tokenizer .apply_chat_template (
2306+ [{
2307+ "role" : "user" ,
2308+ "content" : prompt
2309+ }],
2310+ add_generation_prompt = True ,
2311+ tokenize = False ,
2312+ )
2313+
2314+ # REMOVE
2315+ print (f"Prompt { i } : { prompt } \n ---" )
23092316
23102317 prompt_len = len (tokenizer (prompt ).input_ids )
23112318 sampled_requests .append (
@@ -2349,6 +2356,7 @@ def sample(
23492356 tokenizer : PreTrainedTokenizerBase ,
23502357 num_requests : int ,
23512358 output_len : Optional [int ] = None ,
2359+ skip_chat_template : bool = False ,
23522360 request_id_prefix : str = "" ,
23532361 no_oversample : bool = False ,
23542362 min_distance : float = 0.0 ,
@@ -2372,7 +2380,7 @@ def sample(
23722380
23732381 # template copied from
23742382 # https://github.com/ise-uiuc/blazedit/blob/7765137e656fd62de877422d2e4cf8de51228054/dataset/create_refined_dataset.py#L94-L105 # noqa: E501
2375- instruction = f"""Given a code file, please apply the change requests and generate the new file.
2383+ prompt = f"""Given a code file, please apply the change requests and generate the new file.
23762384
23772385Original file:
23782386```python
@@ -2385,14 +2393,15 @@ def sample(
23852393Please generate the new code file in the "New file" section below.""" # noqa: E501
23862394
23872395 # apply template
2388- prompt = tokenizer .apply_chat_template (
2389- [{
2390- "role" : "user" ,
2391- "content" : instruction
2392- }],
2393- add_generation_prompt = True ,
2394- tokenize = False ,
2395- )
2396+ if not skip_chat_template :
2397+ prompt = tokenizer .apply_chat_template (
2398+ [{
2399+ "role" : "user" ,
2400+ "content" : prompt
2401+ }],
2402+ add_generation_prompt = True ,
2403+ tokenize = False ,
2404+ )
23962405
23972406 prompt_len = len (tokenizer (prompt ).input_ids )
23982407
0 commit comments