3232 MultiAccountId ,
3333)
3434from websockets .asyncio .client import connect
35- from websockets .exceptions import ConnectionClosed
35+ from websockets .exceptions import ConnectionClosed , WebSocketException
3636
3737from async_substrate_interface .const import SS58_FORMAT
3838from async_substrate_interface .errors import (
@@ -535,6 +535,8 @@ def __init__(
535535 self ._open_subscriptions = 0
536536 self ._options = options if options else {}
537537 self ._log_raw_websockets = _log_raw_websockets
538+ self ._is_connecting = False
539+ self ._is_closing = False
538540
539541 try :
540542 now = asyncio .get_running_loop ().time ()
@@ -560,38 +562,63 @@ async def __aenter__(self):
560562 async def loop_time () -> float :
561563 return asyncio .get_running_loop ().time ()
562564
565+ async def _cancel (self ):
566+ try :
567+ self ._receiving_task .cancel ()
568+ await self ._receiving_task
569+ await self .ws .close ()
570+ except (
571+ AttributeError ,
572+ asyncio .CancelledError ,
573+ WebSocketException ,
574+ ):
575+ pass
576+ except Exception as e :
577+ logger .warning (
578+ f"{ e } encountered while trying to close websocket connection."
579+ )
580+
563581 async def connect (self , force = False ):
564- now = await self .loop_time ()
565- self .last_received = now
566- self .last_sent = now
567- if self ._exit_task :
568- self ._exit_task .cancel ()
569- async with self ._lock :
570- if not self ._initialized or force :
571- try :
572- self ._receiving_task .cancel ()
573- await self ._receiving_task
574- await self .ws .close ()
575- except (AttributeError , asyncio .CancelledError ):
576- pass
577- self .ws = await asyncio .wait_for (
578- connect (self .ws_url , ** self ._options ), timeout = 10
579- )
580- self ._receiving_task = asyncio .create_task (self ._start_receiving ())
581- self ._initialized = True
582+ self ._is_connecting = True
583+ try :
584+ now = await self .loop_time ()
585+ self .last_received = now
586+ self .last_sent = now
587+ if self ._exit_task :
588+ self ._exit_task .cancel ()
589+ if not self ._is_closing :
590+ if not self ._initialized or force :
591+ try :
592+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
593+ except asyncio .TimeoutError :
594+ pass
595+
596+ self .ws = await asyncio .wait_for (
597+ connect (self .ws_url , ** self ._options ), timeout = 10.0
598+ )
599+ self ._receiving_task = asyncio .get_running_loop ().create_task (
600+ self ._start_receiving ()
601+ )
602+ self ._initialized = True
603+ finally :
604+ self ._is_connecting = False
582605
583606 async def __aexit__ (self , exc_type , exc_val , exc_tb ):
584- async with self ._lock : # TODO is this actually what I want to happen?
585- self ._in_use -= 1
586- if self ._exit_task is not None :
587- self ._exit_task .cancel ()
588- try :
589- await self ._exit_task
590- except asyncio .CancelledError :
591- pass
592- if self ._in_use == 0 and self .ws is not None :
593- self ._open_subscriptions = 0
594- self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
607+ self ._is_closing = True
608+ try :
609+ if not self ._is_connecting :
610+ self ._in_use -= 1
611+ if self ._exit_task is not None :
612+ self ._exit_task .cancel ()
613+ try :
614+ await self ._exit_task
615+ except asyncio .CancelledError :
616+ pass
617+ if self ._in_use == 0 and self .ws is not None :
618+ self ._open_subscriptions = 0
619+ self ._exit_task = asyncio .create_task (self ._exit_with_timer ())
620+ finally :
621+ self ._is_closing = False
595622
596623 async def _exit_with_timer (self ):
597624 """
@@ -605,16 +632,15 @@ async def _exit_with_timer(self):
605632 pass
606633
607634 async def shutdown (self ):
608- async with self ._lock :
609- try :
610- self ._receiving_task .cancel ()
611- await self ._receiving_task
612- await self .ws .close ()
613- except (AttributeError , asyncio .CancelledError ):
614- pass
615- self .ws = None
616- self ._initialized = False
617- self ._receiving_task = None
635+ self ._is_closing = True
636+ try :
637+ await asyncio .wait_for (self ._cancel (), timeout = 10.0 )
638+ except asyncio .TimeoutError :
639+ pass
640+ self .ws = None
641+ self ._initialized = False
642+ self ._receiving_task = None
643+ self ._is_closing = False
618644
619645 async def _recv (self ) -> None :
620646 try :
@@ -624,10 +650,6 @@ async def _recv(self) -> None:
624650 raw_websocket_logger .debug (f"WEBSOCKET_RECEIVE> { recd .decode ()} " )
625651 response = json .loads (recd )
626652 self .last_received = await self .loop_time ()
627- async with self ._lock :
628- # note that these 'subscriptions' are all waiting sent messages which have not received
629- # responses, and thus are not the same as RPC 'subscriptions', which are unique
630- self ._open_subscriptions -= 1
631653 if "id" in response :
632654 self ._received [response ["id" ]] = response
633655 self ._in_use_ids .remove (response ["id" ])
@@ -647,8 +669,7 @@ async def _start_receiving(self):
647669 except asyncio .CancelledError :
648670 pass
649671 except ConnectionClosed :
650- async with self ._lock :
651- await self .connect (force = True )
672+ await self .connect (force = True )
652673
653674 async def send (self , payload : dict ) -> int :
654675 """
@@ -674,8 +695,7 @@ async def send(self, payload: dict) -> int:
674695 self .last_sent = await self .loop_time ()
675696 return original_id
676697 except (ConnectionClosed , ssl .SSLError , EOFError ):
677- async with self ._lock :
678- await self .connect (force = True )
698+ await self .connect (force = True )
679699
680700 async def retrieve (self , item_id : int ) -> Optional [dict ]:
681701 """
@@ -710,6 +730,7 @@ def __init__(
710730 retry_timeout : float = 60.0 ,
711731 _mock : bool = False ,
712732 _log_raw_websockets : bool = False ,
733+ ws_shutdown_timer : float = 5.0 ,
713734 ):
714735 """
715736 The asyncio-compatible version of the subtensor interface commands we use in bittensor. It is important to
@@ -728,6 +749,7 @@ def __init__(
728749 retry_timeout: how to long wait since the last ping to retry the RPC request
729750 _mock: whether to use mock version of the subtensor interface
730751 _log_raw_websockets: whether to log raw websocket requests during RPC requests
752+ ws_shutdown_timer: how long after the last connection your websocket should close
731753
732754 """
733755 self .max_retries = max_retries
@@ -744,6 +766,7 @@ def __init__(
744766 "max_size" : self .ws_max_size ,
745767 "write_limit" : 2 ** 16 ,
746768 },
769+ shutdown_timer = ws_shutdown_timer ,
747770 )
748771 else :
749772 self .ws = AsyncMock (spec = Websocket )
0 commit comments