Skip to content

Commit fcb784d

Browse files
authored
Pass At K Math (#647)
* test 1 * change task * change * fix names
1 parent bfb1099 commit fcb784d

File tree

3 files changed

+184
-4
lines changed

3 files changed

+184
-4
lines changed

src/lighteval/metrics/metrics.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828
ExprExtractionConfig,
2929
IndicesExtractionConfig,
3030
LatexExtractionConfig,
31+
compare_gold_target,
32+
extract_target_from_pred,
33+
get_extraction_regexes,
3134
multilingual_extractive_match_metric,
3235
)
3336
from lighteval.metrics.harness_compatibility.drop import drop_metrics
@@ -366,6 +369,167 @@ class Metrics(Enum):
366369
corpus_level_fn=np.mean,
367370
higher_is_better=True,
368371
)
372+
math_pass_at_1_4n = SampleLevelMetric(
373+
metric_name="math_pass@1:4_samples",
374+
sample_level_fn=PassAtK(
375+
k=1,
376+
n=4,
377+
strip_strings=True,
378+
# Extracting mathematical expressions and latex expressions
379+
normalize_gold=lambda k: extract_target_from_pred(
380+
k,
381+
get_extraction_regexes(
382+
formatted_doc=None,
383+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
384+
language=Language.ENGLISH,
385+
),
386+
),
387+
# Extracting mathematical expressions and latex expressions
388+
normalize_pred=lambda k: extract_target_from_pred(
389+
k,
390+
get_extraction_regexes(
391+
formatted_doc=None,
392+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
393+
language=Language.ENGLISH,
394+
),
395+
),
396+
# Uses sympy for comparision
397+
sample_scoring_function=compare_gold_target,
398+
).compute,
399+
category=MetricCategory.GENERATIVE_SAMPLING,
400+
use_case=MetricUseCase.REASONING,
401+
corpus_level_fn=np.mean,
402+
higher_is_better=True,
403+
)
404+
math_pass_at_1_8n = SampleLevelMetric(
405+
metric_name="math_pass@1:8_samples",
406+
sample_level_fn=PassAtK(
407+
k=1,
408+
n=8,
409+
strip_strings=True,
410+
# Extracting mathematical expressions and latex expressions
411+
normalize_gold=lambda k: extract_target_from_pred(
412+
k,
413+
get_extraction_regexes(
414+
formatted_doc=None,
415+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
416+
language=Language.ENGLISH,
417+
),
418+
),
419+
# Extracting mathematical expressions and latex expressions
420+
normalize_pred=lambda k: extract_target_from_pred(
421+
k,
422+
get_extraction_regexes(
423+
formatted_doc=None,
424+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
425+
language=Language.ENGLISH,
426+
),
427+
),
428+
# Uses sympy for comparision
429+
sample_scoring_function=compare_gold_target,
430+
).compute,
431+
category=MetricCategory.GENERATIVE_SAMPLING,
432+
use_case=MetricUseCase.REASONING,
433+
corpus_level_fn=np.mean,
434+
higher_is_better=True,
435+
)
436+
math_pass_at_1_16n = SampleLevelMetric(
437+
metric_name="math_pass@1:16_samples",
438+
sample_level_fn=PassAtK(
439+
k=1,
440+
n=16,
441+
strip_strings=True,
442+
# Extracting mathematical expressions and latex expressions
443+
normalize_gold=lambda k: extract_target_from_pred(
444+
k,
445+
get_extraction_regexes(
446+
formatted_doc=None,
447+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
448+
language=Language.ENGLISH,
449+
),
450+
),
451+
# Extracting mathematical expressions and latex expressions
452+
normalize_pred=lambda k: extract_target_from_pred(
453+
k,
454+
get_extraction_regexes(
455+
formatted_doc=None,
456+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
457+
language=Language.ENGLISH,
458+
),
459+
),
460+
# Uses sympy for comparision
461+
sample_scoring_function=compare_gold_target,
462+
).compute,
463+
category=MetricCategory.GENERATIVE_SAMPLING,
464+
use_case=MetricUseCase.REASONING,
465+
corpus_level_fn=np.mean,
466+
higher_is_better=True,
467+
)
468+
math_pass_at_1_32n = SampleLevelMetric(
469+
metric_name="math_pass@1:32_samples",
470+
sample_level_fn=PassAtK(
471+
k=1,
472+
n=32,
473+
strip_strings=True,
474+
# Extracting mathematical expressions and latex expressions
475+
normalize_gold=lambda k: extract_target_from_pred(
476+
k,
477+
get_extraction_regexes(
478+
formatted_doc=None,
479+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
480+
language=Language.ENGLISH,
481+
),
482+
),
483+
# Extracting mathematical expressions and latex expressions
484+
normalize_pred=lambda k: extract_target_from_pred(
485+
k,
486+
get_extraction_regexes(
487+
formatted_doc=None,
488+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
489+
language=Language.ENGLISH,
490+
),
491+
),
492+
# Uses sympy for comparision
493+
sample_scoring_function=compare_gold_target,
494+
).compute,
495+
category=MetricCategory.GENERATIVE_SAMPLING,
496+
use_case=MetricUseCase.REASONING,
497+
corpus_level_fn=np.mean,
498+
higher_is_better=True,
499+
)
500+
math_pass_at_1_64n = SampleLevelMetric(
501+
metric_name="math_pass@1:64_samples",
502+
sample_level_fn=PassAtK(
503+
k=1,
504+
n=64,
505+
strip_strings=True,
506+
# Extracting mathematical expressions and latex expressions
507+
normalize_gold=lambda k: extract_target_from_pred(
508+
k,
509+
get_extraction_regexes(
510+
formatted_doc=None,
511+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
512+
language=Language.ENGLISH,
513+
),
514+
),
515+
# Extracting mathematical expressions and latex expressions
516+
normalize_pred=lambda k: extract_target_from_pred(
517+
k,
518+
get_extraction_regexes(
519+
formatted_doc=None,
520+
target_types=[ExprExtractionConfig(), LatexExtractionConfig()],
521+
language=Language.ENGLISH,
522+
),
523+
),
524+
# Uses sympy for comparision
525+
sample_scoring_function=compare_gold_target,
526+
).compute,
527+
category=MetricCategory.GENERATIVE_SAMPLING,
528+
use_case=MetricUseCase.REASONING,
529+
corpus_level_fn=np.mean,
530+
higher_is_better=True,
531+
)
532+
369533
mrr = SampleLevelMetric(
370534
metric_name="mrr",
371535
sample_level_fn=MRR().compute,

src/lighteval/models/endpoints/endpoint_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,9 @@ def _async_process_request(
409409
decoder_input_details=True,
410410
grammar=grammar,
411411
)
412+
generation_config_dict = {k: v for k, v in generation_config.__dict__.items() if v is not None}
412413

413-
generated_text = self.async_client.text_generation(prompt=context, generation_config=generation_config)
414+
generated_text = self.async_client.text_generation(prompt=context, **generation_config_dict)
414415

415416
return generated_text
416417

@@ -431,10 +432,11 @@ def _process_request(
431432
decoder_input_details=True,
432433
grammar=grammar,
433434
)
435+
generation_config_dict = {k: v for k, v in generation_config.__dict__.items() if v is not None}
434436

435437
generated_text = self.client.text_generation(
436438
prompt=context,
437-
generation_config=generation_config,
439+
**generation_config_dict,
438440
)
439441

440442
return generated_text

src/lighteval/tasks/default_tasks.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,14 @@
323323
few_shots_split=None,
324324
few_shots_select=None,
325325
generation_size=32768,
326-
metric=[Metrics.expr_gold_metric],
326+
metric=[
327+
Metrics.expr_gold_metric,
328+
Metrics.math_pass_at_1_4n,
329+
Metrics.math_pass_at_1_8n,
330+
Metrics.math_pass_at_1_16n,
331+
Metrics.math_pass_at_1_32n,
332+
Metrics.math_pass_at_1_64n,
333+
],
327334
version=1,
328335
)
329336
aime25 = LightevalTaskConfig(
@@ -337,7 +344,14 @@
337344
few_shots_split=None,
338345
few_shots_select=None,
339346
generation_size=10000,
340-
metric=[Metrics.expr_gold_metric],
347+
metric=[
348+
Metrics.expr_gold_metric,
349+
Metrics.math_pass_at_1_4n,
350+
Metrics.math_pass_at_1_8n,
351+
Metrics.math_pass_at_1_16n,
352+
Metrics.math_pass_at_1_32n,
353+
Metrics.math_pass_at_1_64n,
354+
],
341355
version=1,
342356
)
343357
anachronisms_bigbench = LightevalTaskConfig(

0 commit comments

Comments
 (0)