@@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList;
7474class  SSLContext 
7575{
7676public: 
77-     SSLContext ()
77+     SSLContext (bool  isServer =  false )
7878    {
79-         if  (_ssl_ctx_refcnt == 0 ) {
80-             _ssl_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
79+         _isServer = isServer;
80+         if  (!_isServer) {
81+             if  (_ssl_client_ctx_refcnt == 0 ) {
82+                 _ssl_client_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
83+             }
84+             ++_ssl_client_ctx_refcnt;
85+         } else  {
86+             if  (_ssl_svr_ctx_refcnt == 0 ) {
87+                 _ssl_svr_ctx = ssl_ctx_new (SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0 );
88+             }
89+             ++_ssl_svr_ctx_refcnt;
8190        }
82-         ++_ssl_ctx_refcnt;
8391    }
8492
8593    ~SSLContext ()
8694    {
87-         if  (_ssl ) {
88-             ssl_free (_ssl );
89-             _ssl  = nullptr ;
95+         if  (io_ctx ) {
96+             io_ctx-> unref ( );
97+             io_ctx  = nullptr ;
9098        }
91- 
92-         --_ssl_ctx_refcnt;
93-         if  (_ssl_ctx_refcnt == 0 ) {
94-             ssl_ctx_free (_ssl_ctx);
99+         _ssl = nullptr ;
100+         if  (!_isServer) {
101+             --_ssl_client_ctx_refcnt;
102+             if  (_ssl_client_ctx_refcnt == 0 ) {
103+                 ssl_ctx_free (_ssl_client_ctx);
104+                 _ssl_client_ctx = nullptr ;
105+             }
106+         } else  {
107+             --_ssl_svr_ctx_refcnt;
108+             if  (_ssl_svr_ctx_refcnt == 0 ) {
109+                 ssl_ctx_free (_ssl_svr_ctx);
110+                 _ssl_svr_ctx = nullptr ;
111+             }
95112        }
96113    }
97114
98-     void  ref ()
99-     {
100-         ++_refcnt;
101-     }
102- 
103-     void  unref ()
115+     static  void  _delete_shared_SSL (SSL *_to_del)
104116    {
105-         if  (--_refcnt == 0 ) {
106-             delete  this ;
107-         }
117+         ssl_free (_to_del);
108118    }
109119
110120    void  connect (ClientContext* ctx, const  char * hostName, uint32_t  timeout_ms)
@@ -116,50 +126,67 @@ class SSLContext
116126               ssl_free will want to send a close notify alert, but the old TCP connection 
117127               is already gone at this point, so reset io_ctx. */  
118128            io_ctx = nullptr ;
119-             ssl_free ( _ssl) ;
129+             _ssl =  nullptr ;
120130            _available = 0 ;
121131            _read_ptr = nullptr ;
122132        }
123133        io_ctx = ctx;
124-         _ssl = ssl_client_new (_ssl_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
134+         ctx->ref ();
135+ 
136+         //  Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
137+         SSL *_new_ssl = ssl_client_new (_ssl_client_ctx, reinterpret_cast <int >(this ), nullptr , 0 , ext);
138+         std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
139+         _ssl = _new_ssl_shared;
140+ 
125141        uint32_t  t = millis ();
126142
127-         while  (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
143+         while  (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
128144            uint8_t * data;
129-             int  rc = ssl_read (_ssl, &data);
145+             int  rc = ssl_read (_ssl. get () , &data);
130146            if  (rc < SSL_OK) {
131147                ssl_display_error (rc);
132148                break ;
133149            }
134150        }
135151    }
136152
137-     void  connectServer (ClientContext *ctx) {
153+     void  connectServer (ClientContext *ctx, uint32_t  timeout_ms)
154+     {
138155        io_ctx = ctx;
139- 	_ssl = ssl_server_new (_ssl_ctx, reinterpret_cast <int >(this ));
140-         _isServer = true ;
156+         ctx->ref ();
157+ 
158+         //  Wrap the new SSL with a smart pointer, custom deleter to call ssl_free
159+ 	SSL *_new_ssl = ssl_server_new (_ssl_svr_ctx, reinterpret_cast <int >(this ));
160+         std::shared_ptr<SSL> _new_ssl_shared (_new_ssl, _delete_shared_SSL);
161+         _ssl = _new_ssl_shared;
141162
142- 	uint32_t  timeout_ms = 5000 ;
143163        uint32_t  t = millis ();
144164
145-         while  (millis () - t < timeout_ms && ssl_handshake_status (_ssl) != SSL_OK) {
165+         while  (millis () - t < timeout_ms && ssl_handshake_status (_ssl. get () ) != SSL_OK) {
146166            uint8_t * data;
147-             int  rc = ssl_read (_ssl, &data);
167+             int  rc = ssl_read (_ssl. get () , &data);
148168            if  (rc < SSL_OK) {
169+                 ssl_display_error (rc);
149170                break ;
150171            }
151172        }
152173    }
153174
154175    void  stop ()
155176    {
177+         if  (io_ctx) {
178+             io_ctx->unref ();
179+         }
156180        io_ctx = nullptr ;
157181    }
158182
159183    bool  connected ()
160184    {
161-         if  (_isServer) return  _ssl != nullptr ;
162-         else  return  _ssl != nullptr  && ssl_handshake_status (_ssl) == SSL_OK;
185+         if  (_isServer) {
186+             return  _ssl != nullptr ;
187+         } else  {
188+             return  _ssl != nullptr  && ssl_handshake_status (_ssl.get ()) == SSL_OK;
189+         }
163190    }
164191
165192    int  read (uint8_t * dst, size_t  size)
@@ -289,10 +316,9 @@ class SSLContext
289316        return  loadObject (type, buf.get (), size);
290317    }
291318
292- 
293319    bool  loadObject (int  type, const  uint8_t * data, size_t  size)
294320    {
295-         int  rc = ssl_obj_memory_load (_ssl_ctx , type, data, static_cast <int >(size), nullptr );
321+         int  rc = ssl_obj_memory_load (_isServer?_ssl_svr_ctx:_ssl_client_ctx , type, data, static_cast <int >(size), nullptr );
296322        if  (rc != SSL_OK) {
297323            DEBUGV (" loadObject: ssl_obj_memory_load returned %d\n " 
298324            return  false ;
@@ -302,7 +328,7 @@ class SSLContext
302328
303329    bool  verifyCert ()
304330    {
305-         int  rc = ssl_verify_cert (_ssl);
331+         int  rc = ssl_verify_cert (_ssl. get () );
306332        if  (_allowSelfSignedCerts && rc == SSL_X509_ERROR (X509_VFY_ERROR_SELF_SIGNED)) {
307333            DEBUGV (" Allowing self-signed certificate\n " 
308334            return  true ;
@@ -321,12 +347,16 @@ class SSLContext
321347
322348    operator  SSL*()
323349    {
324-         return  _ssl;
350+         return  _ssl. get () ;
325351    }
326352
327353    static  ClientContext* getIOContext (int  fd)
328354    {
329-         return  reinterpret_cast <SSLContext*>(fd)->io_ctx ;
355+         if  (fd) {
356+             SSLContext *thisSSL = reinterpret_cast <SSLContext*>(fd);
357+             return  thisSSL->io_ctx ;
358+         }
359+         return  nullptr ;
330360    }
331361
332362protected: 
@@ -339,10 +369,9 @@ class SSLContext
339369        optimistic_yield (100 );
340370
341371        uint8_t * data;
342-         int  rc = ssl_read (_ssl, &data);
372+         int  rc = ssl_read (_ssl. get () , &data);
343373        if  (rc <= 0 ) {
344374            if  (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) {
345-                 ssl_free (_ssl);
346375                _ssl = nullptr ;
347376            }
348377            return  0 ;
@@ -359,7 +388,7 @@ class SSLContext
359388            return  0 ;
360389        }
361390
362-         int  rc = ssl_write (_ssl, src, size);
391+         int  rc = ssl_write (_ssl. get () , src, size);
363392        if  (rc >= 0 ) {
364393            return  rc;
365394        }
@@ -404,19 +433,22 @@ class SSLContext
404433    }
405434
406435    bool  _isServer = false ;
407-     static  SSL_CTX* _ssl_ctx;
408-     static  int  _ssl_ctx_refcnt;
409-     SSL* _ssl = nullptr ;
410-     int  _refcnt = 0 ;
436+     static  SSL_CTX* _ssl_client_ctx;
437+     static  int  _ssl_client_ctx_refcnt;
438+     static  SSL_CTX* _ssl_svr_ctx;
439+     static  int  _ssl_svr_ctx_refcnt;
440+     std::shared_ptr<SSL> _ssl = nullptr ;
411441    const  uint8_t * _read_ptr = nullptr ;
412442    size_t  _available = 0 ;
413443    BufferList _writeBuffers;
414444    bool  _allowSelfSignedCerts = false ;
415445    ClientContext* io_ctx = nullptr ;
416446};
417447
418- SSL_CTX* SSLContext::_ssl_ctx = nullptr ;
419- int  SSLContext::_ssl_ctx_refcnt = 0 ;
448+ SSL_CTX* SSLContext::_ssl_client_ctx = nullptr ;
449+ int  SSLContext::_ssl_client_ctx_refcnt = 0 ;
450+ SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr ;
451+ int  SSLContext::_ssl_svr_ctx_refcnt = 0 ;
420452
421453WiFiClientSecure::WiFiClientSecure ()
422454{
@@ -426,41 +458,25 @@ WiFiClientSecure::WiFiClientSecure()
426458
427459WiFiClientSecure::~WiFiClientSecure ()
428460{
429-     if  (_ssl) {
430-         _ssl->unref ();
431-     }
432- }
433- 
434- WiFiClientSecure::WiFiClientSecure (const  WiFiClientSecure& other)
435-     : WiFiClient(static_cast <const  WiFiClient&>(other))
436- {
437-     _ssl = other._ssl ;
438-     if  (_ssl) {
439-         _ssl->ref ();
440-     }
441- }
442- 
443- WiFiClientSecure& WiFiClientSecure::operator =(const  WiFiClientSecure& rhs)
444- {
445-     (WiFiClient&) *this  = rhs;
446-     _ssl = rhs._ssl ;
447-     if  (_ssl) {
448-         _ssl->ref ();
449-     }
450-     return  *this ;
461+    _ssl = nullptr ;
451462}
452463
453464//  Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning
454- WiFiClientSecure::WiFiClientSecure (ClientContext* client, bool  usePMEM, const  uint8_t  *rsakey, int  rsakeyLen, const  uint8_t  *cert, int  certLen)
465+ WiFiClientSecure::WiFiClientSecure (ClientContext* client, bool  usePMEM,
466+                                    const  uint8_t  *rsakey, int  rsakeyLen,
467+                                    const  uint8_t  *cert, int  certLen)
455468{
469+     //  TLS handshake may take more than the 5 second default timeout
470+     _timeout = 15000 ;
471+ 
472+     //  We've been given the client context from the available() call
456473    _client = client;
457-     if  (_ssl) {
458-         _ssl->unref ();
459-         _ssl = nullptr ;
460-     }
474+     _client->ref ();
461475
462-     _ssl = new  SSLContext;
463-     _ssl->ref ();
476+     //  Make the "_ssl" SSLContext, in the constructor there should be none yet
477+     SSLContext *_new_ssl = new  SSLContext (true );
478+     std::shared_ptr<SSLContext> _new_ssl_shared (_new_ssl);
479+     _ssl = _new_ssl_shared;
464480
465481    if  (usePMEM) {
466482        if  (rsakey && rsakeyLen) {
@@ -477,8 +493,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui
477493            _ssl->loadObject (SSL_OBJ_X509_CERT, cert, certLen);
478494        }
479495    }
480-     _client->ref ();
481-     _ssl->connectServer (client);
496+     _ssl->connectServer (client, _timeout);
482497}
483498
484499int  WiFiClientSecure::connect (IPAddress ip, uint16_t  port)
@@ -510,14 +525,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port)
510525int  WiFiClientSecure::_connectSSL (const  char * hostName)
511526{
512527    if  (!_ssl) {
513-         _ssl = new  SSLContext;
514-         _ssl->ref ();
528+         _ssl = std::make_shared<SSLContext>();
515529    }
516530    _ssl->connect (_client, hostName, _timeout);
517531
518532    auto  status = ssl_handshake_status (*_ssl);
519533    if  (status != SSL_OK) {
520-         _ssl->unref ();
521534        _ssl = nullptr ;
522535        return  0 ;
523536    }
@@ -537,7 +550,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size)
537550    }
538551
539552    if  (rc != SSL_CLOSE_NOTIFY) {
540-         _ssl->unref ();
541553        _ssl = nullptr ;
542554    }
543555
@@ -640,8 +652,6 @@ void WiFiClientSecure::stop()
640652{
641653    if  (_ssl) {
642654        _ssl->stop ();
643-         _ssl->unref ();
644-         _ssl = nullptr ;
645655    }
646656    WiFiClient::stop ();
647657}
@@ -723,9 +733,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name)
723733    String domain_name_str (domain_name);
724734    domain_name_str.toLowerCase ();
725735
726-     const  char * san = NULL ;
736+     const  char * san = nullptr ;
727737    int  i = 0 ;
728-     while  ((san = ssl_get_cert_subject_alt_dnsname (*_ssl, i)) != NULL ) {
738+     while  ((san = ssl_get_cert_subject_alt_dnsname (*_ssl, i)) != nullptr ) {
729739        String san_str (san);
730740        san_str.toLowerCase ();
731741        if  (matchName (san_str, domain_name_str)) {
@@ -759,8 +769,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name)
759769void  WiFiClientSecure::_initSSLContext ()
760770{
761771    if  (!_ssl) {
762-         _ssl = new  SSLContext;
763-         _ssl->ref ();
772+         _ssl = std::make_shared<SSLContext>();
764773    }
765774}
766775
0 commit comments