From b97a5b01b01411f4de71a4b3b27956bc53d6acc4 Mon Sep 17 00:00:00 2001 From: Laura Helleckes Date: Mon, 6 May 2024 10:00:31 +0200 Subject: [PATCH] Refactor convert_observed data to simplify typing --- pymc/pytensorf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6d44603a8..64c3cf479 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -78,9 +78,12 @@ ] -def convert_observed_data(data): +def convert_observed_data(data) -> np.ndarray | Variable: """Convert user provided dataset to accepted formats.""" + if isgenerator(data): + return floatX(generator(data)) + if hasattr(data, "to_numpy") and hasattr(data, "isnull"): # typically, but not limited to pandas objects vals = data.to_numpy() @@ -116,8 +119,6 @@ def convert_observed_data(data): ret = data elif sps.issparse(data): ret = data - elif isgenerator(data): - ret = generator(data) else: ret = np.asarray(data)