Skip to content

Commit 4df29c9

Browse files
adding span version of ReceiveMessageFrom (#46285)
* adding span version of ReceiveMessageFrom fix #43933 * Started using new testsetup for the receivemessagefrom * removed code duplication in socketpal.windows * removed unused buffers parameter * adding non-zero offset to the ReceiveMessageFrom test * added documentation * offset fix for EAP tests * fixing pr comments
1 parent a56f84b commit 4df29c9

File tree

8 files changed

+221
-11
lines changed

8 files changed

+221
-11
lines changed

src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,7 @@ public void Listen(int backlog) { }
388388
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveFromResult> ReceiveFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
389389
public bool ReceiveFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
390390
public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
391+
public int ReceiveMessageFrom(System.Span<byte> buffer, ref System.Net.Sockets.SocketFlags socketFlags, ref System.Net.EndPoint remoteEP, out System.Net.Sockets.IPPacketInformation ipPacketInformation) { throw null; }
391392
public System.Threading.Tasks.Task<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.ArraySegment<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint) { throw null; }
392393
public System.Threading.Tasks.ValueTask<System.Net.Sockets.SocketReceiveMessageFromResult> ReceiveMessageFromAsync(System.Memory<byte> buffer, System.Net.Sockets.SocketFlags socketFlags, System.Net.EndPoint remoteEndPoint, System.Threading.CancellationToken cancellationToken = default) { throw null; }
393394
public bool ReceiveMessageFromAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }

src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1603,6 +1603,100 @@ public int ReceiveMessageFrom(byte[] buffer, int offset, int size, ref SocketFla
16031603
return bytesTransferred;
16041604
}
16051605

1606+
/// <summary>
1607+
/// Receives the specified number of bytes of data into the specified location of the data buffer,
1608+
/// using the specified <paramref name="socketFlags"/>, and stores the endpoint and packet information.
1609+
/// </summary>
1610+
/// <param name="buffer">
1611+
/// An <see cref="Span{T}"/> of type <see cref="byte"/> that is the storage location for received data.
1612+
/// </param>
1613+
/// <param name="socketFlags">
1614+
/// A bitwise combination of the <see cref="SocketFlags"/> values.
1615+
/// </param>
1616+
/// <param name="remoteEP">
1617+
/// An <see cref="EndPoint"/>, passed by reference, that represents the remote server.
1618+
/// </param>
1619+
/// <param name="ipPacketInformation">
1620+
/// An <see cref="IPPacketInformation"/> holding address and interface information.
1621+
/// </param>
1622+
/// <returns>
1623+
/// The number of bytes received.
1624+
/// </returns>
1625+
/// <exception cref="ObjectDisposedException">The <see cref="Socket"/> object has been closed.</exception>
1626+
/// <exception cref="ArgumentNullException">The <see cref="EndPoint"/> remoteEP is null.</exception>
1627+
/// <exception cref="ArgumentException">The <see cref="AddressFamily"/> of the <see cref="EndPoint"/> used in
1628+
/// <see cref="Socket.ReceiveMessageFrom(Span{byte}, ref SocketFlags, ref EndPoint, out IPPacketInformation)"/>
1629+
/// needs to match the <see cref="AddressFamily"/> of the <see cref="EndPoint"/> used in SendTo.</exception>
1630+
/// <exception cref="InvalidOperationException">
1631+
/// <para>The <see cref="Socket"/> object is not in blocking mode and cannot accept this synchronous call.</para>
1632+
/// <para>You must call the Bind method before performing this operation.</para></exception>
1633+
public int ReceiveMessageFrom(Span<byte> buffer, ref SocketFlags socketFlags, ref EndPoint remoteEP, out IPPacketInformation ipPacketInformation)
1634+
{
1635+
ThrowIfDisposed();
1636+
1637+
if (remoteEP == null)
1638+
{
1639+
throw new ArgumentNullException(nameof(remoteEP));
1640+
}
1641+
if (!CanTryAddressFamily(remoteEP.AddressFamily))
1642+
{
1643+
throw new ArgumentException(SR.Format(SR.net_InvalidEndPointAddressFamily, remoteEP.AddressFamily, _addressFamily), nameof(remoteEP));
1644+
}
1645+
if (_rightEndPoint == null)
1646+
{
1647+
throw new InvalidOperationException(SR.net_sockets_mustbind);
1648+
}
1649+
1650+
SocketPal.CheckDualModeReceiveSupport(this);
1651+
ValidateBlockingMode();
1652+
1653+
// We don't do a CAS demand here because the contents of remoteEP aren't used by
1654+
// WSARecvMsg; all that matters is that we generate a unique-to-this-call SocketAddress
1655+
// with the right address family.
1656+
EndPoint endPointSnapshot = remoteEP;
1657+
Internals.SocketAddress socketAddress = Serialize(ref endPointSnapshot);
1658+
1659+
// Save a copy of the original EndPoint.
1660+
Internals.SocketAddress socketAddressOriginal = IPEndPointExtensions.Serialize(endPointSnapshot);
1661+
1662+
SetReceivingPacketInformation();
1663+
1664+
Internals.SocketAddress receiveAddress;
1665+
int bytesTransferred;
1666+
SocketError errorCode = SocketPal.ReceiveMessageFrom(this, _handle, buffer, ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred);
1667+
1668+
UpdateReceiveSocketErrorForDisposed(ref errorCode, bytesTransferred);
1669+
// Throw an appropriate SocketException if the native call fails.
1670+
if (errorCode != SocketError.Success && errorCode != SocketError.MessageSize)
1671+
{
1672+
UpdateStatusAfterSocketErrorAndThrowException(errorCode);
1673+
}
1674+
else if (SocketsTelemetry.Log.IsEnabled())
1675+
{
1676+
SocketsTelemetry.Log.BytesReceived(bytesTransferred);
1677+
if (errorCode == SocketError.Success && SocketType == SocketType.Dgram) SocketsTelemetry.Log.DatagramReceived();
1678+
}
1679+
1680+
if (!socketAddressOriginal.Equals(receiveAddress))
1681+
{
1682+
try
1683+
{
1684+
remoteEP = endPointSnapshot.Create(receiveAddress);
1685+
}
1686+
catch
1687+
{
1688+
}
1689+
if (_rightEndPoint == null)
1690+
{
1691+
// Save a copy of the EndPoint so we can use it for Create().
1692+
_rightEndPoint = endPointSnapshot;
1693+
}
1694+
}
1695+
1696+
if (NetEventSource.Log.IsEnabled()) NetEventSource.Error(this, errorCode);
1697+
return bytesTransferred;
1698+
}
1699+
16061700
// Receives a datagram into a specific location in the data buffer and stores
16071701
// the end point.
16081702
public int ReceiveFrom(byte[] buffer, int offset, int size, SocketFlags socketFlags, ref EndPoint remoteEP)

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketAsyncContext.Unix.cs

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,31 @@ public override void InvokeCallback(bool allowPooling) =>
565565
Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode);
566566
}
567567

568+
private sealed unsafe class BufferPtrReceiveMessageFromOperation : ReadOperation
569+
{
570+
public byte* BufferPtr;
571+
public int Length;
572+
public SocketFlags Flags;
573+
public int BytesTransferred;
574+
public SocketFlags ReceivedFlags;
575+
576+
public bool IsIPv4;
577+
public bool IsIPv6;
578+
public IPPacketInformation IPPacketInformation;
579+
580+
public BufferPtrReceiveMessageFromOperation(SocketAsyncContext context) : base(context) { }
581+
582+
protected sealed override void Abort() { }
583+
584+
public Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError>? Callback { get; set; }
585+
586+
protected override bool DoTryComplete(SocketAsyncContext context) =>
587+
SocketPal.TryCompleteReceiveMessageFrom(context._socket, new Span<byte>(BufferPtr, Length), null, Flags, SocketAddress!, ref SocketAddressLen, IsIPv4, IsIPv6, out BytesTransferred, out ReceivedFlags, out IPPacketInformation, out ErrorCode);
588+
589+
public override void InvokeCallback(bool allowPooling) =>
590+
Callback!(BytesTransferred, SocketAddress!, SocketAddressLen, ReceivedFlags, IPPacketInformation, ErrorCode);
591+
}
592+
568593
private sealed class AcceptOperation : ReadOperation
569594
{
570595
public IntPtr AcceptedFileDescriptor;
@@ -1696,15 +1721,15 @@ public SocketError ReceiveFromAsync(IList<ArraySegment<byte>> buffers, SocketFla
16961721
}
16971722

16981723
public SocketError ReceiveMessageFrom(
1699-
Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
1724+
Memory<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
17001725
{
17011726
Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}");
17021727

17031728
SocketFlags receivedFlags;
17041729
SocketError errorCode;
17051730
int observedSequenceNumber;
17061731
if (_receiveQueue.IsReady(this, out observedSequenceNumber) &&
1707-
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, buffers, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
1732+
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer.Span, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
17081733
!ShouldRetrySyncOperation(out errorCode)))
17091734
{
17101735
flags = receivedFlags;
@@ -1714,7 +1739,7 @@ public SocketError ReceiveMessageFrom(
17141739
var operation = new ReceiveMessageFromOperation(this)
17151740
{
17161741
Buffer = buffer,
1717-
Buffers = buffers,
1742+
Buffers = null,
17181743
Flags = flags,
17191744
SocketAddress = socketAddress,
17201745
SocketAddressLen = socketAddressLen,
@@ -1731,6 +1756,45 @@ public SocketError ReceiveMessageFrom(
17311756
return operation.ErrorCode;
17321757
}
17331758

1759+
public unsafe SocketError ReceiveMessageFrom(
1760+
Span<byte> buffer, ref SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, int timeout, out IPPacketInformation ipPacketInformation, out int bytesReceived)
1761+
{
1762+
Debug.Assert(timeout == -1 || timeout > 0, $"Unexpected timeout: {timeout}");
1763+
1764+
SocketFlags receivedFlags;
1765+
SocketError errorCode;
1766+
int observedSequenceNumber;
1767+
if (_receiveQueue.IsReady(this, out observedSequenceNumber) &&
1768+
(SocketPal.TryCompleteReceiveMessageFrom(_socket, buffer, null, flags, socketAddress, ref socketAddressLen, isIPv4, isIPv6, out bytesReceived, out receivedFlags, out ipPacketInformation, out errorCode) ||
1769+
!ShouldRetrySyncOperation(out errorCode)))
1770+
{
1771+
flags = receivedFlags;
1772+
return errorCode;
1773+
}
1774+
1775+
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer))
1776+
{
1777+
var operation = new BufferPtrReceiveMessageFromOperation(this)
1778+
{
1779+
BufferPtr = bufferPtr,
1780+
Length = buffer.Length,
1781+
Flags = flags,
1782+
SocketAddress = socketAddress,
1783+
SocketAddressLen = socketAddressLen,
1784+
IsIPv4 = isIPv4,
1785+
IsIPv6 = isIPv6,
1786+
};
1787+
1788+
PerformSyncOperation(ref _receiveQueue, operation, timeout, observedSequenceNumber);
1789+
1790+
socketAddressLen = operation.SocketAddressLen;
1791+
flags = operation.ReceivedFlags;
1792+
ipPacketInformation = operation.IPPacketInformation;
1793+
bytesReceived = operation.BytesTransferred;
1794+
return operation.ErrorCode;
1795+
}
1796+
}
1797+
17341798
public SocketError ReceiveMessageFromAsync(Memory<byte> buffer, IList<ArraySegment<byte>>? buffers, SocketFlags flags, byte[] socketAddress, ref int socketAddressLen, bool isIPv4, bool isIPv6, out int bytesReceived, out SocketFlags receivedFlags, out IPPacketInformation ipPacketInformation, Action<int, byte[], int, SocketFlags, IPPacketInformation, SocketError> callback, CancellationToken cancellationToken = default)
17351799
{
17361800
SetNonBlocking();

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Unix.cs

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han
11721172
SocketError errorCode;
11731173
if (!handle.IsNonBlocking)
11741174
{
1175-
errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory<byte>(buffer, offset, count), null, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
1175+
errorCode = handle.AsyncContext.ReceiveMessageFrom(new Memory<byte>(buffer, offset, count), ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
11761176
}
11771177
else
11781178
{
@@ -1187,6 +1187,33 @@ public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle han
11871187
return errorCode;
11881188
}
11891189

1190+
1191+
public static SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span<byte> buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
1192+
{
1193+
byte[] socketAddressBuffer = socketAddress.Buffer;
1194+
int socketAddressLen = socketAddress.Size;
1195+
1196+
bool isIPv4, isIPv6;
1197+
Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out isIPv4, out isIPv6);
1198+
1199+
SocketError errorCode;
1200+
if (!handle.IsNonBlocking)
1201+
{
1202+
errorCode = handle.AsyncContext.ReceiveMessageFrom(buffer, ref socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, handle.ReceiveTimeout, out ipPacketInformation, out bytesTransferred);
1203+
}
1204+
else
1205+
{
1206+
if (!TryCompleteReceiveMessageFrom(handle, buffer, null, socketFlags, socketAddressBuffer, ref socketAddressLen, isIPv4, isIPv6, out bytesTransferred, out socketFlags, out ipPacketInformation, out errorCode))
1207+
{
1208+
errorCode = SocketError.WouldBlock;
1209+
}
1210+
}
1211+
1212+
socketAddress.InternalSize = socketAddressLen;
1213+
receiveAddress = socketAddress;
1214+
return errorCode;
1215+
}
1216+
11901217
public static SocketError ReceiveFrom(SafeSocketHandle handle, byte[] buffer, int offset, int count, SocketFlags socketFlags, byte[] socketAddress, ref int socketAddressLen, out int bytesTransferred)
11911218
{
11921219
if (!handle.IsNonBlocking)

src/libraries/System.Net.Sockets/src/System/Net/Sockets/SocketPal.Windows.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -451,14 +451,19 @@ public static unsafe IPPacketInformation GetIPPacketInformation(Interop.Winsock.
451451
}
452452

453453
public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, byte[] buffer, int offset, int size, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
454+
{
455+
return ReceiveMessageFrom(socket, handle, new Span<byte>(buffer, offset, size), ref socketFlags, socketAddress, out receiveAddress, out ipPacketInformation, out bytesTransferred);
456+
}
457+
458+
public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHandle handle, Span<byte> buffer, ref SocketFlags socketFlags, Internals.SocketAddress socketAddress, out Internals.SocketAddress receiveAddress, out IPPacketInformation ipPacketInformation, out int bytesTransferred)
454459
{
455460
bool ipv4, ipv6;
456461
Socket.GetIPProtocolInformation(socket.AddressFamily, socketAddress, out ipv4, out ipv6);
457462

458463
bytesTransferred = 0;
459464
receiveAddress = socketAddress;
460465
ipPacketInformation = default(IPPacketInformation);
461-
fixed (byte* ptrBuffer = buffer)
466+
fixed (byte* bufferPtr = &MemoryMarshal.GetReference(buffer))
462467
fixed (byte* ptrSocketAddress = socketAddress.Buffer)
463468
{
464469
Interop.Winsock.WSAMsg wsaMsg;
@@ -467,8 +472,8 @@ public static unsafe SocketError ReceiveMessageFrom(Socket socket, SafeSocketHan
467472
wsaMsg.flags = socketFlags;
468473

469474
WSABuffer wsaBuffer;
470-
wsaBuffer.Length = size;
471-
wsaBuffer.Pointer = (IntPtr)(ptrBuffer + offset);
475+
wsaBuffer.Length = buffer.Length;
476+
wsaBuffer.Pointer = (IntPtr)bufferPtr;
472477
wsaMsg.buffers = (IntPtr)(&wsaBuffer);
473478
wsaMsg.count = 1;
474479

src/libraries/System.Net.Sockets/tests/FunctionalTests/ReceiveMessageFrom.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ public async Task ReceiveSent_TCP_Success(bool ipv6)
5454
[InlineData(true)]
5555
public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
5656
{
57+
// [ActiveIssue("https://github.com/dotnet/runtime/issues/47637")]
58+
int Offset = UsesSync || !PlatformDetection.IsWindows ? 10 : 0;
5759
const int DatagramSize = 256;
5860
const int DatagramsToSend = 16;
5961

@@ -69,7 +71,9 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
6971
sender.BindToAnonymousPort(address);
7072

7173
byte[] sendBuffer = new byte[DatagramSize];
72-
byte[] receiveBuffer = new byte[DatagramSize];
74+
var receiveInternalBuffer = new byte[DatagramSize + Offset];
75+
var emptyBuffer = new byte[Offset];
76+
ArraySegment<byte> receiveBuffer = new ArraySegment<byte>(receiveInternalBuffer, Offset, DatagramSize);
7377
Random rnd = new Random(0);
7478

7579
IPEndPoint remoteEp = new IPEndPoint(ipv4 ? IPAddress.Any : IPAddress.IPv6Any, 0);
@@ -83,7 +87,8 @@ public async Task ReceiveSentMessages_UDP_Success(bool ipv4)
8387
IPPacketInformation packetInformation = result.PacketInformation;
8488

8589
Assert.Equal(DatagramSize, result.ReceivedBytes);
86-
AssertExtensions.SequenceEqual(sendBuffer, receiveBuffer);
90+
AssertExtensions.SequenceEqual(emptyBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, 0, Offset));
91+
AssertExtensions.SequenceEqual(sendBuffer, new ReadOnlySpan<byte>(receiveInternalBuffer, Offset, DatagramSize));
8792
Assert.Equal(sender.LocalEndPoint, result.RemoteEndPoint);
8893
Assert.Equal(((IPEndPoint)sender.LocalEndPoint).Address, packetInformation.Address);
8994
}

0 commit comments

Comments
 (0)