@@ -99,6 +99,7 @@ define!(
9999callback ! ( DbgCallback : Fn ( i32 , Cow <' _, str >, i32 , Cow <' _, str >) -> ( ) ) ;
100100callback ! ( SniCallback : Fn ( & mut HandshakeContext , & [ u8 ] ) -> Result <( ) >) ;
101101callback ! ( CaCallback : Fn ( & MbedtlsList <Certificate >) -> Result <MbedtlsList <Certificate >>) ;
102+ callback ! ( PskCallback : Fn ( & mut HandshakeContext , & str ) -> Result <( ) >) ;
102103
103104
104105#[ repr( transparent) ]
@@ -164,6 +165,7 @@ define!(
164165 sni_callback: Option <Arc <dyn SniCallback + ' static >>,
165166 ticket_callback: Option <Arc <dyn TicketCallback + ' static >>,
166167 ca_callback: Option <Arc <dyn CaCallback + ' static >>,
168+ psk_callback: Option <Arc <dyn PskCallback + ' static >>,
167169 } ;
168170 const drop: fn ( & mut Self ) = ssl_config_free;
169171 impl <' a> Into <ptr> { }
@@ -199,6 +201,7 @@ impl Config {
199201 sni_callback : None ,
200202 ticket_callback : None ,
201203 ca_callback : None ,
204+ psk_callback : None ,
202205 }
203206 }
204207
@@ -457,6 +460,43 @@ impl Config {
457460 self . dbg_callback = Some ( Arc :: new ( cb) ) ;
458461 unsafe { ssl_conf_dbg ( self . into ( ) , Some ( dbg_callback :: < F > ) , & * * self . dbg_callback . as_mut ( ) . unwrap ( ) as * const _ as * mut c_void ) }
459462 }
463+
464+ pub fn set_psk ( & mut self , psk : & [ u8 ] , psk_identity : & str ) -> Result < ( ) > {
465+ unsafe { ssl_conf_psk ( & mut self . inner ,
466+ psk. as_ptr ( ) , psk. len ( ) ,
467+ psk_identity. as_ptr ( ) , psk_identity. len ( ) )
468+ . into_result ( ) . map ( |_| ( ) )
469+ }
470+ }
471+
472+ pub fn set_psk_callback < F > ( & mut self , cb : F )
473+ where
474+ F : PskCallback + ' static ,
475+ {
476+ unsafe extern "C" fn psk_callback < F > (
477+ closure : * mut c_void ,
478+ ctx : * mut ssl_context ,
479+ psk_identity : * const c_uchar ,
480+ identity_len : size_t ,
481+ ) -> c_int
482+ where
483+ F : PskCallback + ' static ,
484+ {
485+ let cb = & mut * ( closure as * mut F ) ;
486+ let context = UnsafeFrom :: from ( ctx) . unwrap ( ) ;
487+
488+ let mut ctx = HandshakeContext :: init ( context) ;
489+
490+ let psk_identity = std:: str:: from_utf8_unchecked ( from_raw_parts ( psk_identity, identity_len) ) ;
491+ match cb ( & mut ctx, psk_identity) {
492+ Ok ( ( ) ) => 0 ,
493+ Err ( e) => e. to_int ( ) ,
494+ }
495+ }
496+
497+ self . psk_callback = Some ( Arc :: new ( cb) ) ;
498+ unsafe { ssl_conf_psk_cb ( self . into ( ) , Some ( psk_callback :: < F > ) , & * * self . psk_callback . as_mut ( ) . unwrap ( ) as * const _ as * mut c_void ) }
499+ }
460500}
461501
462502// TODO
@@ -466,8 +506,6 @@ impl Config {
466506// ssl_conf_dtls_badmac_limit
467507// ssl_conf_handshake_timeout
468508// ssl_conf_session_cache
469- // ssl_conf_psk
470- // ssl_conf_psk_cb
471509// ssl_conf_sig_hashes
472510// ssl_conf_alpn_protocols
473511// ssl_conf_fallback
0 commit comments