diff --git a/datafusion_ray/core.py b/datafusion_ray/core.py index 856a237..c68f284 100644 --- a/datafusion_ray/core.py +++ b/datafusion_ray/core.py @@ -75,9 +75,16 @@ def call_sync(coro): log.exception(e) +# work around for https://github.com/ray-project/ray/issues/31606 +async def _ensure_coro(maybe_obj_ref): + return await maybe_obj_ref + + async def wait_for(coros, name=""): return_values = [] - done, _ = await asyncio.wait(coros) + # wrap the coro in a task to work with python 3.10 and 3.11+ where asyncio.wait semantics + # changed to not accept any awaitable + done, _ = await asyncio.wait([asyncio.create_task(_ensure_coro(c)) for c in coros]) for d in done: e = d.exception() if e is not None: