diff --git a/pymc/pytensorf.py b/pymc/pytensorf.py index 6d44603a8..64c3cf479 100644 --- a/pymc/pytensorf.py +++ b/pymc/pytensorf.py @@ -78,9 +78,12 @@ ] -def convert_observed_data(data): +def convert_observed_data(data) -> np.ndarray | Variable: """Convert user provided dataset to accepted formats.""" + if isgenerator(data): + return floatX(generator(data)) + if hasattr(data, "to_numpy") and hasattr(data, "isnull"): # typically, but not limited to pandas objects vals = data.to_numpy() @@ -116,8 +119,6 @@ def convert_observed_data(data): ret = data elif sps.issparse(data): ret = data - elif isgenerator(data): - ret = generator(data) else: ret = np.asarray(data)