@@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
284284
285285 async def connect (self ):
286286 """Connects to the Redis server if not already connected"""
287+ await self .connect_check_health (check_health = True )
288+
289+ async def connect_check_health (self , check_health : bool = True ):
287290 if self .is_connected :
288291 return
289292 try :
@@ -302,7 +305,7 @@ async def connect(self):
302305 try :
303306 if not self .redis_connect_func :
304307 # Use the default on_connect function
305- await self .on_connect ( )
308+ await self .on_connect_check_health ( check_health = check_health )
306309 else :
307310 # Use the passed function redis_connect_func
308311 (
@@ -341,6 +344,9 @@ def get_protocol(self):
341344
342345 async def on_connect (self ) -> None :
343346 """Initialize the connection, authenticate and select a database"""
347+ await self .on_connect_check_health (check_health = True )
348+
349+ async def on_connect_check_health (self , check_health : bool = True ) -> None :
344350 self ._parser .on_connect (self )
345351 parser = self ._parser
346352
@@ -398,7 +404,7 @@ async def on_connect(self) -> None:
398404 # update cluster exception classes
399405 self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
400406 self ._parser .on_connect (self )
401- await self .send_command ("HELLO" , self .protocol )
407+ await self .send_command ("HELLO" , self .protocol , check_health = check_health )
402408 response = await self .read_response ()
403409 # if response.get(b"proto") != self.protocol and response.get(
404410 # "proto"
@@ -407,18 +413,35 @@ async def on_connect(self) -> None:
407413
408414 # if a client_name is given, set it
409415 if self .client_name :
410- await self .send_command ("CLIENT" , "SETNAME" , self .client_name )
416+ await self .send_command (
417+ "CLIENT" ,
418+ "SETNAME" ,
419+ self .client_name ,
420+ check_health = check_health ,
421+ )
411422 if str_if_bytes (await self .read_response ()) != "OK" :
412423 raise ConnectionError ("Error setting client name" )
413424
414425 # set the library name and version, pipeline for lower startup latency
415426 if self .lib_name :
416- await self .send_command ("CLIENT" , "SETINFO" , "LIB-NAME" , self .lib_name )
427+ await self .send_command (
428+ "CLIENT" ,
429+ "SETINFO" ,
430+ "LIB-NAME" ,
431+ self .lib_name ,
432+ check_health = check_health ,
433+ )
417434 if self .lib_version :
418- await self .send_command ("CLIENT" , "SETINFO" , "LIB-VER" , self .lib_version )
435+ await self .send_command (
436+ "CLIENT" ,
437+ "SETINFO" ,
438+ "LIB-VER" ,
439+ self .lib_version ,
440+ check_health = check_health ,
441+ )
419442 # if a database is specified, switch to it. Also pipeline this
420443 if self .db :
421- await self .send_command ("SELECT" , self .db )
444+ await self .send_command ("SELECT" , self .db , check_health = check_health )
422445
423446 # read responses from pipeline
424447 for _ in (sent for sent in (self .lib_name , self .lib_version ) if sent ):
@@ -480,8 +503,8 @@ async def send_packed_command(
480503 self , command : Union [bytes , str , Iterable [bytes ]], check_health : bool = True
481504 ) -> None :
482505 if not self .is_connected :
483- await self .connect ( )
484- elif check_health :
506+ await self .connect_check_health ( check_health = False )
507+ if check_health :
485508 await self .check_health ()
486509
487510 try :
0 commit comments