diff --git a/mbedtls/Cargo.toml b/mbedtls/Cargo.toml index f9605a243..48a80937a 100644 --- a/mbedtls/Cargo.toml +++ b/mbedtls/Cargo.toml @@ -119,3 +119,8 @@ required-features = ["std"] name = "ssl_conf_verify" path = "tests/ssl_conf_verify.rs" required-features = ["std"] + +[[test]] +name = "ssl_conf_psk" +path = "tests/ssl_conf_psk.rs" +required-features = ["std"] diff --git a/mbedtls/src/ssl/config.rs b/mbedtls/src/ssl/config.rs index 936d047db..f6ebda734 100644 --- a/mbedtls/src/ssl/config.rs +++ b/mbedtls/src/ssl/config.rs @@ -329,6 +329,49 @@ impl<'c> Config<'c> { ) } } + + /// psk and psk_identity cannot be empty + pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> { + assert!(psk_identity.len()>0); + assert!(psk.len()>0); + unsafe { ssl_conf_psk(&mut self.inner, + psk.as_ptr(), psk.len(), + psk_identity.as_ptr(), psk_identity.len()) + .into_result().map(|_| ()) + } + } + + pub fn set_psk_callback(&mut self, cb: &'c mut F) + where + F: FnMut(&mut HandshakeContext, &str) -> Result<()>, + { + unsafe extern "C" fn psk_callback( + closure: *mut c_void, + ctx: *mut ssl_context, + psk_identity: *const c_uchar, + identity_len: size_t) -> c_int + where + F: FnMut(&mut HandshakeContext, &str) -> Result<()>, + { + assert!(identity_len>0); + let cb = &mut *(closure as *mut F); + let mut ctx = UnsafeFrom::from(ctx).expect("valid context"); + let psk_identity = &*(from_raw_parts(psk_identity, identity_len) + as *const [u8] as *const str); + match cb(&mut ctx, psk_identity) { + Ok(()) => 0, + Err(e) => e.to_int(), + } + } + + unsafe { + ssl_conf_psk_cb( + &mut self.inner, + Some(psk_callback::), + cb as *mut F as _ + ) + } + } } /// Builds a linked list of x509_crt instances, all of which are owned by mbedtls. That is, the @@ -417,8 +460,6 @@ impl<'a> Iterator for KeyCertIter<'a> { // ssl_conf_dtls_badmac_limit // ssl_conf_handshake_timeout // ssl_conf_session_cache -// ssl_conf_psk -// ssl_conf_psk_cb // ssl_conf_sig_hashes // ssl_conf_alpn_protocols // ssl_conf_fallback diff --git a/mbedtls/src/ssl/context.rs b/mbedtls/src/ssl/context.rs index cbc17534a..8fd7752c4 100644 --- a/mbedtls/src/ssl/context.rs +++ b/mbedtls/src/ssl/context.rs @@ -205,6 +205,16 @@ impl<'ctx> HandshakeContext<'ctx> { .map(|_| ()) } } + + /// psk cannot be empty + pub fn set_psk(&mut self, psk: &[u8]) -> Result<()> { + assert!(psk.len()>0); + unsafe { + ssl_set_hs_psk(self.inner, psk.as_ptr(), psk.len()) + .into_result() + .map(|_| ()) + } + } } impl<'ctx> ::core::ops::Deref for HandshakeContext<'ctx> { @@ -315,7 +325,6 @@ impl<'a> Drop for Session<'a> { // ssl_renegotiate // ssl_send_alert_message // ssl_set_client_transport_id -// ssl_set_hs_psk // ssl_set_timer_cb // // ssl_handshake_step diff --git a/mbedtls/tests/ssl_conf_psk.rs b/mbedtls/tests/ssl_conf_psk.rs new file mode 100644 index 000000000..f02a3fcf5 --- /dev/null +++ b/mbedtls/tests/ssl_conf_psk.rs @@ -0,0 +1,58 @@ +#![allow(dead_code)] +extern crate mbedtls; + +use std::net::TcpStream; + +mod support; +use support::entropy::entropy_new; + +use mbedtls::rng::CtrDrbg; +use mbedtls::ssl::config::{Endpoint, Preset, Transport}; +use mbedtls::ssl::{Config, Context, HandshakeContext}; +use mbedtls::Result as TlsResult; + + +fn client(mut conn: TcpStream, psk: &[u8]) -> TlsResult<()> { + { + let mut entropy = entropy_new(); + let mut rng = CtrDrbg::new(&mut entropy, None)?; + let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default); + config.set_rng(Some(&mut rng)); + config.set_psk(psk, "Client_identity")?; + let mut ctx = Context::new(&config)?; + ctx.establish(&mut conn, None).map(|_| ())?; + Ok(()) + } +} + +fn server(mut conn: TcpStream, mut psk_callback: F) -> TlsResult<()> + where + F: FnMut(&mut HandshakeContext, &str) -> TlsResult<()> { + let mut entropy = entropy_new(); + let mut rng = CtrDrbg::new(&mut entropy, None)?; + let mut config = Config::new(Endpoint::Server, Transport::Stream, Preset::Default); + config.set_rng(Some(&mut rng)); + config.set_psk_callback(&mut psk_callback); + let mut ctx = Context::new(&config)?; + let _ = ctx.establish(&mut conn, None)?; + Ok(()) +} + +#[cfg(unix)] +mod test { + use super::*; + use std::thread; + use crate::support::net::create_tcp_pair; + use crate::support::keys; + + #[test] + fn callback_standard_psk() { + let (c, s) = create_tcp_pair().unwrap(); + let psk_callback = + |ctx: &mut HandshakeContext, _: &str| { ctx.set_psk(keys::PRESHARED_KEY) }; + let c = thread::spawn(move || super::client(c, keys::PRESHARED_KEY).unwrap()); + let s = thread::spawn(move || super::server(s, psk_callback).unwrap()); + c.join().unwrap(); + s.join().unwrap(); + } +} diff --git a/mbedtls/tests/support/keys.rs b/mbedtls/tests/support/keys.rs index 86d0d212b..57c7a0fa6 100644 --- a/mbedtls/tests/support/keys.rs +++ b/mbedtls/tests/support/keys.rs @@ -279,3 +279,7 @@ C4j3yqL0Gbs+moaswS1UR8XSnKt8TBcXVozCAy12A4qKSjkP7VKPTLeTOZxw0UBe 8CzQYNKoIGy4ayFVi+VKaNCHKvJm0diQkKw5Tz7L5quBBjt8JpmRtNbPsjXiq4Is y14Xc4kb05mM5M9u685eWefa -----END PRIVATE KEY-----\0"; + +pub const PRESHARED_KEY: &'static [u8] = &[ + 234, 206, 151, 23, 219, 21, 71, 144, + 107, 42, 23, 67, 249, 173, 182, 224 ];