Skip to content

Commit a944890

Browse files
Fix callable annotations (#4216)
1 parent 521db35 commit a944890

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

tests/testing_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import random
1717
import signal
1818
import warnings
19+
from collections.abc import Callable
1920

2021
import psutil
2122
import pytest
@@ -73,7 +74,7 @@ def set_tmp_dir(self, tmp_path):
7374
self.tmp_dir = str(tmp_path)
7475

7576

76-
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
77+
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
7778
"""
7879
Decorator to ignore warnings with a specific message and/or category.
7980

trl/extras/profiling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import contextlib
1616
import functools
1717
import time
18-
from collections.abc import Generator
18+
from collections.abc import Callable, Generator
1919

2020
from transformers import Trainer
2121
from transformers.integrations import is_mlflow_available, is_wandb_available
@@ -68,12 +68,12 @@ def some_method(self):
6868
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
6969

7070

71-
def profiling_decorator(func: callable) -> callable:
71+
def profiling_decorator(func: Callable) -> Callable:
7272
"""
7373
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
7474
7575
Args:
76-
func (`callable`):
76+
func (`Callable`):
7777
Function to be profiled.
7878
7979
Example:

trl/trainer/callbacks.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
import os
17+
from collections.abc import Callable
1718
from typing import Optional, Union
1819

1920
import pandas as pd
@@ -583,7 +584,7 @@ def __init__(
583584
self,
584585
trainer: Trainer,
585586
project_name: Optional[str] = None,
586-
scorers: Optional[dict[str, callable]] = None,
587+
scorers: Optional[dict[str, Callable]] = None,
587588
generation_config: Optional[GenerationConfig] = None,
588589
num_prompts: Optional[int] = None,
589590
dataset_name: str = "eval_dataset",

0 commit comments

Comments
 (0)