diff --git a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs index 3f0b9c182d2cf9..c348418d510c4d 100644 --- a/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs +++ b/src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs @@ -472,7 +472,7 @@ public void Release() { if (_comWrappers != null) { - _comWrappers.RemoveRCWFromCache(_externalComObject); + _comWrappers.RemoveRCWFromCache(_externalComObject, _proxyHandle); _comWrappers = null; } @@ -720,8 +720,18 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( { if (_rcwCache.TryGetValue(externalComObject, out GCHandle handle)) { - retValue = handle.Target; - return true; + object? cachedWrapper = handle.Target; + if (cachedWrapper is not null) + { + retValue = cachedWrapper; + return true; + } + else + { + // The GCHandle has been clear out but the NativeObjectWrapper + // finalizer has not yet run to remove the entry from _rcwCache + _rcwCache.Remove(externalComObject); + } } if (wrapperMaybe is not null) @@ -765,9 +775,21 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( using (LockHolder.Hold(_lock)) { + object? cachedWrapper = null; if (_rcwCache.TryGetValue(externalComObject, out var existingHandle)) { - retValue = existingHandle.Target; + cachedWrapper = existingHandle.Target; + if (cachedWrapper is null) + { + // The GCHandle has been clear out but the NativeObjectWrapper + // finalizer has not yet run to remove the entry from _rcwCache + _rcwCache.Remove(externalComObject); + } + } + + if (cachedWrapper is not null) + { + retValue = cachedWrapper; } else { @@ -788,11 +810,17 @@ private unsafe bool TryGetOrCreateObjectForComInstanceInternal( } #pragma warning restore IDE0060 - private void RemoveRCWFromCache(IntPtr comPointer) + private void RemoveRCWFromCache(IntPtr comPointer, GCHandle expectedValue) { using (LockHolder.Hold(_lock)) { - _rcwCache.Remove(comPointer); + // TryGetOrCreateObjectForComInstanceInternal may have put a new entry into the cache + // in the time between the GC cleared the contents of the GC handle but before the + // NativeObjectWrapper finializer ran. + if (_rcwCache.TryGetValue(comPointer, out GCHandle cachedValue) && expectedValue.Equals(cachedValue)) + { + _rcwCache.Remove(comPointer); + } } } diff --git a/src/tests/Interop/COM/ComWrappers/API/Program.cs b/src/tests/Interop/COM/ComWrappers/API/Program.cs index 3c1fbc83b1c5f0..106df3cec434e5 100644 --- a/src/tests/Interop/COM/ComWrappers/API/Program.cs +++ b/src/tests/Interop/COM/ComWrappers/API/Program.cs @@ -299,6 +299,36 @@ static void ValidateCreateObjectCachingScenario() Assert.NotEqual(trackerObj1, trackerObj3); } + // Make sure that if one wrapper is GCed, another can be created. + static void ValidateCreateObjectGcBehavior() + { + Console.WriteLine($"Running {nameof(ValidateCreateObjectCachingScenario)}..."); + + var cw = new TestComWrappers(); + + // Get an object from a tracker runtime. + IntPtr trackerObjRaw = MockReferenceTrackerRuntime.CreateTrackerObject(); + + // Create the first native object wrapper and run the GC. + CreateObject(); + GC.Collect(); + + // Try to create another wrapper for the same object. The above GC + // may have collected parts of the ComWrapper cache, but this should + // still work. + CreateObject(); + ForceGC(); + + Marshal.Release(trackerObjRaw); + + [MethodImpl(MethodImplOptions.NoInlining)] + void CreateObject() + { + var obj = (ITrackerObjectWrapper)cw.GetOrCreateObjectForComInstance(trackerObjRaw, CreateObjectFlags.None); + Assert.NotNull(obj); + } + } + static void ValidateMappingAPIs() { Console.WriteLine($"Running {nameof(ValidateMappingAPIs)}..."); @@ -777,6 +807,7 @@ static int Main() ValidateCreatingAComInterfaceForObjectAfterTheFirstIsFree(); ValidateFallbackQueryInterface(); ValidateCreateObjectCachingScenario(); + ValidateCreateObjectGcBehavior(); ValidateMappingAPIs(); ValidateWrappersInstanceIsolation(); ValidatePrecreatedExternalWrapper();