1- from typing import Callable , Tuple
1+ from __future__ import annotations
2+
3+ from typing import Callable , Sequence , Tuple
24
35from pathlib import Path
46
1214 REGRESSION ,
1315)
1416from autosklearn .data .xy_data_manager import XYDataManager
15- from autosklearn .metrics import Scorer , accuracy , precision , r2
17+ from autosklearn .metrics import Scorer , accuracy , log_loss , precision , r2
1618from autosklearn .util .logging_ import PicklableClientLogger
1719
1820import pytest
2123
2224
2325@parametrize (
24- "dataset, metric , task" ,
26+ "dataset, metrics , task" ,
2527 [
26- ("breast_cancer" , accuracy , BINARY_CLASSIFICATION ),
27- ("wine" , accuracy , MULTICLASS_CLASSIFICATION ),
28- ("diabetes" , r2 , REGRESSION ),
28+ ("breast_cancer" , [accuracy ], BINARY_CLASSIFICATION ),
29+ ("breast_cancer" , [accuracy , log_loss ], BINARY_CLASSIFICATION ),
30+ ("wine" , [accuracy ], MULTICLASS_CLASSIFICATION ),
31+ ("diabetes" , [r2 ], REGRESSION ),
2932 ],
3033)
3134def test_produces_correct_output (
3235 dataset : str ,
3336 task : int ,
34- metric : Scorer ,
37+ metrics : Sequence [ Scorer ] ,
3538 mock_logger : PicklableClientLogger ,
3639 make_automl : Callable [..., AutoML ],
3740 make_sklearn_dataset : Callable [..., XYDataManager ],
@@ -45,8 +48,8 @@ def test_produces_correct_output(
4548 task : int
4649 The task type of the dataset
4750
48- metric: Scorer
49- Metric to use, required as fit usually determines the metric to use
51+ metrics: Sequence[ Scorer]
52+ Metric(s) to use, required as fit usually determines the metric to use
5053
5154 Fixtures
5255 --------
@@ -66,7 +69,7 @@ def test_produces_correct_output(
6669 * It should produce predictions "predictions_ensemble_1337_1_0.0.npy"
6770 """
6871 seed = 1337
69- automl = make_automl (metrics = [ metric ] , seed = seed )
72+ automl = make_automl (metrics = metrics , seed = seed )
7073 automl ._logger = mock_logger
7174
7275 datamanager = make_sklearn_dataset (
0 commit comments