@@ -318,7 +318,9 @@ class TestConnectParams(tb.TestCase):
318318 'result' : ([('host' , 123 )], {
319319 'user' : 'user' ,
320320 'password' : 'passw' ,
321- 'database' : 'testdb' })
321+ 'database' : 'testdb' ,
322+ 'ssl' : True ,
323+ 'ssl_is_advisory' : True })
322324 },
323325
324326 {
@@ -384,7 +386,7 @@ class TestConnectParams(tb.TestCase):
384386 'user' : 'user3' ,
385387 'password' : '123123' ,
386388 'database' : 'abcdef' ,
387- 'ssl' : ssl . SSLContext ,
389+ 'ssl' : True ,
388390 'ssl_is_advisory' : True })
389391 },
390392
@@ -461,7 +463,7 @@ class TestConnectParams(tb.TestCase):
461463 'user' : 'me' ,
462464 'password' : 'ask' ,
463465 'database' : 'db' ,
464- 'ssl' : ssl . SSLContext ,
466+ 'ssl' : True ,
465467 'ssl_is_advisory' : False })
466468 },
467469
@@ -617,7 +619,7 @@ def run_testcase(self, testcase):
617619 password = testcase .get ('password' )
618620 passfile = testcase .get ('passfile' )
619621 database = testcase .get ('database' )
620- ssl = testcase .get ('ssl' )
622+ sslmode = testcase .get ('ssl' )
621623 server_settings = testcase .get ('server_settings' )
622624
623625 expected = testcase .get ('result' )
@@ -640,21 +642,25 @@ def run_testcase(self, testcase):
640642
641643 addrs , params = connect_utils ._parse_connect_dsn_and_args (
642644 dsn = dsn , host = host , port = port , user = user , password = password ,
643- passfile = passfile , database = database , ssl = ssl ,
645+ passfile = passfile , database = database , ssl = sslmode ,
644646 connect_timeout = None , server_settings = server_settings )
645647
646- params = {k : v for k , v in params ._asdict ().items ()
647- if v is not None }
648+ params = {
649+ k : v for k , v in params ._asdict ().items () if v is not None
650+ }
651+
652+ if isinstance (params .get ('ssl' ), ssl .SSLContext ):
653+ params ['ssl' ] = True
648654
649655 result = (addrs , params )
650656
651657 if expected is not None :
652- for k , v in expected [1 ]. items () :
653- # If `expected` contains a type, allow that to "match" any
654- # instance of that type tyat `result` may contain. We need
655- # this because different SSLContexts don't compare equal.
656- if isinstance ( v , type ) and isinstance ( result [ 1 ]. get ( k ), v ):
657- result [ 1 ][ k ] = v
658+ if 'ssl' not in expected [1 ]:
659+ # Avoid the hassle of specifying the default SSL mode
660+ # unless explicitly tested for.
661+ params . pop ( 'ssl' , None )
662+ params . pop ( 'ssl_is_advisory' , None )
663+
658664 self .assertEqual (expected , result , 'Testcase: {}' .format (testcase ))
659665
660666 def test_test_connect_params_environ (self ):
@@ -1063,16 +1069,6 @@ async def verify_fails(sslmode):
10631069 await verify_fails ('verify-ca' )
10641070 await verify_fails ('verify-full' )
10651071
1066- async def test_connection_ssl_unix (self ):
1067- ssl_context = ssl .SSLContext (ssl .PROTOCOL_SSLv23 )
1068- ssl_context .load_verify_locations (SSL_CA_CERT_FILE )
1069-
1070- with self .assertRaisesRegex (asyncpg .InterfaceError ,
1071- 'can only be enabled for TCP addresses' ):
1072- await self .connect (
1073- host = '/tmp' ,
1074- ssl = ssl_context )
1075-
10761072 async def test_connection_implicit_host (self ):
10771073 conn_spec = self .get_connection_spec ()
10781074 con = await asyncpg .connect (
0 commit comments