diff --git a/src/Renci.SshNet/BaseClient.cs b/src/Renci.SshNet/BaseClient.cs index ddd434c2e..bdfd5cb5f 100644 --- a/src/Renci.SshNet/BaseClient.cs +++ b/src/Renci.SshNet/BaseClient.cs @@ -153,6 +153,11 @@ public TimeSpan KeepAliveInterval /// public event EventHandler HostKeyReceived; + /// + /// Occurs when server identification received. + /// + public event EventHandler ServerIdentificationReceived; + /// /// Initializes a new instance of the class. /// @@ -390,6 +395,11 @@ private void Session_HostKeyReceived(object sender, HostKeyEventArgs e) HostKeyReceived?.Invoke(this, e); } + private void Session_ServerIdentificationReceived(object sender, SshIdentificationEventArgs e) + { + ServerIdentificationReceived?.Invoke(this, e); + } + /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// @@ -532,6 +542,7 @@ private Timer CreateKeepAliveTimer(TimeSpan dueTime, TimeSpan period) private ISession CreateAndConnectSession() { var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory()); + session.ServerIdentificationReceived += Session_ServerIdentificationReceived; session.HostKeyReceived += Session_HostKeyReceived; session.ErrorOccured += Session_ErrorOccured; @@ -550,6 +561,7 @@ private ISession CreateAndConnectSession() private async Task CreateAndConnectSessionAsync(CancellationToken cancellationToken) { var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory()); + session.ServerIdentificationReceived += Session_ServerIdentificationReceived; session.HostKeyReceived += Session_HostKeyReceived; session.ErrorOccured += Session_ErrorOccured; @@ -569,6 +581,7 @@ private void DisposeSession(ISession session) { session.ErrorOccured -= Session_ErrorOccured; session.HostKeyReceived -= Session_HostKeyReceived; + session.ServerIdentificationReceived -= Session_ServerIdentificationReceived; session.Dispose(); } diff --git a/src/Renci.SshNet/Common/SshIdentificationEventArgs.cs b/src/Renci.SshNet/Common/SshIdentificationEventArgs.cs new file mode 100644 index 000000000..f618112bd --- /dev/null +++ b/src/Renci.SshNet/Common/SshIdentificationEventArgs.cs @@ -0,0 +1,26 @@ +using System; + +using Renci.SshNet.Connection; + +namespace Renci.SshNet.Common +{ + /// + /// Provides data for the ServerIdentificationReceived events. + /// + public class SshIdentificationEventArgs : EventArgs + { + /// + /// Initializes a new instance of the class. + /// + /// The SSH identification. + public SshIdentificationEventArgs(SshIdentification sshIdentification) + { + SshIdentification = sshIdentification; + } + + /// + /// Gets the SSH identification. + /// + public SshIdentification SshIdentification { get; private set; } + } +} diff --git a/src/Renci.SshNet/Connection/SshIdentification.cs b/src/Renci.SshNet/Connection/SshIdentification.cs index 931656296..727cc4d94 100644 --- a/src/Renci.SshNet/Connection/SshIdentification.cs +++ b/src/Renci.SshNet/Connection/SshIdentification.cs @@ -5,7 +5,7 @@ namespace Renci.SshNet.Connection /// /// Represents an SSH identification. /// - internal sealed class SshIdentification + public sealed class SshIdentification { /// /// Initializes a new instance of the class with the specified protocol version diff --git a/src/Renci.SshNet/ISession.cs b/src/Renci.SshNet/ISession.cs index 5a035b104..e78ff75f8 100644 --- a/src/Renci.SshNet/ISession.cs +++ b/src/Renci.SshNet/ISession.cs @@ -260,6 +260,11 @@ internal interface ISession : IDisposable /// event EventHandler ErrorOccured; + /// + /// Occurs when server identification received. + /// + event EventHandler ServerIdentificationReceived; + /// /// Occurs when host key received. /// diff --git a/src/Renci.SshNet/Session.cs b/src/Renci.SshNet/Session.cs index 326e9b139..c984d3109 100644 --- a/src/Renci.SshNet/Session.cs +++ b/src/Renci.SshNet/Session.cs @@ -366,6 +366,11 @@ public Message ClientInitMessage /// public event EventHandler Disconnected; + /// + /// Occurs when server identification received. + /// + public event EventHandler ServerIdentificationReceived; + /// /// Occurs when host key received. /// @@ -624,6 +629,8 @@ public void Connect() DisconnectReason.ProtocolVersionNotSupported); } + ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification)); + // Register Transport response messages RegisterMessage("SSH_MSG_DISCONNECT"); RegisterMessage("SSH_MSG_IGNORE"); @@ -736,6 +743,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken) DisconnectReason.ProtocolVersionNotSupported); } + ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification)); + // Register Transport response messages RegisterMessage("SSH_MSG_DISCONNECT"); RegisterMessage("SSH_MSG_IGNORE"); diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs index 8f4bed7c0..3c1c2fdcf 100644 --- a/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs @@ -46,7 +46,8 @@ public abstract class SessionTest_ConnectedBase protected Session Session { get; private set; } protected Socket ClientSocket { get; private set; } protected Socket ServerSocket { get; private set; } - internal SshIdentification ServerIdentification { get; private set; } + internal SshIdentification ServerIdentification { get; set; } + protected bool CallSessionConnectWhenArrange { get; set; } [TestInitialize] public void Setup() @@ -159,6 +160,8 @@ protected virtual void SetupData() ServerListener.Start(); ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo); + + CallSessionConnectWhenArrange = true; } private void CreateMocks() @@ -180,7 +183,7 @@ private void SetupMocks() _ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange()) .Returns(_protocolVersionExchangeMock.Object); _ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout)) - .Returns(ServerIdentification); + .Returns(() => ServerIdentification); _ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object); _ = _keyExchangeMock.Setup(p => p.Name) .Returns(_keyExchangeAlgorithm); @@ -212,7 +215,10 @@ protected void Arrange() SetupData(); SetupMocks(); - Session.Connect(); + if (CallSessionConnectWhenArrange) + { + Session.Connect(); + } } protected virtual void ClientAuthentication_Callback() diff --git a/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs new file mode 100644 index 000000000..7b5ff1d86 --- /dev/null +++ b/test/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerIdentificationReceived.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; + +using Renci.SshNet.Connection; + +namespace Renci.SshNet.Tests.Classes +{ + [TestClass] + public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase + { + protected override void SetupData() + { + base.SetupData(); + + CallSessionConnectWhenArrange = false; + + Session.ServerIdentificationReceived += (s, e) => + { + if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal)) + && !e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6.1", System.StringComparison.Ordinal)) + { + _ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256"); + _ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256@libssh.org"); + } + }; + } + + protected override void Act() + { + } + + [TestMethod] + [DataRow("OpenSSH_6.5")] + [DataRow("OpenSSH_6.5p1")] + [DataRow("OpenSSH_6.5 PKIX")] + [DataRow("OpenSSH_6.6")] + [DataRow("OpenSSH_6.6p1")] + [DataRow("OpenSSH_6.6 PKIX")] + public void ShouldExcludeCurve25519KexWhenServerIs(string softwareVersion) + { + ServerIdentification = new SshIdentification("2.0", softwareVersion); + + Session.Connect(); + + Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256")); + Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org")); + } + + [TestMethod] + [DataRow("OpenSSH_6.6.1")] + [DataRow("OpenSSH_6.6.1p1")] + [DataRow("OpenSSH_6.6.1 PKIX")] + [DataRow("OpenSSH_6.7")] + [DataRow("OpenSSH_6.7p1")] + [DataRow("OpenSSH_6.7 PKIX")] + public void ShouldIncludeCurve25519KexWhenServerIs(string softwareVersion) + { + ServerIdentification = new SshIdentification("2.0", softwareVersion); + + Session.Connect(); + + Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256")); + Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256@libssh.org")); + } + } +}