3131from prompt_toolkit .history import FileHistory
3232from prompt_toolkit .auto_suggest import AutoSuggestFromHistory
3333
34+ from mycli .packages .ssh_client import create_ssh_client
3435from .packages .special .main import NO_QUERY
3536from .packages .prompt_utils import confirm , confirm_destructive_query
3637from .packages .tabular_output import sql_format
37- from .packages import special
38+ from .packages import special , ssh_client
3839from .packages .special .favoritequeries import FavoriteQueries
3940from .sqlcompleter import SQLCompleter
4041from .clitoolbar import create_toolbar_tokens_func
6364 from urllib .parse import unquote
6465
6566
66- try :
67- import paramiko
68- except ImportError :
69- from mycli .packages .paramiko_stub import paramiko
70-
7167# Query tuples are used for maintaining history
7268Query = namedtuple ('Query' , ['query' , 'successful' , 'mutating' ])
7369
@@ -198,6 +194,8 @@ def __init__(self, sqlexecute=None, prompt=None,
198194
199195 self .prompt_app = None
200196
197+ self .ssh_client = None
198+
201199 def register_special_commands (self ):
202200 special .register_special_command (self .change_db , 'use' ,
203201 '\\ u' , 'Change to a new database.' , aliases = ('\\ u' ,))
@@ -358,9 +356,7 @@ def merge_ssl_with_cnf(self, ssl, cnf):
358356 return merged
359357
360358 def connect (self , database = '' , user = '' , passwd = '' , host = '' , port = '' ,
361- socket = '' , charset = '' , local_infile = '' , ssl = '' ,
362- ssh_user = '' , ssh_host = '' , ssh_port = '' ,
363- ssh_password = '' , ssh_key_filename = '' ):
359+ socket = '' , charset = '' , local_infile = '' , ssl = None ):
364360
365361 cnf = {'database' : None ,
366362 'user' : None ,
@@ -384,7 +380,7 @@ def connect(self, database='', user='', passwd='', host='', port='',
384380
385381 database = database or cnf ['database' ]
386382 # Socket interface not supported for SSH connections
387- if port or host or ssh_host or ssh_port :
383+ if port or host or self . ssh_client :
388384 socket = ''
389385 else :
390386 socket = socket or cnf ['socket' ] or guess_socket_location ()
@@ -416,17 +412,15 @@ def _connect():
416412 try :
417413 self .sqlexecute = SQLExecute (
418414 database , user , passwd , host , port , socket , charset ,
419- local_infile , ssl , ssh_user , ssh_host , ssh_port ,
420- ssh_password , ssh_key_filename
415+ local_infile , ssl , ssh_client = self .ssh_client
421416 )
422417 except OperationalError as e :
423418 if ('Access denied for user' in e .args [1 ]):
424419 new_passwd = click .prompt ('Password' , hide_input = True ,
425420 show_default = False , type = str , err = True )
426421 self .sqlexecute = SQLExecute (
427422 database , user , new_passwd , host , port , socket ,
428- charset , local_infile , ssl , ssh_user , ssh_host ,
429- ssh_port , ssh_password , ssh_key_filename
423+ charset , local_infile , ssl , ssh_client = self .ssh_client
430424 )
431425 else :
432426 raise e
@@ -1092,16 +1086,17 @@ def cli(database, user, host, port, socket, password, dbname,
10921086 else :
10931087 click .secho (alias )
10941088 sys .exit (0 )
1089+
10951090 if list_ssh_config :
1096- ssh_config = read_ssh_config (ssh_config_path )
1097- for host in ssh_config . get_hostnames ():
1091+ hosts = ssh_client . get_config_hosts (ssh_config_path )
1092+ for host , hostname in hosts . items ():
10981093 if verbose :
1099- host_config = ssh_config .lookup (host )
11001094 click .secho ("{} : {}" .format (
1101- host , host_config . get ( ' hostname' ) ))
1095+ host , hostname ))
11021096 else :
11031097 click .secho (host )
11041098 sys .exit (0 )
1099+
11051100 # Choose which ever one has a valid value.
11061101 database = dbname or database
11071102
@@ -1153,7 +1148,7 @@ def cli(database, user, host, port, socket, password, dbname,
11531148 port = uri .port
11541149
11551150 if ssh_config_host :
1156- ssh_config = read_ssh_config (
1151+ ssh_config = ssh_client . read_config_file (
11571152 ssh_config_path
11581153 ).lookup (ssh_config_host )
11591154 ssh_host = ssh_host if ssh_host else ssh_config .get ('hostname' )
@@ -1164,7 +1159,10 @@ def cli(database, user, host, port, socket, password, dbname,
11641159 ssh_key_filename = ssh_key_filename if ssh_key_filename else ssh_config .get (
11651160 'identityfile' , [None ])[0 ]
11661161
1167- ssh_key_filename = ssh_key_filename and os .path .expanduser (ssh_key_filename )
1162+ if ssh_host :
1163+ mycli .ssh_client = create_ssh_client (
1164+ ssh_host , ssh_port , ssh_user , ssh_password , ssh_key_filename
1165+ )
11681166
11691167 mycli .connect (
11701168 database = database ,
@@ -1175,11 +1173,6 @@ def cli(database, user, host, port, socket, password, dbname,
11751173 socket = socket ,
11761174 local_infile = local_infile ,
11771175 ssl = ssl ,
1178- ssh_user = ssh_user ,
1179- ssh_host = ssh_host ,
1180- ssh_port = ssh_port ,
1181- ssh_password = ssh_password ,
1182- ssh_key_filename = ssh_key_filename
11831176 )
11841177
11851178 mycli .logger .debug ('Launch Params: \n '
@@ -1298,26 +1291,5 @@ def edit_and_execute(event):
12981291 buff .open_in_editor (validate_and_handle = False )
12991292
13001293
1301- def read_ssh_config (ssh_config_path ):
1302- ssh_config = paramiko .config .SSHConfig ()
1303- try :
1304- with open (ssh_config_path ) as f :
1305- ssh_config .parse (f )
1306- # Paramiko prior to version 2.7 raises Exception on parse errors.
1307- # In 2.7 it has become paramiko.ssh_exception.SSHException,
1308- # but let's catch everything for compatibility
1309- except Exception as err :
1310- click .secho (
1311- f'Could not parse SSH configuration file { ssh_config_path } :\n { err } ' ,
1312- err = True , fg = 'red'
1313- )
1314- sys .exit (1 )
1315- except FileNotFoundError as e :
1316- click .secho (str (e ), err = True , fg = 'red' )
1317- sys .exit (1 )
1318- else :
1319- return ssh_config
1320-
1321-
13221294if __name__ == "__main__" :
13231295 cli ()
0 commit comments