1- import io
21import paramiko
32import requests .adapters
43import six
54import logging
65import os
6+ import signal
77import socket
88import subprocess
99
2323RecentlyUsedContainer = urllib3 ._collections .RecentlyUsedContainer
2424
2525
26- def create_paramiko_client (base_url ):
27- logging .getLogger ("paramiko" ).setLevel (logging .WARNING )
28- ssh_client = paramiko .SSHClient ()
29- base_url = six .moves .urllib_parse .urlparse (base_url )
30- ssh_params = {
31- "hostname" : base_url .hostname ,
32- "port" : base_url .port ,
33- "username" : base_url .username
34- }
35- ssh_config_file = os .path .expanduser ("~/.ssh/config" )
36- if os .path .exists (ssh_config_file ):
37- conf = paramiko .SSHConfig ()
38- with open (ssh_config_file ) as f :
39- conf .parse (f )
40- host_config = conf .lookup (base_url .hostname )
41- ssh_conf = host_config
42- if 'proxycommand' in host_config :
43- ssh_params ["sock" ] = paramiko .ProxyCommand (
44- ssh_conf ['proxycommand' ]
45- )
46- if 'hostname' in host_config :
47- ssh_params ['hostname' ] = host_config ['hostname' ]
48- if 'identityfile' in host_config :
49- ssh_params ['key_filename' ] = host_config ['identityfile' ]
50- if base_url .port is None and 'port' in host_config :
51- ssh_params ['port' ] = ssh_conf ['port' ]
52- if base_url .username is None and 'user' in host_config :
53- ssh_params ['username' ] = ssh_conf ['user' ]
54-
55- ssh_client .load_system_host_keys ()
56- ssh_client .set_missing_host_key_policy (paramiko .WarningPolicy ())
57- return ssh_client , ssh_params
58-
59-
6026class SSHSocket (socket .socket ):
6127 def __init__ (self , host ):
6228 super (SSHSocket , self ).__init__ (
@@ -80,7 +46,8 @@ def connect(self, **kwargs):
8046 ' ' .join (args ),
8147 shell = True ,
8248 stdout = subprocess .PIPE ,
83- stdin = subprocess .PIPE )
49+ stdin = subprocess .PIPE ,
50+ preexec_fn = lambda : signal .signal (signal .SIGINT , signal .SIG_IGN ))
8451
8552 def _write (self , data ):
8653 if not self .proc or self .proc .stdin .closed :
@@ -96,17 +63,18 @@ def sendall(self, data):
9663 def send (self , data ):
9764 return self ._write (data )
9865
99- def recv (self ):
66+ def recv (self , n ):
10067 if not self .proc :
10168 raise Exception ('SSH subprocess not initiated.'
10269 'connect() must be called first.' )
103- return self .proc .stdout .read ()
70+ return self .proc .stdout .read (n )
10471
10572 def makefile (self , mode ):
106- if not self .proc or self .proc .stdout .closed :
107- buf = io .BytesIO ()
108- buf .write (b'\n \n ' )
109- return buf
73+ if not self .proc :
74+ self .connect ()
75+ if six .PY3 :
76+ self .proc .stdout .channel = self
77+
11078 return self .proc .stdout
11179
11280 def close (self ):
@@ -124,15 +92,15 @@ def __init__(self, ssh_transport=None, timeout=60, host=None):
12492 )
12593 self .ssh_transport = ssh_transport
12694 self .timeout = timeout
127- self .host = host
95+ self .ssh_host = host
12896
12997 def connect (self ):
13098 if self .ssh_transport :
13199 sock = self .ssh_transport .open_session ()
132100 sock .settimeout (self .timeout )
133101 sock .exec_command ('docker system dial-stdio' )
134102 else :
135- sock = SSHSocket (self .host )
103+ sock = SSHSocket (self .ssh_host )
136104 sock .settimeout (self .timeout )
137105 sock .connect ()
138106
@@ -147,16 +115,16 @@ def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
147115 'localhost' , timeout = timeout , maxsize = maxsize
148116 )
149117 self .ssh_transport = None
118+ self .timeout = timeout
150119 if ssh_client :
151120 self .ssh_transport = ssh_client .get_transport ()
152- self .timeout = timeout
153- self .host = host
154- self .port = None
121+ self .ssh_host = host
122+ self .ssh_port = None
155123 if ':' in host :
156- self .host , self .port = host .split (':' )
124+ self .ssh_host , self .ssh_port = host .split (':' )
157125
158126 def _new_conn (self ):
159- return SSHConnection (self .ssh_transport , self .timeout , self .host )
127+ return SSHConnection (self .ssh_transport , self .timeout , self .ssh_host )
160128
161129 # When re-using connections, urllib3 calls fileno() on our
162130 # SSH channel instance, quickly overloading our fd limit. To avoid this,
@@ -193,22 +161,59 @@ def __init__(self, base_url, timeout=60,
193161 shell_out = True ):
194162 self .ssh_client = None
195163 if not shell_out :
196- self .ssh_client , self . ssh_params = create_paramiko_client (base_url )
164+ self ._create_paramiko_client (base_url )
197165 self ._connect ()
198- base_url = base_url . lstrip ( 'ssh://' )
199- self .host = base_url
166+
167+ self .ssh_host = base_url . lstrip ( 'ssh://' )
200168 self .timeout = timeout
201169 self .max_pool_size = max_pool_size
202170 self .pools = RecentlyUsedContainer (
203171 pool_connections , dispose_func = lambda p : p .close ()
204172 )
205173 super (SSHHTTPAdapter , self ).__init__ ()
206174
175+ def _create_paramiko_client (self , base_url ):
176+ logging .getLogger ("paramiko" ).setLevel (logging .WARNING )
177+ self .ssh_client = paramiko .SSHClient ()
178+ base_url = six .moves .urllib_parse .urlparse (base_url )
179+ self .ssh_params = {
180+ "hostname" : base_url .hostname ,
181+ "port" : base_url .port ,
182+ "username" : base_url .username
183+ }
184+ ssh_config_file = os .path .expanduser ("~/.ssh/config" )
185+ if os .path .exists (ssh_config_file ):
186+ conf = paramiko .SSHConfig ()
187+ with open (ssh_config_file ) as f :
188+ conf .parse (f )
189+ host_config = conf .lookup (base_url .hostname )
190+ self .ssh_conf = host_config
191+ if 'proxycommand' in host_config :
192+ self .ssh_params ["sock" ] = paramiko .ProxyCommand (
193+ self .ssh_conf ['proxycommand' ]
194+ )
195+ if 'hostname' in host_config :
196+ self .ssh_params ['hostname' ] = host_config ['hostname' ]
197+ if base_url .port is None and 'port' in host_config :
198+ self .ssh_params ['port' ] = self .ssh_conf ['port' ]
199+ if base_url .username is None and 'user' in host_config :
200+ self .ssh_params ['username' ] = self .ssh_conf ['user' ]
201+
202+ self .ssh_client .load_system_host_keys ()
203+ self .ssh_client .set_missing_host_key_policy (paramiko .WarningPolicy ())
204+
207205 def _connect (self ):
208206 if self .ssh_client :
209207 self .ssh_client .connect (** self .ssh_params )
210208
211209 def get_connection (self , url , proxies = None ):
210+ if not self .ssh_client :
211+ return SSHConnectionPool (
212+ ssh_client = self .ssh_client ,
213+ timeout = self .timeout ,
214+ maxsize = self .max_pool_size ,
215+ host = self .ssh_host
216+ )
212217 with self .pools .lock :
213218 pool = self .pools .get (url )
214219 if pool :
@@ -222,7 +227,7 @@ def get_connection(self, url, proxies=None):
222227 ssh_client = self .ssh_client ,
223228 timeout = self .timeout ,
224229 maxsize = self .max_pool_size ,
225- host = self .host
230+ host = self .ssh_host
226231 )
227232 self .pools [url ] = pool
228233
0 commit comments