Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,6 @@ private static TlsAlertMessage GetAlertMessageFromChain(X509Chain chain)
return TlsAlertMessage.CertificateUnknown;
}

Debug.Fail("GetAlertMessageFromChain was called but none of the chain elements had errors.");
return TlsAlertMessage.BadCertificate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ namespace System.Net.Security.Tests

public class SslStreamNetworkStreamTest
{
private readonly X509Certificate2 _serverCert;
private readonly X509CertificateCollection _serverChain;

public SslStreamNetworkStreamTest()
{
(_serverCert, _serverChain) = TestHelper.GenerateCertificates("localhost", DateTimeOffset.UtcNow.AddMinutes(-5));
}

[Fact]
public async Task SslStream_SendReceiveOverNetworkStream_Ok()
{
Expand Down Expand Up @@ -193,6 +201,67 @@ public async Task SslStream_NestedAuth_Throws()
}
}

[Fact]
public async Task SslStream_UntrustedCaWithCustomCallback_OK()
{
var options = new SslClientAuthenticationOptions() { TargetHost = "localhost" };
options.RemoteCertificateValidationCallback =
(sender, certificate, chain, sslPolicyErrors) =>
{
chain.ChainPolicy.ExtraStore.AddRange(_serverChain);
chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;

bool result = chain.Build((X509Certificate2)certificate);
Assert.True(result);

return result;
};

(Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
using (clientStream)
using (serverStream)
using (SslStream client = new SslStream(clientStream))
using (SslStream server = new SslStream(serverStream))
{
Task t1 = client.AuthenticateAsClientAsync(options, default);
Task t2 = server.AuthenticateAsServerAsync(_serverCert);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(t1, t2);
}
}

[Fact]
public async Task SslStream_UntrustedCaWithCustomCallback_Throws()
{
var options = new SslClientAuthenticationOptions() { TargetHost = "localhost" };
options.RemoteCertificateValidationCallback =
(sender, certificate, chain, sslPolicyErrors) =>
{
chain.ChainPolicy.ExtraStore.AddRange(_serverChain);
chain.ChainPolicy.CustomTrustStore.Add(_serverChain[_serverChain.Count -1]);
chain.ChainPolicy.TrustMode = X509ChainTrustMode.CustomRootTrust;
// This should work and we should be able to trust the chain.
Assert.True(chain.Build((X509Certificate2)certificate));
// Reject it in custom callback to simulate for example pinning.
return false;
};

(Stream clientStream, Stream serverStream) = TestHelper.GetConnectedStreams();
using (clientStream)
using (serverStream)
using (SslStream client = new SslStream(clientStream))
using (SslStream server = new SslStream(serverStream))
{
Task t1 = client.AuthenticateAsClientAsync(options, default);
Task t2 = server.AuthenticateAsServerAsync(_serverCert);

await Assert.ThrowsAsync<AuthenticationException>(() => t1);
// Server side should finish since we run custom callback after handshake is done.
await t2;
}
}

private static bool ValidateServerCertificate(
object sender,
X509Certificate retrievedServerPublicCertificate,
Expand Down
137 changes: 135 additions & 2 deletions src/libraries/System.Net.Security/tests/FunctionalTests/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Net.Test.Common;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
using System.Text;

namespace System.Net.Security.Tests
{
Expand Down Expand Up @@ -47,5 +48,137 @@ internal static (VirtualNetworkStream ClientStream, VirtualNetworkStream ServerS

return (new VirtualNetworkStream(vn, isServer: false), new VirtualNetworkStream(vn, isServer: true));
}

internal static (X509Certificate2 certificate, X509Certificate2Collection) GenerateCertificates(string name, DateTimeOffset startTime, string? caUrl = null)
{
X509Certificate2Collection chain = new X509Certificate2Collection();

using (RSA root = RSA.Create())
using (RSA intermediate = RSA.Create())
using (RSA server = RSA.Create())
{
CertificateRequest rootReq = new CertificateRequest(
"CN=Root",
root,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);

rootReq.CertificateExtensions.Add(
new X509BasicConstraintsExtension(true, false, 0, true));
rootReq.CertificateExtensions.Add(
new X509SubjectKeyIdentifierExtension(rootReq.PublicKey, false));
rootReq.CertificateExtensions.Add(
new X509KeyUsageExtension(X509KeyUsageFlags.KeyCertSign, false));

//DateTimeOffset start = DateTimeOffset.UtcNow.AddMinutes(-5);
DateTimeOffset endTime = startTime.AddMonths(1);

X509Certificate2 rootCertWithKey = rootReq.CreateSelfSigned(startTime, endTime);

CertificateRequest intermedReq = new CertificateRequest(
"CN=Intermediate",
intermediate,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);

intermedReq.CertificateExtensions.Add(
new X509BasicConstraintsExtension(true, false, 0, true));
intermedReq.CertificateExtensions.Add(
new X509SubjectKeyIdentifierExtension(intermedReq.PublicKey, false));
intermedReq.CertificateExtensions.Add(
new X509KeyUsageExtension(X509KeyUsageFlags.KeyCertSign, false));

byte[] serial = new byte[8];
RandomNumberGenerator.Fill(serial);

X509Certificate2 intermedCertWithKey;
using (X509Certificate2 intermedPub = intermedReq.Create(rootCertWithKey, startTime, endTime, serial))
{
intermedCertWithKey = intermedPub.CopyWithPrivateKey(intermediate);
}

CertificateRequest serverReq = new CertificateRequest(
$"CN={name}",
server,
HashAlgorithmName.SHA256,
RSASignaturePadding.Pkcs1);

serverReq.CertificateExtensions.Add(
new X509BasicConstraintsExtension(false, false, 0, false));
serverReq.CertificateExtensions.Add(
new X509SubjectKeyIdentifierExtension(serverReq.PublicKey, false));

// Add Issuer KeyIdentifier
using (SHA1 sha1 = SHA1.Create())
{
byte[] data = new byte[24];
data[0] = 0x30; //SEQUENCE
data[1] = 22;
data[2] = 0x80;
data[3] = 20;
Buffer.BlockCopy(sha1.ComputeHash(intermedCertWithKey.PublicKey.EncodedKeyValue.RawData), 0, data, 4, 20);
serverReq.CertificateExtensions.Add(new X509Extension(new Oid("2.5.29.35"), data, false));
}

// 1.3.6.1.5.5.7.1.1
if (caUrl != null)
{
var urlBytes = Encoding.ASCII.GetBytes(caUrl);

byte[] data = new byte[urlBytes.Length + 16];
data[0] = 0x30; //SEQUENCE
data[1] = (byte)(urlBytes.Length + 14);
data[2] = 0x30; //SEQUENCE;
data[3] = (byte)(urlBytes.Length + 12);
data[4] = 6; // OBJECT
data[5] = 8; // LENGTH
// OID
data[6] = 0x2b;
data[7] = 0x6;
data[8] = 0x1;
data[9] = 0x5;
data[10] = 0x5;
data[11] = 0x7;
data[12] = 0x30; // SEQUENCE
data[13] = 02;
data[14] = 0x86;
data[15] = (byte)(urlBytes.Length);
data[16] = 0x74;
data[17] = 0x74;
data[18] = 0x70;
Buffer.BlockCopy(urlBytes, 0, data, 16, urlBytes.Length);

serverReq.CertificateExtensions.Add(new X509Extension(new Oid("1.3.6.1.5.5.7.1.1"), data, false));
}

serverReq.CertificateExtensions.Add(
new X509KeyUsageExtension(
X509KeyUsageFlags.DigitalSignature | X509KeyUsageFlags.KeyEncipherment | X509KeyUsageFlags.DataEncipherment,
false));
serverReq.CertificateExtensions.Add(
new X509EnhancedKeyUsageExtension(
new OidCollection()
{
new Oid("1.3.6.1.5.5.7.3.1", null),
},
false));

SubjectAlternativeNameBuilder builder = new SubjectAlternativeNameBuilder();
builder.AddDnsName(name);
builder.AddIpAddress(IPAddress.Loopback);
builder.AddIpAddress(IPAddress.IPv6Loopback);
serverReq.CertificateExtensions.Add(builder.Build());

RandomNumberGenerator.Fill(serial);

X509Certificate2 serverCert = serverReq.Create(intermedCertWithKey, startTime, endTime, serial);
X509Certificate2 serverCertWithKey = serverCert.CopyWithPrivateKey(server);

chain.Add(intermedCertWithKey);
chain.Add(rootCertWithKey);

return (serverCertWithKey, chain);
}
}
}
}