diff --git a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs index 0d062862653c36..a7ab5415d25d97 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.CoreCLR.cs @@ -140,7 +140,9 @@ public static partial class AsyncHelpers private struct RuntimeAsyncAwaitState { public Continuation? SentinelContinuation; + public ICriticalNotifyCompletion? CriticalNotifier; public INotifyCompletion? Notifier; + public Task? CalledTask; } [ThreadStatic] @@ -203,7 +205,21 @@ private static unsafe object AllocContinuationResultBox(void* ptr) return RuntimeTypeHandle.InternalAllocNoChecks((MethodTable*)pMT); } - private interface IThunkTaskOps + [BypassReadyToRun] + [MethodImpl(MethodImplOptions.NoInlining | MethodImplOptions.Async)] + [RequiresPreviewFeatures] + private static void TransparentAwaitTask(Task t) + { + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + Continuation? sentinelContinuation = state.SentinelContinuation; + if (sentinelContinuation == null) + state.SentinelContinuation = sentinelContinuation = new Continuation(); + + state.CalledTask = t; + AsyncSuspend(sentinelContinuation); + } + + private interface IRuntimeAsyncTaskOps { static abstract Action GetContinuationAction(T task); static abstract Continuation GetContinuationState(T task); @@ -212,9 +228,12 @@ private interface IThunkTaskOps static abstract void PostToSyncContext(T task, SynchronizationContext syncCtx); } - private sealed class ThunkTask : Task + /// + /// Represents a wrapped runtime async operation. + /// + private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction { - public ThunkTask() + public RuntimeAsyncTask() { // We use the base Task's state object field to store the Continuation while posting the task around. // Ensure that state object isn't published out for others to see. @@ -231,31 +250,38 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread) private void MoveNext() { - ThunkTaskCore.MoveNext, Ops>(this); + RuntimeAsyncTaskCore.DispatchContinuations, Ops>(this); } public void HandleSuspended() { - ThunkTaskCore.HandleSuspended, Ops>(this); + RuntimeAsyncTaskCore.HandleSuspended, Ops>(this); + } + + void ITaskCompletionAction.Invoke(Task completingTask) + { + MoveNext(); } + bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true; + private static readonly SendOrPostCallback s_postCallback = static state => { - Debug.Assert(state is ThunkTask); - ((ThunkTask)state).MoveNext(); + Debug.Assert(state is RuntimeAsyncTask); + ((RuntimeAsyncTask)state).MoveNext(); }; - private struct Ops : IThunkTaskOps> + private struct Ops : IRuntimeAsyncTaskOps> { - public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!; - public static void MoveNext(ThunkTask task) => task.MoveNext(); - public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!; - public static void SetContinuationState(ThunkTask task, Continuation value) + public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; + public static void MoveNext(RuntimeAsyncTask task) => task.MoveNext(); + public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!; + public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) { task.m_stateObject = value; } - public static bool SetCompleted(ThunkTask task, Continuation continuation) + public static bool SetCompleted(RuntimeAsyncTask task, Continuation continuation) { T result; if (RuntimeHelpers.IsReferenceOrContainsReferences()) @@ -277,16 +303,19 @@ public static bool SetCompleted(ThunkTask task, Continuation continuation) return task.TrySetResult(result); } - public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext) + public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext) { syncContext.Post(s_postCallback, task); } } } - private sealed class ThunkTask : Task + /// + /// Represents a wrapped runtime async operation. + /// + private sealed class RuntimeAsyncTask : Task, ITaskCompletionAction { - public ThunkTask() + public RuntimeAsyncTask() { // We use the base Task's state object field to store the Continuation while posting the task around. // Ensure that state object isn't published out for others to see. @@ -303,45 +332,52 @@ internal override void ExecuteFromThreadPool(Thread threadPoolThread) private void MoveNext() { - ThunkTaskCore.MoveNext(this); + RuntimeAsyncTaskCore.DispatchContinuations(this); } public void HandleSuspended() { - ThunkTaskCore.HandleSuspended(this); + RuntimeAsyncTaskCore.HandleSuspended(this); } + void ITaskCompletionAction.Invoke(Task completingTask) + { + MoveNext(); + } + + bool ITaskCompletionAction.InvokeMayRunArbitraryCode => true; + private static readonly SendOrPostCallback s_postCallback = static state => { - Debug.Assert(state is ThunkTask); - ((ThunkTask)state).MoveNext(); + Debug.Assert(state is RuntimeAsyncTask); + ((RuntimeAsyncTask)state).MoveNext(); }; - private struct Ops : IThunkTaskOps + private struct Ops : IRuntimeAsyncTaskOps { - public static Action GetContinuationAction(ThunkTask task) => (Action)task.m_action!; - public static void MoveNext(ThunkTask task) => task.MoveNext(); - public static Continuation GetContinuationState(ThunkTask task) => (Continuation)task.m_stateObject!; - public static void SetContinuationState(ThunkTask task, Continuation value) + public static Action GetContinuationAction(RuntimeAsyncTask task) => (Action)task.m_action!; + public static void MoveNext(RuntimeAsyncTask task) => task.MoveNext(); + public static Continuation GetContinuationState(RuntimeAsyncTask task) => (Continuation)task.m_stateObject!; + public static void SetContinuationState(RuntimeAsyncTask task, Continuation value) { task.m_stateObject = value; } - public static bool SetCompleted(ThunkTask task, Continuation continuation) + public static bool SetCompleted(RuntimeAsyncTask task, Continuation continuation) { return task.TrySetResult(); } - public static void PostToSyncContext(ThunkTask task, SynchronizationContext syncContext) + public static void PostToSyncContext(RuntimeAsyncTask task, SynchronizationContext syncContext) { syncContext.Post(s_postCallback, task); } } } - private static class ThunkTaskCore + private static class RuntimeAsyncTaskCore { - public static unsafe void MoveNext(T task) where T : Task where TOps : IThunkTaskOps + public static unsafe void DispatchContinuations(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps { ExecutionAndSyncBlockStore contexts = default; contexts.Push(); @@ -422,9 +458,20 @@ private static Continuation UnwindToPossibleHandler(Continuation continuation) } } - public static void HandleSuspended(T task) where T : Task where TOps : IThunkTaskOps + public static void HandleSuspended(T task) where T : Task, ITaskCompletionAction where TOps : IRuntimeAsyncTaskOps { - Continuation headContinuation = UnlinkHeadContinuation(out INotifyCompletion? notifier); + ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; + ICriticalNotifyCompletion? critNotifier = state.CriticalNotifier; + INotifyCompletion? notifier = state.Notifier; + Task? calledTask = state.CalledTask; + + state.CriticalNotifier = null; + state.Notifier = null; + state.CalledTask = null; + + Continuation sentinelContinuation = state.SentinelContinuation!; + Continuation headContinuation = sentinelContinuation.Next!; + sentinelContinuation.Next = null; // Head continuation should be the result of async call to AwaitAwaiter or UnsafeAwaitAwaiter. // These never have special continuation handling. @@ -438,9 +485,19 @@ public static void HandleSuspended(T task) where T : Task where TOps : try { - if (notifier is ICriticalNotifyCompletion crit) + if (critNotifier != null) + { + critNotifier.UnsafeOnCompleted(TOps.GetContinuationAction(task)); + } + else if (calledTask != null) { - crit.UnsafeOnCompleted(TOps.GetContinuationAction(task)); + // Runtime async callable wrapper for task returning + // method. This implements the context transparent + // forwarding and makes these wrappers minimal cost. + if (!calledTask.TryAddCompletionAction(task)) + { + ThreadPool.UnsafeQueueUserWorkItemInternal(task, preferLocal: true); + } } else { @@ -454,19 +511,7 @@ public static void HandleSuspended(T task) where T : Task where TOps : } } - private static Continuation UnlinkHeadContinuation(out INotifyCompletion? notifier) - { - ref RuntimeAsyncAwaitState state = ref t_runtimeAsyncAwaitState; - notifier = state.Notifier; - state.Notifier = null; - - Continuation sentinelContinuation = state.SentinelContinuation!; - Continuation head = sentinelContinuation.Next!; - sentinelContinuation.Next = null; - return head; - } - - private static bool QueueContinuationFollowUpActionIfNecessary(T task, Continuation continuation) where T : Task where TOps : IThunkTaskOps + private static bool QueueContinuationFollowUpActionIfNecessary(T task, Continuation continuation) where T : Task where TOps : IRuntimeAsyncTaskOps { if ((continuation.Flags & CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL) != 0) { @@ -554,7 +599,7 @@ private static bool QueueContinuationFollowUpActionIfNecessary(T task, continuation.Next = finalContinuation; - ThunkTask result = new(); + RuntimeAsyncTask result = new(); result.HandleSuspended(); return result; } @@ -567,7 +612,7 @@ private static Task FinalizeTaskReturningThunk(Continuation continuation) }; continuation.Next = finalContinuation; - ThunkTask result = new(); + RuntimeAsyncTask result = new(); result.HandleSuspended(); return result; } @@ -679,5 +724,16 @@ private static void CaptureContinuationContext(SynchronizationContext syncCtx, r flags |= CorInfoContinuationFlags.CORINFO_CONTINUATION_CONTINUE_ON_THREAD_POOL; } + + internal static T CompletedTaskResult(Task task) + { + TaskAwaiter.ValidateEnd(task); + return task.ResultOnSuccess; + } + + internal static void CompletedTask(Task task) + { + TaskAwaiter.ValidateEnd(task); + } } } diff --git a/src/coreclr/vm/asyncthunks.cpp b/src/coreclr/vm/asyncthunks.cpp index ffaddb452fb97f..cf4e290bc6e7c4 100644 --- a/src/coreclr/vm/asyncthunks.cpp +++ b/src/coreclr/vm/asyncthunks.cpp @@ -484,68 +484,28 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pAsyncOtherVariant, MetaSig& m // Implement IL that is effectively the following: /* { - TaskAwaiter awaiter = other(arg).GetAwaiter(); - if (!awaiter.IsCompleted) + Task task = other(arg); + if (!task.IsCompleted) { // Magic function which will suspend the current run of async methods - AsyncHelpers.UnsafeAwaitAwaiter>(awaiter); + AsyncHelpers.TransparentAwaitTask(task); } - return awaiter.GetResult(); + return AsyncHelpers.CompletedTaskResult(task); } - */ - ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); - - TypeHandle thTaskAwaiter; - MethodTable* pMTTask; - MethodDesc* mdGetAwaiter; - MethodDesc* mdIsCompleted; - MethodDesc* mdGetResult; - - bool isValueTask = IsValueTaskAsyncThunk(); - if (msig.IsReturnTypeVoid()) + For ValueTask: { - pMTTask = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK : CLASS__TASK); - thTaskAwaiter = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_AWAITER : CLASS__TASK_AWAITER); - mdGetAwaiter = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK__GET_AWAITER : METHOD__TASK__GET_AWAITER); - mdIsCompleted = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER__GET_ISCOMPLETED : METHOD__TASK_AWAITER__GET_ISCOMPLETED); - mdGetResult = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER__GET_RESULT : METHOD__TASK_AWAITER__GET_RESULT); - } - else - { - TypeHandle thLogicalRetType = msig.GetRetTypeHandleThrowing(); - MethodTable* pMTTaskOpen = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_1 : CLASS__TASK_1); - pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable(); - MethodTable* pMTTaskAwaiterOpen = CoreLibBinder::GetClass(isValueTask ? CLASS__VALUETASK_AWAITER_1 : CLASS__TASK_AWAITER_1); - - thTaskAwaiter = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskAwaiterOpen->GetModule(), pMTTaskAwaiterOpen->GetCl(), Instantiation(&thLogicalRetType, 1)); + ValueTask vt = other(arg); + if (vt.IsCompleted) + return vt.Result/vt.ThrowIfCompletedUnsuccessfully(); - mdGetAwaiter = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_1__GET_AWAITER : METHOD__TASK_1__GET_AWAITER); - mdGetAwaiter = MethodDesc::FindOrCreateAssociatedMethodDesc(mdGetAwaiter, pMTTask, FALSE, Instantiation(), FALSE); - - mdIsCompleted = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER_1__GET_ISCOMPLETED : METHOD__TASK_AWAITER_1__GET_ISCOMPLETED); - mdIsCompleted = MethodDesc::FindOrCreateAssociatedMethodDesc(mdIsCompleted, thTaskAwaiter.GetMethodTable(), FALSE, Instantiation(), FALSE); - - mdGetResult = CoreLibBinder::GetMethod(isValueTask ? METHOD__VALUETASK_AWAITER_1__GET_RESULT : METHOD__TASK_AWAITER_1__GET_RESULT); - mdGetResult = MethodDesc::FindOrCreateAssociatedMethodDesc(mdGetResult, thTaskAwaiter.GetMethodTable(), FALSE, Instantiation(), FALSE); - } - - DWORD localArg = 0; - ILCodeLabel* pGetResultLabel = pCode->NewCodeLabel(); - - LocalDesc awaiterLocalDesc(thTaskAwaiter); - DWORD awaiterLocal = pCode->NewLocal(awaiterLocalDesc); - - if (msig.HasThis()) - { - pCode->EmitLDARG(localArg++); - } - for (UINT iArg = 0; iArg < msig.NumFixedArgs(); iArg++) - { - pCode->EmitLDARG(localArg++); + Task task = vt.AsTask(); + } + */ + ILCodeStream* pCode = pSL->NewCodeStream(ILStubLinker::kDispatch); - int token; + int userFuncToken; _ASSERTE(!pAsyncOtherVariant->IsWrapperStub()); if (pAsyncOtherVariant->HasClassOrMethodInstantiation()) { @@ -588,95 +548,129 @@ void MethodDesc::EmitAsyncMethodThunk(MethodDesc* pAsyncOtherVariant, MetaSig& m DWORD sigLen; PCCOR_SIGNATURE sig = (PCCOR_SIGNATURE)methodSigBuilder.GetSignature(&sigLen); int methodSigToken = pCode->GetSigToken(sig, sigLen); - token = pCode->GetToken(pAsyncOtherVariant, typeSigToken, methodSigToken); + userFuncToken = pCode->GetToken(pAsyncOtherVariant, typeSigToken, methodSigToken); } else { - token = pCode->GetToken(pAsyncOtherVariant, typeSigToken); + userFuncToken = pCode->GetToken(pAsyncOtherVariant, typeSigToken); } } else { - token = pCode->GetToken(pAsyncOtherVariant); + userFuncToken = pCode->GetToken(pAsyncOtherVariant); } - pCode->EmitCALL(token, localArg, 1); - - int getAwaiterToken; - int getIsCompletedToken; - int getResultToken; - if (!msig.IsReturnTypeVoid()) + DWORD localArg = 0; + if (msig.HasThis()) { - getAwaiterToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, mdGetAwaiter); - getIsCompletedToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, mdIsCompleted); - getResultToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, mdGetResult); + pCode->EmitLDARG(localArg++); } - else + for (UINT iArg = 0; iArg < msig.NumFixedArgs(); iArg++) { - getAwaiterToken = pCode->GetToken(mdGetAwaiter); - getIsCompletedToken = pCode->GetToken(mdIsCompleted); - getResultToken = pCode->GetToken(mdGetResult); + pCode->EmitLDARG(localArg++); } - if (isValueTask) - { - LocalDesc valuetaskLocalDesc(pMTTask); - DWORD valuetaskLocal = pCode->NewLocal(valuetaskLocalDesc); - pCode->EmitSTLOC(valuetaskLocal); - pCode->EmitLDLOCA(valuetaskLocal); - pCode->EmitCALL(getAwaiterToken, 1, 1); - } - else + pCode->EmitCALL(userFuncToken, localArg, 1); + + TypeHandle thLogicalRetType = msig.GetRetTypeHandleThrowing(); + if (IsValueTaskAsyncThunk()) { - pCode->EmitCALLVIRT(getAwaiterToken, 1, 1); - } + // Emit + // if (vtask.IsCompleted) + // return vtask.Result/vtask.ThrowIfCompletedUnsuccessfully() + // task = vtask.AsTask() + // + + MethodTable* pMTValueTask; + int isCompletedToken; + int completionResultToken; + int asTaskToken; + if (msig.IsReturnTypeVoid()) + { + pMTValueTask = CoreLibBinder::GetClass(CLASS__VALUETASK); - pCode->EmitSTLOC(awaiterLocal); - pCode->EmitLDLOCA(awaiterLocal); - pCode->EmitCALL(getIsCompletedToken, 1, 1); - pCode->EmitBRTRUE(pGetResultLabel); - pCode->EmitLDLOC(awaiterLocal); + MethodDesc* pMDValueTaskIsCompleted = CoreLibBinder::GetMethod(METHOD__VALUETASK__GET_ISCOMPLETED); + MethodDesc* pMDCompletionResult = CoreLibBinder::GetMethod(METHOD__VALUETASK__THROW_IF_COMPLETED_UNSUCCESSFULLY); + MethodDesc* pMDAsTask = CoreLibBinder::GetMethod(METHOD__VALUETASK__AS_TASK); - int awaitAwaiterToken = GetTokenForAwaitAwaiterInstantiatedOverTaskAwaiterType(pCode, thTaskAwaiter); - pCode->EmitCALL(awaitAwaiterToken, 1, 0); - pCode->EmitLabel(pGetResultLabel); + isCompletedToken = pCode->GetToken(pMDValueTaskIsCompleted); + completionResultToken = pCode->GetToken(pMDCompletionResult); + asTaskToken = pCode->GetToken(pMDAsTask); + } + else + { + MethodTable* pMTValueTaskOpen = CoreLibBinder::GetClass(CLASS__VALUETASK_1); + pMTValueTask = ClassLoader::LoadGenericInstantiationThrowing(pMTValueTaskOpen->GetModule(), pMTValueTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable(); - pCode->EmitLDLOCA(awaiterLocal); - pCode->EmitCALL(getResultToken, 1, mdGetResult->IsVoid() ? 0 : 1); + MethodDesc* pMDValueTaskIsCompleted = CoreLibBinder::GetMethod(METHOD__VALUETASK_1__GET_ISCOMPLETED); + MethodDesc* pMDCompletionResult = CoreLibBinder::GetMethod(METHOD__VALUETASK_1__GET_RESULT); + MethodDesc* pMDAsTask = CoreLibBinder::GetMethod(METHOD__VALUETASK_1__AS_TASK); - pCode->EmitRET(); -} + pMDValueTaskIsCompleted = FindOrCreateAssociatedMethodDesc(pMDValueTaskIsCompleted, pMTValueTask, FALSE, Instantiation(), FALSE); + pMDCompletionResult = FindOrCreateAssociatedMethodDesc(pMDCompletionResult, pMTValueTask, FALSE, Instantiation(), FALSE); + pMDAsTask = FindOrCreateAssociatedMethodDesc(pMDAsTask, pMTValueTask, FALSE, Instantiation(), FALSE); -// Get a token for AsyncHelpers.UnsafeAwaitAwaiter>() -// with T substituted by the return type of the async method. -int MethodDesc::GetTokenForAwaitAwaiterInstantiatedOverTaskAwaiterType(ILCodeStream* pCode, TypeHandle taskAwaiterType) -{ - MethodDesc* awaitAwaiter = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__UNSAFE_AWAIT_AWAITER_1); - TypeHandle thInstantiations[]{ taskAwaiterType }; - awaitAwaiter = FindOrCreateAssociatedMethodDesc(awaitAwaiter, awaitAwaiter->GetMethodTable(), FALSE, Instantiation(thInstantiations, 1), FALSE); + isCompletedToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, pMDValueTaskIsCompleted); + completionResultToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, pMDCompletionResult); + asTaskToken = GetTokenForGenericTypeMethodCallWithAsyncReturnType(pCode, pMDAsTask); + } + + LocalDesc valueTaskLocalDesc(pMTValueTask); + DWORD valueTaskLocal = pCode->NewLocal(valueTaskLocalDesc); + ILCodeLabel* valueTaskNotCompletedLabel = pCode->NewCodeLabel(); + + // Store value task returned by call to actual user func + pCode->EmitSTLOC(valueTaskLocal); - if (!taskAwaiterType.IsSharedByGenericInstantiations()) + pCode->EmitLDLOCA(valueTaskLocal); + pCode->EmitCALL(isCompletedToken, 1, 1); + pCode->EmitBRFALSE(valueTaskNotCompletedLabel); + + pCode->EmitLDLOCA(valueTaskLocal); + pCode->EmitCALL(completionResultToken, 1, msig.IsReturnTypeVoid() ? 0 : 1); + pCode->EmitRET(); + + pCode->EmitLabel(valueTaskNotCompletedLabel); + pCode->EmitLDLOCA(valueTaskLocal); + pCode->EmitCALL(asTaskToken, 1, 1); + } + + MethodTable* pMTTask; + + int completedTaskResultToken; + if (msig.IsReturnTypeVoid()) { - return pCode->GetToken(awaitAwaiter); + pMTTask = CoreLibBinder::GetClass(CLASS__TASK); + + MethodDesc* pMDCompletedTask = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK); + completedTaskResultToken = pCode->GetToken(pMDCompletedTask); } + else + { + MethodTable* pMTTaskOpen = CoreLibBinder::GetClass(CLASS__TASK_1); + pMTTask = ClassLoader::LoadGenericInstantiationThrowing(pMTTaskOpen->GetModule(), pMTTaskOpen->GetCl(), Instantiation(&thLogicalRetType, 1)).GetMethodTable(); - SigBuilder methodSigBuilder; - methodSigBuilder.AppendByte(IMAGE_CEE_CS_CALLCONV_GENERICINST); - methodSigBuilder.AppendData(1); - SigPointer retTypeSig = GetAsyncThunkResultTypeSig(); - PCCOR_SIGNATURE retTypeSigRaw; - uint32_t retTypeSigLen; - retTypeSig.GetSignature(&retTypeSigRaw, &retTypeSigLen); + MethodDesc* pMDCompletedTaskResult = CoreLibBinder::GetMethod(METHOD__ASYNC_HELPERS__COMPLETED_TASK_RESULT); + pMDCompletedTaskResult = FindOrCreateAssociatedMethodDesc(pMDCompletedTaskResult, pMDCompletedTaskResult->GetMethodTable(), FALSE, Instantiation(&thLogicalRetType, 1), FALSE); + completedTaskResultToken = GetTokenForGenericMethodCallWithAsyncReturnType(pCode, pMDCompletedTaskResult); + } - methodSigBuilder.AppendElementType(ELEMENT_TYPE_GENERICINST); - methodSigBuilder.AppendElementType(ELEMENT_TYPE_INTERNAL); - methodSigBuilder.AppendPointer(taskAwaiterType.GetMethodTable()); - methodSigBuilder.AppendData(1); - methodSigBuilder.AppendBlob((const PVOID)retTypeSigRaw, retTypeSigLen); + LocalDesc taskLocalDesc(pMTTask); + DWORD taskLocal = pCode->NewLocal(taskLocalDesc); + ILCodeLabel* pGetResultLabel = pCode->NewCodeLabel(); - DWORD methodSigLen; - PCCOR_SIGNATURE methodSig = (PCCOR_SIGNATURE)methodSigBuilder.GetSignature(&methodSigLen); - int methodSigToken = pCode->GetSigToken(methodSig, methodSigLen); + // Store task returned by actual user func or by ValueTask.AsTask + pCode->EmitSTLOC(taskLocal); - return pCode->GetToken(awaitAwaiter, mdTokenNil, methodSigToken); + pCode->EmitLDLOC(taskLocal); + pCode->EmitCALL(METHOD__TASK__GET_ISCOMPLETED, 1, 1); + pCode->EmitBRTRUE(pGetResultLabel); + + pCode->EmitLDLOC(taskLocal); + pCode->EmitCALL(METHOD__ASYNC_HELPERS__TRANSPARENT_AWAIT_TASK, 1, 0); + + pCode->EmitLabel(pGetResultLabel); + pCode->EmitLDLOC(taskLocal); + pCode->EmitCALL(completedTaskResultToken, 1, msig.IsReturnTypeVoid() ? 0 : 1); + pCode->EmitRET(); } diff --git a/src/coreclr/vm/corelib.h b/src/coreclr/vm/corelib.h index 89428c2b6cd9e0..bfc6cc586ca38c 100644 --- a/src/coreclr/vm/corelib.h +++ b/src/coreclr/vm/corelib.h @@ -344,40 +344,28 @@ DEFINE_CLASS(THREAD_START_EXCEPTION,Threading, ThreadStartException DEFINE_METHOD(THREAD_START_EXCEPTION,EX_CTOR, .ctor, IM_Exception_RetVoid) DEFINE_CLASS(VALUETASK_1, Tasks, ValueTask`1) -DEFINE_METHOD(VALUETASK_1, GET_AWAITER, GetAwaiter, NoSig) +DEFINE_METHOD(VALUETASK_1, GET_ISCOMPLETED, get_IsCompleted, NoSig) +DEFINE_METHOD(VALUETASK_1, GET_RESULT, get_Result, NoSig) +DEFINE_METHOD(VALUETASK_1, AS_TASK, AsTask, IM_RetTaskOfT) DEFINE_CLASS(VALUETASK, Tasks, ValueTask) DEFINE_METHOD(VALUETASK, FROM_EXCEPTION, FromException, SM_Exception_RetValueTask) DEFINE_METHOD(VALUETASK, FROM_EXCEPTION_1, FromException, GM_Exception_RetValueTaskOfT) DEFINE_METHOD(VALUETASK, FROM_RESULT_T, FromResult, GM_T_RetValueTaskOfT) DEFINE_METHOD(VALUETASK, GET_COMPLETED_TASK, get_CompletedTask, SM_RetValueTask) -DEFINE_METHOD(VALUETASK, GET_AWAITER, GetAwaiter, NoSig) +DEFINE_METHOD(VALUETASK, GET_ISCOMPLETED, get_IsCompleted, NoSig) +DEFINE_METHOD(VALUETASK, THROW_IF_COMPLETED_UNSUCCESSFULLY, ThrowIfCompletedUnsuccessfully, NoSig) +DEFINE_METHOD(VALUETASK, AS_TASK, AsTask, IM_RetTask) DEFINE_CLASS(TASK_1, Tasks, Task`1) -DEFINE_METHOD(TASK_1, GET_AWAITER, GetAwaiter, NoSig) +DEFINE_METHOD(TASK_1, GET_RESULTONSUCCESS, get_ResultOnSuccess, NoSig) DEFINE_CLASS(TASK, Tasks, Task) DEFINE_METHOD(TASK, FROM_EXCEPTION, FromException, SM_Exception_RetTask) DEFINE_METHOD(TASK, FROM_EXCEPTION_1, FromException, GM_Exception_RetTaskOfT) DEFINE_METHOD(TASK, FROM_RESULT_T, FromResult, GM_T_RetTaskOfT) DEFINE_METHOD(TASK, GET_COMPLETED_TASK, get_CompletedTask, SM_RetTask) -DEFINE_METHOD(TASK, GET_AWAITER, GetAwaiter, NoSig) - -DEFINE_CLASS(TASK_AWAITER_1, CompilerServices, TaskAwaiter`1) -DEFINE_METHOD(TASK_AWAITER_1, GET_ISCOMPLETED, get_IsCompleted, NoSig) -DEFINE_METHOD(TASK_AWAITER_1, GET_RESULT, GetResult, NoSig) - -DEFINE_CLASS(TASK_AWAITER, CompilerServices, TaskAwaiter) -DEFINE_METHOD(TASK_AWAITER, GET_ISCOMPLETED, get_IsCompleted, NoSig) -DEFINE_METHOD(TASK_AWAITER, GET_RESULT, GetResult, NoSig) - -DEFINE_CLASS(VALUETASK_AWAITER_1, CompilerServices, ValueTaskAwaiter`1) -DEFINE_METHOD(VALUETASK_AWAITER_1, GET_ISCOMPLETED, get_IsCompleted, NoSig) -DEFINE_METHOD(VALUETASK_AWAITER_1, GET_RESULT, GetResult, NoSig) - -DEFINE_CLASS(VALUETASK_AWAITER, CompilerServices, ValueTaskAwaiter) -DEFINE_METHOD(VALUETASK_AWAITER, GET_ISCOMPLETED, get_IsCompleted, NoSig) -DEFINE_METHOD(VALUETASK_AWAITER, GET_RESULT, GetResult, NoSig) +DEFINE_METHOD(TASK, GET_ISCOMPLETED, get_IsCompleted, NoSig) DEFINE_CLASS(TYPE_HANDLE, System, RuntimeTypeHandle) DEFINE_CLASS(RT_TYPE_HANDLE, System, RuntimeTypeHandle) @@ -743,6 +731,9 @@ DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION, ValueTaskFromExcepti DEFINE_METHOD(ASYNC_HELPERS, VALUETASK_FROM_EXCEPTION_1, ValueTaskFromException, GM_Exception_RetValueTaskOfT) DEFINE_METHOD(ASYNC_HELPERS, UNSAFE_AWAIT_AWAITER_1, UnsafeAwaitAwaiter, GM_T_RetVoid) +DEFINE_METHOD(ASYNC_HELPERS, TRANSPARENT_AWAIT_TASK, TransparentAwaitTask, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, COMPLETED_TASK_RESULT, CompletedTaskResult, NoSig) +DEFINE_METHOD(ASYNC_HELPERS, COMPLETED_TASK, CompletedTask, NoSig) DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_EXECUTION_CONTEXT, CaptureExecutionContext, NoSig) DEFINE_METHOD(ASYNC_HELPERS, RESTORE_EXECUTION_CONTEXT, RestoreExecutionContext, NoSig) DEFINE_METHOD(ASYNC_HELPERS, CAPTURE_CONTINUATION_CONTEXT, CaptureContinuationContext, NoSig) diff --git a/src/coreclr/vm/metasig.h b/src/coreclr/vm/metasig.h index ddb9b5ea403e63..94157031ee6489 100644 --- a/src/coreclr/vm/metasig.h +++ b/src/coreclr/vm/metasig.h @@ -559,15 +559,9 @@ DEFINE_METASIG_T(IM(Dec_RetVoid, g(DECIMAL), v)) DEFINE_METASIG_T(IM(Currency_RetVoid, g(CURRENCY), v)) DEFINE_METASIG_T(SM(RefDec_RetVoid, r(g(DECIMAL)), v)) -DEFINE_METASIG_T(IM(Exception_RetTaskOfT, C(EXCEPTION), GI(C(TASK_1), 1, G(0)))) -DEFINE_METASIG_T(IM(T_RetTaskOfT, G(0), GI(C(TASK_1), 1, G(0)))) - -DEFINE_METASIG_T(IM(Exception_RetTask, C(EXCEPTION), C(TASK))) +DEFINE_METASIG_T(IM(RetTaskOfT, _, GI(C(TASK_1), 1, G(0)))) DEFINE_METASIG_T(IM(RetTask, _, C(TASK))) -DEFINE_METASIG_T(IM(Exception_RetValueTaskOfT, C(EXCEPTION), GI(g(VALUETASK_1), 1, G(0)))) -DEFINE_METASIG_T(IM(T_RetValueTaskOfT, G(0), GI(g(VALUETASK_1), 1, G(0)))) - DEFINE_METASIG_T(IM(Exception_RetValueTask, C(EXCEPTION), g(VALUETASK))) DEFINE_METASIG_T(IM(RetValueTask, _, g(VALUETASK))) diff --git a/src/coreclr/vm/method.hpp b/src/coreclr/vm/method.hpp index 4b717f485d51eb..f56eddccf9012d 100644 --- a/src/coreclr/vm/method.hpp +++ b/src/coreclr/vm/method.hpp @@ -2126,7 +2126,6 @@ class MethodDesc bool IsValueTaskAsyncThunk(); int GetTokenForGenericMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); int GetTokenForGenericTypeMethodCallWithAsyncReturnType(ILCodeStream* pCode, MethodDesc* md); - int GetTokenForAwaitAwaiterInstantiatedOverTaskAwaiterType(ILCodeStream* pCode, TypeHandle taskAwaiterType); public: static void CreateDerivedTargetSigWithExtraParams(MetaSig& msig, SigBuilder* stubSigBuilder); bool TryGenerateTransientILImplementation(DynamicResolver** resolver, COR_ILMETHOD_DECODER** methodILDecoder); diff --git a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs index 7b2a94ce56b175..e2ecf8aee702e7 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Runtime/CompilerServices/AsyncHelpers.cs @@ -44,7 +44,7 @@ public static void UnsafeAwaitAwaiter(TAwaiter awaiter) where TAwaiter if (sentinelContinuation == null) state.SentinelContinuation = sentinelContinuation = new Continuation(); - state.Notifier = awaiter; + state.CriticalNotifier = awaiter; AsyncSuspend(sentinelContinuation); } diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs index 1f3020782e754b..1cd3b2e3b540dd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/Tasks/Task.cs @@ -4532,6 +4532,9 @@ internal void AddCompletionAction(ITaskCompletionAction action, bool addBeforeOt action.Invoke(this); // run the action directly if we failed to queue the continuation (i.e., the task completed) } + internal bool TryAddCompletionAction(ITaskCompletionAction action, bool addBeforeOthers = false) + => AddTaskContinuation(action, addBeforeOthers); + // Support method for AddTaskContinuation that takes care of multi-continuation logic. // Returns true if and only if the continuation was successfully queued. private bool AddTaskContinuationComplex(object tc, bool addBeforeOthers) diff --git a/src/tests/async/synchronization-context/synchronization-context.cs b/src/tests/async/synchronization-context/synchronization-context.cs index 3e3a3cd70d38d6..33280981cebee7 100644 --- a/src/tests/async/synchronization-context/synchronization-context.cs +++ b/src/tests/async/synchronization-context/synchronization-context.cs @@ -67,8 +67,11 @@ private static async Task WrappedYieldToThreadWithCustomSyncContext() private class MySyncContext : SynchronizationContext { + public int NumPosts; + public override void Post(SendOrPostCallback d, object state) { + NumPosts++; ThreadPool.UnsafeQueueUserWorkItem(_ => { SynchronizationContext prevContext = Current; @@ -216,4 +219,26 @@ private static async Task SetContext(SynchronizationContext context, bool suspen if (suspend) await Task.Yield(); } + + [Fact] + public static void TestNoSyncContextInRuntimeCallableThunk() + { + SynchronizationContext prevContext = SynchronizationContext.Current; + try + { + SynchronizationContext.SetSynchronizationContext(new MySyncContext()); + TestNoSyncContextInRuntimeCallableThunkAsync().GetAwaiter().GetResult(); + } + finally + { + SynchronizationContext.SetSynchronizationContext(prevContext); + } + } + + private static async Task TestNoSyncContextInRuntimeCallableThunkAsync() + { + MySyncContext syncCtx = (MySyncContext)SynchronizationContext.Current; + await Task.Delay(100).ConfigureAwait(false); + Assert.Equal(0, syncCtx.NumPosts); + } } diff --git a/src/tests/async/valuetask/valuetask.cs b/src/tests/async/valuetask/valuetask.cs index 4740181e27394f..1bdd69f9e6120a 100644 --- a/src/tests/async/valuetask/valuetask.cs +++ b/src/tests/async/valuetask/valuetask.cs @@ -6,15 +6,15 @@ using System.Threading.Tasks; using Xunit; -public class Async2valuetask +public class Async2ValueTask { [Fact] - public static int TestEntryPoint() + public static int TestBasic() { - return (int)AsyncTestEntryPoint(100).Result; + return (int)AsyncTestBasicEntryPoint(100).Result; } - private static ValueTask AsyncTestEntryPoint(int arg) + private static ValueTask AsyncTestBasicEntryPoint(int arg) { return M1(arg); } @@ -24,4 +24,38 @@ private static async ValueTask M1(int arg) await Task.Yield(); return arg; } + + [Fact] + public static void RuntimeAsyncCallableThunks() + { + RuntimeAsyncCallableThunksAsync().GetAwaiter().GetResult(); + } + + private static async ValueTask RuntimeAsyncCallableThunksAsync() + { + int result = await Foo(); + Assert.Equal(123, result); + await Bar(); + result = await Baz(); + Assert.Equal(456, result); + string strResult = await Beef(); + Assert.Equal("foo", strResult); + } + + private static ValueTask Foo() => new ValueTask(123); + private static ValueTask Bar() => ValueTask.CompletedTask; + + [RuntimeAsyncMethodGeneration(false)] + private static async ValueTask Baz() + { + await Task.Yield(); + return 456; + } + + [RuntimeAsyncMethodGeneration(false)] + private static async ValueTask Beef() + { + await Task.Yield(); + return "foo"; + } }