Skip to content

Commit 33b5f47

Browse files
committed
Add some configurations for task based subscription processing
- Propagate subscription manager tasks exceptions to the main loop. - Allow configuring exceptions to be logged silently, rather than raised. - Add task cleanup.
1 parent 81a3780 commit 33b5f47

File tree

6 files changed

+237
-78
lines changed

6 files changed

+237
-78
lines changed

tests/core/subscriptions/test_subscription_manager.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import pytest
2+
import asyncio
23
import itertools
4+
import logging
5+
import time
36
from unittest.mock import (
47
AsyncMock,
58
)
@@ -14,8 +17,12 @@
1417
PersistentConnectionProvider,
1518
)
1619
from web3.exceptions import (
20+
SubscriptionHandlerTaskException,
1721
Web3ValueError,
1822
)
23+
from web3.providers.persistent.request_processor import (
24+
TaskReliantQueue,
25+
)
1926
from web3.providers.persistent.subscription_manager import (
2027
SubscriptionManager,
2128
)
@@ -213,3 +220,161 @@ async def test_unsubscribe_with_subscriptions_reference_does_not_mutate_the_list
213220

214221
await subscription_manager.unsubscribe_all()
215222
assert subscription_manager.subscriptions == []
223+
224+
225+
@pytest.mark.asyncio
226+
async def test_high_throughput_subscription_task_based(
227+
subscription_manager,
228+
) -> None:
229+
provider = subscription_manager._w3.provider
230+
num_msgs = 5_000
231+
232+
provider._request_processor._handler_subscription_queue = TaskReliantQueue(
233+
maxsize=num_msgs
234+
)
235+
236+
# Turn on task-based processing. This test should fail the time constraint if this
237+
# is not set to ``True`` (not task-based processing).
238+
subscription_manager.task_based = True
239+
240+
class Counter:
241+
val: int = 0
242+
243+
counter = Counter()
244+
245+
async def high_throughput_handler(handler_context) -> None:
246+
# if we awaited all `num_msgs`, we would sleep at least 5 seconds total
247+
await asyncio.sleep(5 / num_msgs)
248+
249+
handler_context.counter.val += 1
250+
if handler_context.counter.val == num_msgs:
251+
await handler_context.subscription.unsubscribe()
252+
253+
# build a meaningless subscription since we are fabricating the messages
254+
sub_id = await subscription_manager.subscribe(
255+
NewHeadsSubscription(
256+
handler=high_throughput_handler, handler_context={"counter": counter}
257+
),
258+
)
259+
provider._request_processor.cache_request_information(
260+
request_id=sub_id,
261+
method="eth_subscribe",
262+
params=[],
263+
response_formatters=((), (), ()),
264+
)
265+
266+
# put `num_msgs` messages in the queue
267+
for _ in range(num_msgs):
268+
provider._request_processor._handler_subscription_queue.put_nowait(
269+
{
270+
"jsonrpc": "2.0",
271+
"method": "eth_subscription",
272+
"params": {"subscription": sub_id, "result": "0x0"},
273+
}
274+
)
275+
276+
start = time.time()
277+
await subscription_manager.handle_subscriptions()
278+
stop = time.time()
279+
280+
assert counter.val == num_msgs
281+
282+
assert subscription_manager.total_handler_calls == num_msgs
283+
assert stop - start < 3, "subscription handling took too long!"
284+
285+
286+
@pytest.mark.asyncio
287+
async def test_task_based_subscription_handling_error_propagation(
288+
subscription_manager,
289+
) -> None:
290+
provider = subscription_manager._w3.provider
291+
subscription_manager.task_based = True
292+
293+
async def high_throughput_handler(_handler_context) -> None:
294+
raise ValueError("Test error msg.")
295+
296+
# build a meaningless subscription since we are fabricating the messages
297+
sub_id = await subscription_manager.subscribe(
298+
NewHeadsSubscription(handler=high_throughput_handler)
299+
)
300+
provider._request_processor.cache_request_information(
301+
request_id=sub_id,
302+
method="eth_subscribe",
303+
params=[],
304+
response_formatters=((), (), ()),
305+
)
306+
provider._request_processor._handler_subscription_queue.put_nowait(
307+
{
308+
"jsonrpc": "2.0",
309+
"method": "eth_subscription",
310+
"params": {"subscription": sub_id, "result": "0x0"},
311+
}
312+
)
313+
314+
with pytest.raises(
315+
SubscriptionHandlerTaskException,
316+
match="Test error msg.",
317+
):
318+
await subscription_manager.handle_subscriptions()
319+
320+
321+
@pytest.mark.asyncio
322+
async def test_task_based_subscription_handling_ignore_errors(
323+
subscription_manager, caplog
324+
) -> None:
325+
provider = subscription_manager._w3.provider
326+
subscription_manager.task_based = True
327+
subscription_manager.ignore_task_exceptions = True
328+
329+
class TestObject:
330+
exception = None
331+
332+
async def sub_handler(handler_context) -> None:
333+
if handler_context.obj.exception:
334+
# on the second call, yield to loop so we log, unsubscribe, and return
335+
await asyncio.sleep(0.01)
336+
await handler_context.subscription.unsubscribe()
337+
return
338+
339+
e = ValueError("Test error msg.")
340+
handler_context.obj.exception = e
341+
raise e
342+
343+
# build a meaningless subscription since we are fabricating the messages
344+
test_obj = TestObject()
345+
sub_id = await subscription_manager.subscribe(
346+
NewHeadsSubscription(handler=sub_handler, handler_context={"obj": test_obj})
347+
)
348+
provider._request_processor.cache_request_information(
349+
request_id=sub_id,
350+
method="eth_subscribe",
351+
params=[],
352+
response_formatters=((), (), ()),
353+
)
354+
for _ in range(2):
355+
provider._request_processor._handler_subscription_queue.put_nowait(
356+
{
357+
"jsonrpc": "2.0",
358+
"method": "eth_subscription",
359+
"params": {"subscription": sub_id, "result": "0x0"},
360+
}
361+
)
362+
363+
with caplog.at_level(
364+
logging.WARNING, logger="web3.providers.persistent.subscription_manager"
365+
):
366+
await subscription_manager.handle_subscriptions()
367+
368+
# find the warning so and assert it was logged
369+
warning_records = [r for r in caplog.records if r.levelname == "WARNING"]
370+
assert len(warning_records) == 1
371+
record = warning_records[0]
372+
assert (
373+
"An exception occurred in a subscription handler task but was ignored, `"
374+
"`ignore_task_exceptions==True``." in record.message
375+
)
376+
await subscription_manager.handle_subscriptions()
377+
378+
assert subscription_manager.total_handler_calls == 2
379+
assert subscription_manager.subscriptions == []
380+
assert subscription_manager._tasks == set()

web3/_utils/module_testing/persistent_connection_provider.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from dataclasses import (
44
dataclass,
55
)
6-
import time
76
from typing import (
87
TYPE_CHECKING,
98
Any,
@@ -39,9 +38,6 @@
3938
from web3.middleware import (
4039
ExtraDataToPOAMiddleware,
4140
)
42-
from web3.providers.persistent.request_processor import (
43-
TaskReliantQueue,
44-
)
4541
from web3.types import (
4642
BlockData,
4743
FormattedEthSubscriptionResponse,
@@ -918,59 +914,3 @@ async def test_run_forever_starts_with_0_subs_and_runs_until_task_cancelled(
918914

919915
# cleanup
920916
await clean_up_task(run_forever_task)
921-
922-
@pytest.mark.asyncio
923-
async def test_high_throughput_subscription_task_based(
924-
self, async_w3: AsyncWeb3
925-
) -> None:
926-
async_w3.provider._request_processor._handler_subscription_queue = (
927-
TaskReliantQueue(maxsize=5_000)
928-
)
929-
sub_manager = async_w3.subscription_manager
930-
sub_manager.task_based = True # turn on task-based processing
931-
932-
class Counter:
933-
val: int = 0
934-
935-
counter = Counter()
936-
937-
async def high_throughput_handler(
938-
handler_context: Any,
939-
) -> None:
940-
handler_context.counter.val += 1
941-
if handler_context.counter.val == 5_000:
942-
await handler_context.subscription.unsubscribe()
943-
# if we awaited all 5_000 messages, we would sleep at least 5 seconds
944-
await asyncio.sleep(5 // 5_000)
945-
946-
# build a meaningless subscription since we are fabricating the messages
947-
sub_id = await async_w3.eth.subscribe(
948-
"syncing",
949-
handler=high_throughput_handler,
950-
handler_context={"counter": counter},
951-
)
952-
async_w3.provider._request_processor.cache_request_information(
953-
request_id=sub_id,
954-
method=RPCEndpoint("eth_subscribe"),
955-
params=[],
956-
response_formatters=((), (), ()), # type: ignore
957-
)
958-
959-
# put 5_000 messages in the queue
960-
for _ in range(5_000):
961-
async_w3.provider._request_processor._handler_subscription_queue.put_nowait(
962-
{
963-
"jsonrpc": "2.0",
964-
"method": "eth_subscription",
965-
"params": {"subscription": HexBytes(sub_id), "result": False},
966-
}
967-
)
968-
969-
start = time.time()
970-
await sub_manager.handle_subscriptions()
971-
stop = time.time()
972-
973-
assert counter.val == 5_000
974-
975-
assert sub_manager.total_handler_calls == 5_000
976-
assert stop - start < 3, "subscription handling took too long!"

web3/exceptions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,13 +353,20 @@ class PersistentConnectionClosedOK(PersistentConnectionError):
353353
"""
354354

355355

356-
class SubscriptionProcessingFinished(Web3Exception):
356+
class SubscriptionProcessingFinished(PersistentConnectionError):
357357
"""
358358
Raised to alert the subscription manager that the processing of subscriptions
359359
has finished.
360360
"""
361361

362362

363+
class SubscriptionHandlerTaskException(TaskNotRunning):
364+
"""
365+
Raised to alert the subscription manager that an exception occurred in the
366+
subscription processing task.
367+
"""
368+
369+
363370
class Web3RPCError(Web3Exception):
364371
"""
365372
Raised when a JSON-RPC response contains an error field.

web3/manager.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -550,13 +550,10 @@ async def _message_stream(
550550
else:
551551
# if not an active sub, skip processing and continue
552552
continue
553-
except TaskNotRunning:
553+
except TaskNotRunning as e:
554554
await asyncio.sleep(0)
555555
self._provider._handle_listener_task_exceptions()
556-
self.logger.error(
557-
"Message listener background task has stopped unexpectedly. "
558-
"Stopping message stream."
559-
)
556+
self.logger.error("Stopping message stream: %s", e.message)
560557
return
561558

562559
async def _process_response(

web3/providers/persistent/persistent.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ async def send_batch_func(
164164
if cache_key != self._send_batch_func_cache[0]:
165165

166166
async def send_func(
167-
requests: List[Tuple[RPCEndpoint, Any]]
167+
requests: List[Tuple[RPCEndpoint, Any]],
168168
) -> List[RPCRequest]:
169169
for mw in middleware:
170170
initialized = mw(async_w3)
@@ -376,11 +376,12 @@ def _message_listener_callback(
376376
) -> None:
377377
# Puts a `TaskNotRunning` in appropriate queues to signal the end of the
378378
# listener task to any listeners relying on the queues.
379+
message = "Message listener task has ended."
379380
self._request_processor._subscription_response_queue.put_nowait(
380-
TaskNotRunning(message_listener_task)
381+
TaskNotRunning(message_listener_task, message=message)
381382
)
382383
self._request_processor._handler_subscription_queue.put_nowait(
383-
TaskNotRunning(message_listener_task)
384+
TaskNotRunning(message_listener_task, message=message)
384385
)
385386

386387
def _raise_stray_errors_from_cache(self) -> None:

0 commit comments

Comments
 (0)