diff --git a/unified-runtime/source/adapters/level_zero/v2/context.cpp b/unified-runtime/source/adapters/level_zero/v2/context.cpp index 3d2a7758d6be4..acd84bf36015e 100644 --- a/unified-runtime/source/adapters/level_zero/v2/context.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/context.cpp @@ -80,11 +80,14 @@ ur_context_handle_t_::ur_context_handle_t_(ze_context_handle_t hContext, phDevices[0]->Platform->ZeMutableCmdListExt.Supported}), eventPoolCacheImmediate( this, phDevices[0]->Platform->getNumDevices(), - [context = this](DeviceId /* deviceId*/, v2::event_flags_t flags) - -> std::unique_ptr { + [context = this, platform = phDevices[0]->Platform]( + DeviceId deviceId, + v2::event_flags_t flags) -> std::unique_ptr { + auto device = platform->getDeviceById(deviceId); + // TODO: just use per-context id? - return std::make_unique( - context, v2::QUEUE_IMMEDIATE, flags); + return std::make_unique( + platform, context, v2::QUEUE_IMMEDIATE, device, flags); }), eventPoolCacheRegular(this, phDevices[0]->Platform->getNumDevices(), [context = this, platform = phDevices[0]->Platform]( diff --git a/unified-runtime/source/adapters/level_zero/v2/event_provider.hpp b/unified-runtime/source/adapters/level_zero/v2/event_provider.hpp index c6bedb8fc1cf8..2c7529d6cb288 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_provider.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_provider.hpp @@ -28,6 +28,11 @@ enum event_flag_t { }; static constexpr size_t EVENT_FLAGS_USED_BITS = 2; +enum queue_type { + QUEUE_REGULAR, + QUEUE_IMMEDIATE, +}; + class event_provider; namespace raii { diff --git a/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.cpp b/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.cpp index 886fd53db4b4c..8a94a3131c12d 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.cpp @@ -22,9 +22,14 @@ namespace v2 { provider_counter::provider_counter(ur_platform_handle_t platform, ur_context_handle_t context, - ur_device_handle_t device) { + queue_type queueType, + ur_device_handle_t device, + event_flags_t flags) + : queueType(queueType), flags(flags) { + assert(flags & EVENT_FLAGS_COUNTER); + ZE2UR_CALL_THROWS(zeDriverGetExtensionFunctionAddress, - (platform->ZeDriver, "zexCounterBasedEventCreate", + (platform->ZeDriver, "zexCounterBasedEventCreate2", (void **)&this->eventCreateFunc)); ZE2UR_CALL_THROWS(zelLoaderTranslateHandle, (ZEL_HANDLE_CONTEXT, context->getZeHandle(), @@ -34,17 +39,35 @@ provider_counter::provider_counter(ur_platform_handle_t platform, (ZEL_HANDLE_DEVICE, device->ZeDevice, (void **)&translatedDevice)); } +static zex_counter_based_event_exp_flags_t createZeFlags(queue_type queueType, + event_flags_t flags) { + zex_counter_based_event_exp_flags_t zeFlags = + ZEX_COUNTER_BASED_EVENT_FLAG_HOST_VISIBLE; + if (flags & EVENT_FLAGS_PROFILING_ENABLED) { + zeFlags |= ZE_EVENT_POOL_FLAG_KERNEL_TIMESTAMP; + } + + if (queueType == QUEUE_IMMEDIATE) { + zeFlags |= ZEX_COUNTER_BASED_EVENT_FLAG_IMMEDIATE; + } else { + zeFlags |= ZEX_COUNTER_BASED_EVENT_FLAG_NON_IMMEDIATE; + } + + return zeFlags; +} + raii::cache_borrowed_event provider_counter::allocate() { if (freelist.empty()) { - ZeStruct desc; - desc.index = 0; - desc.signal = ZE_EVENT_SCOPE_FLAG_HOST; - desc.wait = 0; + zex_counter_based_event_desc_t desc = {}; + desc.stype = ZEX_STRUCTURE_COUNTER_BASED_EVENT_DESC; + desc.flags = createZeFlags(queueType, flags); + desc.signalScope = ZE_EVENT_SCOPE_FLAG_HOST; + ze_event_handle_t handle; // TODO: allocate host and device buffers to use here - ZE2UR_CALL_THROWS(eventCreateFunc, (translatedContext, translatedDevice, - nullptr, nullptr, 0, &desc, &handle)); + ZE2UR_CALL_THROWS(eventCreateFunc, + (translatedContext, translatedDevice, &desc, &handle)); freelist.emplace_back(handle); } @@ -57,8 +80,6 @@ raii::cache_borrowed_event provider_counter::allocate() { [this](ze_event_handle_t handle) { freelist.push_back(handle); }); } -event_flags_t provider_counter::eventFlags() const { - return EVENT_FLAGS_COUNTER; -} +event_flags_t provider_counter::eventFlags() const { return flags; } } // namespace v2 diff --git a/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.hpp b/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.hpp index bb46cb5daf42b..baa4d5875507a 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_provider_counter.hpp @@ -25,22 +25,27 @@ #include "../device.hpp" +#include +#include + namespace v2 { typedef ze_result_t (*zexCounterBasedEventCreate)( ze_context_handle_t hContext, ze_device_handle_t hDevice, - uint64_t *deviceAddress, uint64_t *hostAddress, uint64_t completionValue, - const ze_event_desc_t *desc, ze_event_handle_t *phEvent); + const zex_counter_based_event_desc_t *desc, ze_event_handle_t *phEvent); class provider_counter : public event_provider { public: provider_counter(ur_platform_handle_t platform, ur_context_handle_t, - ur_device_handle_t); + queue_type, ur_device_handle_t, event_flags_t); raii::cache_borrowed_event allocate() override; event_flags_t eventFlags() const override; private: + queue_type queueType; + event_flags_t flags; + ze_context_handle_t translatedContext; ze_device_handle_t translatedDevice; diff --git a/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.cpp b/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.cpp index 6239f3f5f7412..06267059dc91b 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.cpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.cpp @@ -17,9 +17,8 @@ #include "event_provider.hpp" #include "event_provider_normal.hpp" -#include "../common/latency_tracker.hpp" - #include "../common.hpp" +#include "../common/latency_tracker.hpp" namespace v2 { static constexpr int EVENTS_BURST = 64; diff --git a/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.hpp b/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.hpp index 811b32f2e23f8..df6946e7ac580 100644 --- a/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.hpp +++ b/unified-runtime/source/adapters/level_zero/v2/event_provider_normal.hpp @@ -21,17 +21,13 @@ #include "common.hpp" #include "event.hpp" +#include "event_provider.hpp" #include "../device.hpp" #include "../ur_interface_loader.hpp" namespace v2 { -enum queue_type { - QUEUE_REGULAR, - QUEUE_IMMEDIATE, -}; - class provider_pool { public: provider_pool(ur_context_handle_t, queue_type, event_flags_t flags);