Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
<Compile Include="SqlConnectionBasicTests.cs" />
<Compile Include="SqlCommandTest.cs" />
<Compile Include="SqlConnectionTest.cs" />
<Compile Include="SslOverTdsStreamTest.cs" />
<Compile Include="TestTdsServer.cs" />
<Compile Include="AADAccessTokenTest.cs" />
<Compile Include="CloneTests.cs" />
Expand Down Expand Up @@ -60,6 +61,7 @@
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' != 'netcoreapp2.1' AND '$(TargetGroup)' == 'netcoreapp'">
<PackageReference Include="System.Data.Odbc" Version="$(SystemDataOdbcVersion)" />
<Compile Include="SslOverTdsStreamTest.NetCoreApp.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="$(TestsPath)ManualTests\SQL\UdtTest\UDTs\Address\Address.csproj">
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;

namespace Microsoft.Data.SqlClient.Tests
{
public static partial class SslOverTdsStreamTest
{
static partial void SyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
{
byte[] input;
byte[] output;
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);

byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
(Stream stream, int index) =>
{
stream.Write(input.AsSpan(TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE));
}
);

ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
(Stream stream, byte[] bytes, int offset, int count) =>
{
return stream.Read(bytes.AsSpan(offset, count));
}
);

Validate(input, output);
}

static partial void AsyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
{
byte[] input;
byte[] output;
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);

byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
async (Stream stream, int index) =>
{
await stream.WriteAsync(
new ReadOnlyMemory<byte>(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE)
);
}
);

ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
async (Stream stream, byte[] bytes, int offset, int count) =>
{
return await stream.ReadAsync(
new Memory<byte>(bytes, offset, count)
);
}
);

Validate(input, output);
}
}

public sealed partial class LimitedMemoryStream : MemoryStream
{
public override int Read(Span<byte> destination)
{
if (_readLimit > 0)
{
return base.Read(destination.Slice(0, Math.Min(_readLimit, destination.Length)));
}
else
{
return base.Read(destination);
}
}

public override ValueTask<int> ReadAsync(Memory<byte> destination, CancellationToken cancellationToken = default)
{

if (_readLimit > 0)
{
return base.ReadAsync(destination.Slice(0, Math.Min(_readLimit, destination.Length)), cancellationToken);
}
else
{
return base.ReadAsync(destination, cancellationToken);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,278 @@
using System;
using System.Diagnostics;
using System.IO;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Xunit;

namespace Microsoft.Data.SqlClient.Tests
{
public static partial class SslOverTdsStreamTest
{
[Theory]
[SkipOnTargetFramework(TargetFrameworkMonikers.NetFramework)]
[InlineData(0),InlineData(3),InlineData(128), InlineData(2048), InlineData(8192)]
public static void ReadWrite(int readLimit)
{
const int EncapulatedPacketCount = 4;
const int PassThroughPacketCount = 5;

SyncTest(EncapulatedPacketCount, PassThroughPacketCount, readLimit);
SyncCoreTest(EncapulatedPacketCount, PassThroughPacketCount, readLimit);
AsyncTest(EncapulatedPacketCount, PassThroughPacketCount, readLimit);
AsyncCoreTest(EncapulatedPacketCount, PassThroughPacketCount, readLimit);
}

private static void SyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
{
byte[] input;
byte[] output;
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);

byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
(Stream stream, int index) =>
{
stream.Write(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE);
}
);

ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
(Stream stream, byte[] bytes, int offset, int count) =>
{
return stream.Read(bytes, offset, count);
}
);

Validate(input, output);
}

static partial void SyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength);

private static void AsyncTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength)
{
byte[] input;
byte[] output;
SetupArrays(encapsulatedPacketCount + passthroughPacketCount, out input, out output);
byte[] buffer = WritePackets(encapsulatedPacketCount, passthroughPacketCount,
async (Stream stream, int index) =>
{
await stream.WriteAsync(input, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE);
}
);

ReadPackets(buffer, encapsulatedPacketCount, passthroughPacketCount, maxPacketReadLength, output,
async (Stream stream, byte[] bytes, int offset, int count) =>
{
return await stream.ReadAsync(bytes, offset, count);
}
);

Validate(input, output);
}

static partial void AsyncCoreTest(int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength);



private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func<Stream, byte[], int, int, Task<int>> action)
{
using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength))
using (Stream tdsStream = CreateSslOverTdsStream(stream))
{
int offset = 0;
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
for (int index = 0; index < encapsulatedPacketCount; index++)
{
Array.Clear(bytes, 0, bytes.Length);
int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult();
Array.Copy(bytes, 0, output, offset, packetBytes);
offset += packetBytes;
}
InvokeFinishHandshake(tdsStream);//tdsStream.FinishHandshake();
for (int index = 0; index < passthroughPacketCount; index++)
{
Array.Clear(bytes, 0, bytes.Length);
int packetBytes = ReadPacket(tdsStream, action, bytes).GetAwaiter().GetResult();
Array.Copy(bytes, 0, output, offset, packetBytes);
offset += packetBytes;
}
}
}

private static void InvokeFinishHandshake(Stream stream)
{
MethodInfo method = stream.GetType().GetMethod("FinishHandshake", BindingFlags.Public | BindingFlags.Instance);
method.Invoke(stream, null);
}

private static Stream CreateSslOverTdsStream(Stream stream)
{
Type type = typeof(SqlClientFactory).Assembly.GetType("Microsoft.Data.SqlClient.SNI.SslOverTdsStream");
ConstructorInfo ctor = type.GetConstructor(new Type[] { typeof(Stream) });
Stream instance = (Stream)ctor.Invoke(new object[] { stream });
return instance;
}

private static void ReadPackets(byte[] buffer, int encapsulatedPacketCount, int passthroughPacketCount, int maxPacketReadLength, byte[] output, Func<Stream, byte[], int, int, int> action)
{
using (LimitedMemoryStream stream = new LimitedMemoryStream(buffer, maxPacketReadLength))
using (Stream tdsStream = CreateSslOverTdsStream(stream))
{
int offset = 0;
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
for (int index = 0; index < encapsulatedPacketCount; index++)
{
Array.Clear(bytes, 0, bytes.Length);
int packetBytes = ReadPacket(tdsStream, action, bytes);
Array.Copy(bytes, 0, output, offset, packetBytes);
offset += packetBytes;
}
InvokeFinishHandshake(tdsStream);
for (int index = 0; index < passthroughPacketCount; index++)
{
Array.Clear(bytes, 0, bytes.Length);
int packetBytes = ReadPacket(tdsStream, action, bytes);
Array.Copy(bytes, 0, output, offset, packetBytes);
offset += packetBytes;
}
}

}

private static int ReadPacket(Stream tdsStream, Func<Stream, byte[], int, int, int> action, byte[] output)
{
int readCount;
int offset = 0;
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
do
{
readCount = action(tdsStream, bytes, offset, bytes.Length - offset);
if (readCount > 0)
{
offset += readCount;
}
}
while (readCount > 0 && offset < bytes.Length);
Array.Copy(bytes, 0, output, 0, offset);
return offset;
}

private static async Task<int> ReadPacket(Stream tdsStream, Func<Stream, byte[], int, int, Task<int>> action, byte[] output)
{
int readCount;
int offset = 0;
byte[] bytes = new byte[TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
do
{
readCount = await action(tdsStream, bytes, offset, bytes.Length - offset);
if (readCount > 0)
{
offset += readCount;
}
}
while (readCount > 0 && offset < bytes.Length);
Array.Copy(bytes, 0, output, 0, offset);
return offset;
}

private static byte[] WritePackets(int encapsulatedPacketCount, int passthroughPacketCount, Action<Stream, int> action)
{
byte[] buffer = null;
using (LimitedMemoryStream stream = new LimitedMemoryStream())
{
//using (SslOverTdsStream tdsStream = new SslOverTdsStream(stream))
using (Stream tdsStream = CreateSslOverTdsStream(stream))
{
for (int index = 0; index < encapsulatedPacketCount; index++)
{
action(tdsStream, index);
}
InvokeFinishHandshake(tdsStream);//tdsStream.FinishHandshake();
for (int index = 0; index < passthroughPacketCount; index++)
{
action(tdsStream, encapsulatedPacketCount + index);
}
}
buffer = stream.ToArray();
}
return buffer;
}

private static void SetupArrays(int packetCount, out byte[] input, out byte[] output)
{
byte[] pattern = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13 };
input = new byte[packetCount * TdsEnums.DEFAULT_LOGIN_PACKET_SIZE];
output = new byte[input.Length];
for (int index = 0; index < packetCount; index++)
{
int position = 0;
while (position < TdsEnums.DEFAULT_LOGIN_PACKET_SIZE)
{
int copyCount = Math.Min(pattern.Length, TdsEnums.DEFAULT_LOGIN_PACKET_SIZE - position);
Array.Copy(pattern, 0, input, (TdsEnums.DEFAULT_LOGIN_PACKET_SIZE * index) + position, copyCount);
position += copyCount;
}
}
}

private static void Validate(byte[] input, byte[] output)
{
Assert.True(input.AsSpan().SequenceEqual(output.AsSpan()));
}

internal static class TdsEnums
{
public const int DEFAULT_LOGIN_PACKET_SIZE = 4096;
}
}

[DebuggerStepThrough]
public sealed partial class LimitedMemoryStream : MemoryStream
{
private readonly int _readLimit;
private readonly int _delay;

public LimitedMemoryStream(int readLimit = 0, int delay = 0)
{
_readLimit = readLimit;
_delay = delay;
}

public LimitedMemoryStream(byte[] buffer, int readLimit = 0, int delay = 0)
: base(buffer)
{
_readLimit = readLimit;
_delay = delay;
}

public override int Read(byte[] buffer, int offset, int count)
{
if (_readLimit > 0)
{
return base.Read(buffer, offset, Math.Min(_readLimit, count));
}
else
{
return base.Read(buffer, offset, count);
}
}

public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
if (_delay > 0)
{
await Task.Delay(_delay, cancellationToken);
}
if (_readLimit > 0)
{
return await base.ReadAsync(buffer, offset, Math.Min(_readLimit, count), cancellationToken).ConfigureAwait(false);
}
else
{
return await base.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
}
}

}
}