Skip to content

Commit 7c5dae2

Browse files
Fix a tiny bug in PromptManager::FewShotSampler::_init_fewshot_sampling_random (#423)
--------- Co-authored-by: Clémentine Fourrier <[email protected]>
1 parent c2337cf commit 7c5dae2

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

src/lighteval/tasks/prompt_manager.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def doc_to_fewshot_sorting_class(formatted_doc: Doc) -> str:
9292
formatted_doc (Doc): Formatted document.
9393
9494
Returns:
95-
str: Class of the
95+
str: Class of the fewshot document
9696
"""
9797
return formatted_doc.fewshot_sorting_class or PromptManager.doc_to_target(formatted_doc)
9898

@@ -356,12 +356,13 @@ def _init_fewshot_sampling_sequential(self, num_fewshot: int, variance_seed: int
356356
self._fewshot_cache[variance_seed] = fewshotpool # Store few shot examples
357357

358358
def _init_fewshot_sampling_random(self, variance_seed: int):
359-
fewshotpool = self.task.fewshot_docs()
359+
fewshotpool = list(self.task.fewshot_docs())
360360
if variance_seed == 0:
361361
self._fewshot_cache[variance_seed] = fewshotpool
362362
else: # we shuffle
363363
rnd = random.Random(variance_seed)
364-
self._fewshot_cache[variance_seed] = rnd.shuffle(fewshotpool)
364+
rnd.shuffle(fewshotpool)
365+
self._fewshot_cache[variance_seed] = fewshotpool
365366

366367
def _init_fewshot_sampling_balanced(
367368
self,

tests/test_prompt_manager.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# MIT License
2+
3+
# Copyright (c) 2024 The HuggingFace Team
4+
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
12+
# The above copyright notice and this permission notice shall be included in all
13+
# copies or substantial portions of the Software.
14+
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
# SOFTWARE.
22+
23+
import random
24+
from collections import Counter
25+
26+
import pytest
27+
28+
from lighteval.tasks.lighteval_task import LightevalTask, LightevalTaskConfig
29+
from lighteval.tasks.prompt_manager import FewShotSampler, PromptManager
30+
from lighteval.tasks.requests import Doc
31+
32+
33+
@pytest.mark.parametrize("fewshot_select", ["sequential", "random", "balanced"])
34+
def test_fewshot_sampler(fewshot_select: str):
35+
config = LightevalTaskConfig(
36+
name="test_fewshot_task",
37+
prompt_function=lambda _, __: None,
38+
hf_repo=None,
39+
hf_subset="default",
40+
metric=[],
41+
few_shots_split="test",
42+
few_shots_select=fewshot_select,
43+
)
44+
task = LightevalTask("test_fewshot_task", config)
45+
rnd = random.Random(0)
46+
task._fewshot_docs = [
47+
Doc(str(i), ["A", "B"], rnd.randint(0, 2), fewshot_sorting_class=str(i % 20)) for i in range(100)
48+
]
49+
sampler = FewShotSampler(task)
50+
seed = 1
51+
docs = sampler.sample_fewshot_examples(20, seed)
52+
match task.fewshot_selection:
53+
case "balanced":
54+
labels = Counter([PromptManager.doc_to_fewshot_sorting_class(d) for d in docs])
55+
assert labels.total() / len(labels) == 1
56+
case "sequential":
57+
assert docs == task.fewshot_docs()[:20]
58+
case "random":
59+
rnd = random.Random(seed)
60+
task_docs = task.fewshot_docs()
61+
rnd.shuffle(task_docs)
62+
assert docs == task_docs[:20]

0 commit comments

Comments
 (0)