diff --git a/unified-runtime/source/adapters/offload/event.cpp b/unified-runtime/source/adapters/offload/event.cpp index aab41ed3d2d0e..5df95eee02d1b 100644 --- a/unified-runtime/source/adapters/offload/event.cpp +++ b/unified-runtime/source/adapters/offload/event.cpp @@ -38,6 +38,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urEventGetInfo(ur_event_handle_t hEvent, return UR_RESULT_SUCCESS; } +namespace { +struct callback_data_t { + ur_event_callback_t Callback; + ur_event_handle_t Event; + ur_execution_info_t Status; + void *UserData; +}; +void CallbackHandler(void *CallbackData) { + auto *Data = reinterpret_cast(CallbackData); + Data->Callback(Data->Event, Data->Status, Data->UserData); + delete Data; +} +} // namespace + +UR_APIEXPORT ur_result_t UR_APICALL +urEventSetCallback(ur_event_handle_t hEvent, ur_execution_info_t execStatus, + ur_event_callback_t pfnNotify, void *pUserData) { + // Liboffload only supports a transition from SUBMITTED to COMPLETE + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(olCreateQueue(hEvent->UrQueue->OffloadDevice, &Queue)); + OL_RETURN_ON_ERR(olWaitEvents(Queue, &hEvent->OffloadEvent, 1)); + auto CallbackData = + new callback_data_t{pfnNotify, hEvent, execStatus, pUserData}; + OL_RETURN_ON_ERR(olLaunchHostFunction(Queue, CallbackHandler, CallbackData)); + OL_RETURN_ON_ERR(olDestroyQueue(Queue)); + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEventGetProfilingInfo(ur_event_handle_t, ur_profiling_info_t, size_t, void *, diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 5b4c8bd13bc50..bc19bc6f76481 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -74,7 +74,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEventProcAddrTable( pDdiTable->pfnGetProfilingInfo = urEventGetProfilingInfo; pDdiTable->pfnRelease = urEventRelease; pDdiTable->pfnRetain = urEventRetain; - pDdiTable->pfnSetCallback = nullptr; + pDdiTable->pfnSetCallback = urEventSetCallback; pDdiTable->pfnWait = urEventWait; return UR_RESULT_SUCCESS; }