Skip to content

Commit cc578b6

Browse files
authored
🧺 [2/N] Refactor _generate in GRPO/RLOO: Use prompt_ids from generation (#4152)
1 parent 30cf68a commit cc578b6

File tree

5 files changed

+117
-62
lines changed

5 files changed

+117
-62
lines changed

‎tests/test_vllm_client_server.py‎

Lines changed: 69 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -74,36 +74,42 @@ def setup_class(cls):
7474

7575
def test_generate(self):
7676
prompts = ["Hello, AI!", "Tell me a joke"]
77-
outputs = self.client.generate(prompts)["completion_ids"]
77+
outputs = self.client.generate(prompts)
78+
prompt_ids = outputs["prompt_ids"]
79+
completion_ids = outputs["completion_ids"]
7880

79-
# Check that the output is a list
80-
assert isinstance(outputs, list)
81+
# Check that the outputs are lists
82+
assert isinstance(prompt_ids, list)
83+
assert isinstance(completion_ids, list)
8184

82-
# Check that the number of generated sequences is equal to the number of prompts
83-
assert len(outputs) == len(prompts)
85+
# Check that the number of sequences are equal to the number of prompts
86+
assert len(prompt_ids) == len(prompts)
87+
assert len(completion_ids) == len(prompts)
8488

85-
# Check that the generated sequences are lists of integers
86-
for seq in outputs:
89+
# Check that the sequences are lists of integers
90+
for seq in prompt_ids:
91+
assert all(isinstance(tok, int) for tok in seq)
92+
for seq in completion_ids:
8793
assert all(isinstance(tok, int) for tok in seq)
8894

8995
def test_generate_with_params(self):
9096
prompts = ["Hello, AI!", "Tell me a joke"]
91-
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
97+
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
9298
"completion_ids"
9399
]
94100

95101
# Check that the output is a list
96-
assert isinstance(outputs, list)
102+
assert isinstance(completion_ids, list)
97103

98104
# Check that the number of generated sequences is 2 times the number of prompts
99-
assert len(outputs) == 2 * len(prompts)
105+
assert len(completion_ids) == 2 * len(prompts)
100106

101107
# Check that the generated sequences are lists of integers
102-
for seq in outputs:
108+
for seq in completion_ids:
103109
assert all(isinstance(tok, int) for tok in seq)
104110

105111
# Check that the length of the generated sequences is less than or equal to 32
106-
for seq in outputs:
112+
for seq in completion_ids:
107113
assert len(seq) <= 32
108114

109115
def test_update_model_params(self):
@@ -148,36 +154,42 @@ def setup_class(cls):
148154

149155
def test_generate(self):
150156
prompts = ["Hello, AI!", "Tell me a joke"]
151-
outputs = self.client.generate(prompts)["completion_ids"]
157+
outputs = self.client.generate(prompts)
158+
prompt_ids = outputs["prompt_ids"]
159+
completion_ids = outputs["completion_ids"]
152160

153-
# Check that the output is a list
154-
assert isinstance(outputs, list)
161+
# Check that the outputs are lists
162+
assert isinstance(prompt_ids, list)
163+
assert isinstance(completion_ids, list)
155164

156-
# Check that the number of generated sequences is equal to the number of prompts
157-
assert len(outputs) == len(prompts)
165+
# Check that the number of sequences are equal to the number of prompts
166+
assert len(prompt_ids) == len(prompts)
167+
assert len(completion_ids) == len(prompts)
158168

159-
# Check that the generated sequences are lists of integers
160-
for seq in outputs:
169+
# Check that the sequences are lists of integers
170+
for seq in prompt_ids:
171+
assert all(isinstance(tok, int) for tok in seq)
172+
for seq in completion_ids:
161173
assert all(isinstance(tok, int) for tok in seq)
162174

163175
def test_generate_with_params(self):
164176
prompts = ["Hello, AI!", "Tell me a joke"]
165-
outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
177+
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
166178
"completion_ids"
167179
]
168180

169181
# Check that the output is a list
170-
assert isinstance(outputs, list)
182+
assert isinstance(completion_ids, list)
171183

172184
# Check that the number of generated sequences is 2 times the number of prompts
173-
assert len(outputs) == 2 * len(prompts)
185+
assert len(completion_ids) == 2 * len(prompts)
174186

175187
# Check that the generated sequences are lists of integers
176-
for seq in outputs:
188+
for seq in completion_ids:
177189
assert all(isinstance(tok, int) for tok in seq)
178190

179191
# Check that the length of the generated sequences is less than or equal to 32
180-
for seq in outputs:
192+
for seq in completion_ids:
181193
assert len(seq) <= 32
182194

183195
def test_update_model_params(self):
@@ -224,16 +236,22 @@ def setup_class(cls):
224236

225237
def test_generate(self):
226238
prompts = ["Hello, AI!", "Tell me a joke"]
227-
outputs = self.client.generate(prompts)["completion_ids"]
239+
outputs = self.client.generate(prompts)
240+
prompt_ids = outputs["prompt_ids"]
241+
completion_ids = outputs["completion_ids"]
228242

229-
# Check that the output is a list
230-
assert isinstance(outputs, list)
243+
# Check that the outputs are lists
244+
assert isinstance(prompt_ids, list)
245+
assert isinstance(completion_ids, list)
231246

232-
# Check that the number of generated sequences is equal to the number of prompts
233-
assert len(outputs) == len(prompts)
247+
# Check that the number of sequences are equal to the number of prompts
248+
assert len(prompt_ids) == len(prompts)
249+
assert len(completion_ids) == len(prompts)
234250

235-
# Check that the generated sequences are lists of integers
236-
for seq in outputs:
251+
# Check that the sequences are lists of integers
252+
for seq in prompt_ids:
253+
assert all(isinstance(tok, int) for tok in seq)
254+
for seq in completion_ids:
237255
assert all(isinstance(tok, int) for tok in seq)
238256

239257
def test_update_model_params(self):
@@ -280,16 +298,22 @@ def setup_class(cls):
280298

281299
def test_generate(self):
282300
prompts = ["Hello, AI!", "Tell me a joke"]
283-
outputs = self.client.generate(prompts)["completion_ids"]
301+
outputs = self.client.generate(prompts)
302+
prompt_ids = outputs["prompt_ids"]
303+
completion_ids = outputs["completion_ids"]
284304

285-
# Check that the output is a list
286-
assert isinstance(outputs, list)
305+
# Check that the outputs are lists
306+
assert isinstance(prompt_ids, list)
307+
assert isinstance(completion_ids, list)
287308

288-
# Check that the number of generated sequences is equal to the number of prompts
289-
assert len(outputs) == len(prompts)
309+
# Check that the number of sequences are equal to the number of prompts
310+
assert len(prompt_ids) == len(prompts)
311+
assert len(completion_ids) == len(prompts)
290312

291-
# Check that the generated sequences are lists of integers
292-
for seq in outputs:
313+
# Check that the sequences are lists of integers
314+
for seq in prompt_ids:
315+
assert all(isinstance(tok, int) for tok in seq)
316+
for seq in completion_ids:
293317
assert all(isinstance(tok, int) for tok in seq)
294318

295319
def test_update_model_params(self):
@@ -336,9 +360,13 @@ def test_init_communicator_with_device_int(self):
336360

337361
# Test basic functionality
338362
prompts = ["Hello, AI!"]
339-
outputs = client.generate(prompts)["completion_ids"]
340-
assert isinstance(outputs, list)
341-
assert len(outputs) == len(prompts)
363+
outputs = client.generate(prompts)
364+
prompt_ids = outputs["prompt_ids"]
365+
completion_ids = outputs["completion_ids"]
366+
assert isinstance(prompt_ids, list)
367+
assert len(prompt_ids) == len(prompts)
368+
assert isinstance(completion_ids, list)
369+
assert len(completion_ids) == len(prompts)
342370

343371
client.close_communicator()
344372

‎trl/extras/vllm_client.py‎

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ class VLLMClient:
8383
8484
>>> client = VLLMClient()
8585
>>> client.generate(["Hello, AI!", "Tell me a joke"])
86-
[[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
87-
[911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]]
86+
{'prompt_ids': [[9707, 11, 15235, 0],
87+
[40451, 752, 264, 21646]],
88+
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
89+
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
90+
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
91+
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}
8892
8993
>>> from transformers import AutoModelForCausalLM
9094
@@ -212,6 +216,8 @@ def generate(
212216
213217
Returns:
214218
`dict` with keys:
219+
- `prompt_ids` (`list[list[int]]`):
220+
List of lists of token IDs representing the tokenized input prompts.
215221
- `completion_ids` (`list[list[int]]`):
216222
List of lists of token IDs representing the model-generated completions for each prompt.
217223
- `logprobs` (`list[list[float]]`):
@@ -246,7 +252,11 @@ def pil_to_base64(image):
246252
)
247253
if response.status_code == 200:
248254
json_response = response.json()
249-
return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]}
255+
return {
256+
"prompt_ids": json_response["prompt_ids"],
257+
"completion_ids": json_response["completion_ids"],
258+
"logprobs": json_response["logprobs"],
259+
}
250260
else:
251261
raise Exception(f"Request failed: {response.status_code}, {response.text}")
252262

‎trl/scripts/vllm_serve.py‎

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,7 @@ class GenerateRequest(BaseModel):
499499
generation_kwargs: dict = field(default_factory=dict)
500500

501501
class GenerateResponse(BaseModel):
502+
prompt_ids: list[list[int]]
502503
completion_ids: list[list[int]]
503504
logprobs: list[list[float]]
504505

@@ -532,6 +533,7 @@ async def generate(request: GenerateRequest):
532533
533534
Returns:
534535
`GenerateResponse`:
536+
- `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt.
535537
- `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion.
536538
- `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the
537539
generated completions.
@@ -543,7 +545,11 @@ async def generate(request: GenerateRequest):
543545
544546
Example response:
545547
```json
546-
{"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]}
548+
{
549+
"prompt_ids": [[101, 102], [201, 202]],
550+
"completion_ids": [[103, 104, 105], [203, 204, 205]],
551+
"logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]
552+
}
547553
```
548554
"""
549555
request.images = request.images or [None] * len(request.prompts)
@@ -596,13 +602,14 @@ async def generate(request: GenerateRequest):
596602

597603
# Flatten and combine all results
598604
all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list
605+
prompt_ids = [output.prompt_token_ids for output in all_outputs]
599606
completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs]
600607
logprobs: list[list[float]] = [
601608
[sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs]
602609
for outputs in all_outputs
603610
for output in outputs.outputs
604611
]
605-
return {"completion_ids": completion_ids, "logprobs": logprobs}
612+
return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs}
606613

607614
class InitCommunicatorRequest(BaseModel):
608615
host: str

‎trl/trainer/grpo_trainer.py‎

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,11 +1101,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
11011101
**kwargs,
11021102
)
11031103
prompt_inputs = super()._prepare_inputs(prompt_inputs)
1104-
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
11051104
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
1106-
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
11071105

11081106
if self.max_prompt_length is not None:
1107+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
1108+
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
1109+
11091110
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
11101111
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
11111112
# tokens are needed for generation.
@@ -1187,19 +1188,23 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
11871188
guided_decoding_regex=self.guided_decoding_regex,
11881189
generation_kwargs=self.args.generation_kwargs,
11891190
)
1190-
payload = (output["completion_ids"], output["logprobs"])
1191+
payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"])
11911192
else:
11921193
payload = None
11931194

11941195
# Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice.
11951196
obj_list = [payload]
11961197
broadcast_object_list(obj_list, from_process=0)
1197-
all_completion_ids, all_logprobs = obj_list[0]
1198+
all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0]
1199+
1200+
# At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times
1201+
all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)]
11981202

11991203
process_slice = slice(
12001204
self.accelerator.process_index * len(prompts),
12011205
(self.accelerator.process_index + 1) * len(prompts),
12021206
)
1207+
prompt_ids = all_prompt_ids[process_slice]
12031208
completion_ids = all_completion_ids[process_slice]
12041209
logprobs = all_logprobs[process_slice]
12051210

@@ -1254,6 +1259,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
12541259
with profiling_context(self, "vLLM.generate"):
12551260
all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False)
12561261

1262+
all_prompt_ids = [output.prompt_token_ids for output in all_outputs]
12571263
all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs]
12581264
all_logprobs = [
12591265
[next(iter(lp.values())).logprob for lp in output.logprobs]
@@ -1266,9 +1272,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
12661272
# Each rank generates all outputs — we keep only our share.
12671273
local_rank_in_group = torch.distributed.get_rank(group=self.tp_group)
12681274
tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size)
1275+
prompt_ids = all_prompt_ids[tp_slice]
12691276
completion_ids = all_completion_ids[tp_slice]
12701277
logprobs = all_logprobs[tp_slice]
12711278
else:
1279+
prompt_ids = all_prompt_ids
12721280
completion_ids = all_completion_ids
12731281
logprobs = all_logprobs
12741282

@@ -1311,10 +1319,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]):
13111319

13121320
else:
13131321
# Regular generation path
1314-
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids]
1315-
prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids]
1316-
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
1317-
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
1322+
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
13181323

13191324
with (
13201325
profiling_context(self, "transformers.generate"),

0 commit comments

Comments
 (0)