diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py index 056899b6e..d5a1d260e 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field.py @@ -96,12 +96,14 @@ def _cast_value(ctx: Context, value: Any, data_type: type | term.URIRef | None): else: raise ValueError(f"Type {type(value)} is not accepted for an image.") elif data_type == DataType.AUDIO_OBJECT: - output = deps.librosa.load(io.BytesIO(value)) + output = deps.librosa.load(io.BytesIO(value), sr=None) return output elif data_type == DataType.BOUNDING_BOX: # pytype: disable=wrong-arg-types return bounding_box.parse(value) elif not isinstance(data_type, type): raise ValueError(f"No special case for type {data_type}.") + elif isinstance(value, np.ndarray) and issubclass(data_type, np.generic): + return value.astype(data_type) elif isinstance(value, list) or isinstance(value, np.ndarray): return [_cast_value(ctx=ctx, value=v, data_type=data_type) for v in value] elif data_type == bytes and not isinstance(value, bytes): diff --git a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py index 4299193b6..f71d8c3b1 100644 --- a/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py +++ b/python/mlcroissant/mlcroissant/_src/operation_graph/operations/field_test.py @@ -51,6 +51,21 @@ def test_cast_value(conforms_to, value, data_type, expected): assert field._cast_value(ctx, value, data_type) == expected +@parametrize_conforms_to() +@pytest.mark.parametrize( + ["value", "data_type", "expected"], + [ + [np.array([1, 2, 3]), DataType.INTEGER, np.array([1, 2, 3])], + [np.array([1, 2, 3]), DataType.FLOAT32, np.array([1.0, 2.0, 3.0])], + ], +) +def test_cast_value_ndarray(): + ctx = Context(conforms_to=conforms_to) + cast_value = field._cast_value(ctx, value, data_type) + assert cast_value == expected + assert cast_value.dtype == expected.dtype + + @parametrize_conforms_to() @pytest.mark.parametrize( ["value", "data_type"],