From 5fd3e67810b0fe651b138f74c0ee534b409c27ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Sat, 6 Jun 2020 16:30:41 +0200 Subject: [PATCH 01/19] New transport class to handle Phoenix channels --- gql/transport/phoenix_channel_websockets.py | 207 ++++++++++++++++++++ gql/transport/websockets.py | 19 +- tests/conftest.py | 28 ++- tests/test_async_client_validation.py | 8 +- tests/test_phoenix_channel_query.py | 69 +++++++ tests/test_phoenix_channel_subscription.py | 175 +++++++++++++++++ tests/test_websocket_exceptions.py | 18 +- tests/test_websocket_query.py | 16 +- tests/test_websocket_subscription.py | 14 +- 9 files changed, 514 insertions(+), 40 deletions(-) create mode 100644 gql/transport/phoenix_channel_websockets.py create mode 100644 tests/test_phoenix_channel_query.py create mode 100644 tests/test_phoenix_channel_subscription.py diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py new file mode 100644 index 00000000..84aff78e --- /dev/null +++ b/gql/transport/phoenix_channel_websockets.py @@ -0,0 +1,207 @@ +import asyncio +import json +from typing import Dict, Optional, Tuple + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, +) +from .websockets import WebsocketsTransport + + +class PhoenixChannelWebsocketsTransport(WebsocketsTransport): + def __init__(self, channel_name, heartbeat_interval=30, *args, **kwargs) -> None: + self.channel_name = channel_name + self.heartbeat_interval = heartbeat_interval + self.subscription_ids_to_query_ids: Dict[str, int] = {} + super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) + + async def _send_init_message_and_wait_ack(self) -> None: + """Join the specified channel and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + init_message = json.dumps( + { + "topic": self.channel_name, + "event": "phx_join", + "payload": {}, + "ref": query_id, + } + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout) + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type != "reply": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def heartbeat_coro(): + while True: + await asyncio.sleep(self.heartbeat_interval) + await self._send(json.dumps({"topic": "phoenix", "event": "heartbeat"})) + + self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) + + async def _send_stop_message(self, query_id: int) -> None: + pass + + async def _send_connection_terminate_message(self) -> None: + """Send a phx_leave message to disconnect from the provided channel. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + connection_terminate_message = json.dumps( + { + "topic": self.channel_name, + "event": "phx_leave", + "payload": {}, + "ref": query_id, + } + ) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + { + "topic": self.channel_name, + "event": "doc", + "payload": { + "query": print_ast(document), + "variables": variable_values or {}, + }, + "ref": query_id, + } + ) + + await self._send(query_str) + + return query_id + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server + + Returns a list consisting of: + - the answer_type (between: + 'heartbeat', 'data', 'reply', 'error', 'close') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + event: str = "" + answer_id: Optional[int] = None + answer_type: str = "" + execution_result: Optional[ExecutionResult] = None + + try: + json_answer = json.loads(answer) + + event = str(json_answer.get("event")) + + if event == "subscription:data": + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + subscription_id = str(payload.get("subscriptionId")) + answer_id = self.subscription_ids_to_query_ids[subscription_id] + result = payload.get("result") + + if not isinstance(result, dict): + raise ValueError("result is not a dict") + + answer_type = "data" + + execution_result = ExecutionResult( + errors=payload.get("errors"), data=result.get("data") + ) + + elif event == "phx_reply": + answer_id = int(json_answer.get("ref")) + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + status = str(payload.get("status")) + + if status == "ok": + + answer_type = "reply" + response = payload.get("response") + + if isinstance(response, dict) and "subscriptionId" in response: + subscription_id = str(response.get("subscriptionId")) + self.subscription_ids_to_query_ids[subscription_id] = answer_id + + elif status == "error": + response = payload.get("response") + + if isinstance(response, dict): + raise TransportQueryError( + response.get("reason"), query_id=answer_id + ) + else: + raise ValueError("reply error") + + elif status == "timeout": + raise ValueError("reply timeout") + + elif event == "phx_error": + raise TransportServerError("Server error") + elif event == "phx_close": + answer_type = "close" + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + "Server did not return a GraphQL result" + ) from e + + return answer_type, answer_id, execution_result + + async def _handle_answer(self, answer_id, answer_type, execution_result) -> None: + if answer_type == "close": + for listener in self.listeners.values(): + await listener.put(("complete", execution_result)) + else: + await super()._handle_answer(answer_id, answer_type, execution_result) + + async def close(self) -> None: + self.heartbeat_task.cancel() + await super().close() diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 081c677b..954d7d79 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -348,19 +348,20 @@ async def _receive_data_loop(self) -> None: await self._fail(e, clean_close=False) break - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put( - (answer_type, execution_result) - ) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass + await self._handle_answer(answer_id, answer_type, execution_result) finally: log.debug("Exiting _receive_data_loop()") + async def _handle_answer(self, answer_id, answer_type, execution_result) -> None: + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + async def subscribe( self, document: DocumentNode, diff --git a/tests/conftest.py b/tests/conftest.py index 6d345953..8ce81f8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,6 +135,8 @@ async def stop(self): print("Server stopped\n\n\n") + +class WebSocketServerHelper: @staticmethod async def send_complete(ws, query_id): await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') @@ -164,6 +166,26 @@ async def wait_connection_terminate(ws): assert json_result["type"] == "connection_terminate" +class PhoenixChannelServerHelper: + @staticmethod + async def send_close(ws): + await ws.send('{"event":"phx_close"}') + + @staticmethod + async def send_connection_ack(ws): + + # Line return for easy debugging + print("") + + # Wait for init + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "phx_join" + + # Send ack + await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}') + + def get_server_handler(request): """Get the server handler. @@ -180,7 +202,7 @@ def get_server_handler(request): async def default_server_handler(ws, path): try: - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) query_id = 1 for answer in answers: @@ -194,10 +216,10 @@ async def default_server_handler(ws, path): formatted_answer = answer await ws.send(formatted_answer) - await WebSocketServer.send_complete(ws, query_id) + await WebSocketServerHelper.send_complete(ws, query_id) query_id += 1 - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() except ConnectionClosed: pass diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index ec651866..8558fafd 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -8,7 +8,7 @@ from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef starwars_expected_one = { @@ -25,7 +25,7 @@ async def server_starwars(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) try: await ws.recv() @@ -42,8 +42,8 @@ async def server_starwars(ws, path): await ws.send(data) await asyncio.sleep(2 * MS) - await WebSocketServer.send_complete(ws, 1) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.wait_connection_terminate(ws) except websockets.exceptions.ConnectionClosedOK: pass diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py new file mode 100644 index 00000000..d59050ac --- /dev/null +++ b/tests/test_phoenix_channel_query.py @@ -0,0 +1,69 @@ +import pytest + +from gql import Client, gql +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import PhoenixChannelServerHelper + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +query1_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"test_subscription","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":3,' + '"topic":"test_topic"}' +) + + +@pytest.fixture +def ws_server_helper(request): + yield PhoenixChannelServerHelper + + +async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(query1_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [phoenix_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_simple_query(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + result = await session.execute(query) + + print("Client received:", result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py new file mode 100644 index 00000000..8efc9899 --- /dev/null +++ b/tests/test_phoenix_channel_subscription.py @@ -0,0 +1,175 @@ +import asyncio +import json + +import pytest +import websockets +from parse import search + +from gql import Client, gql +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import MS, PhoenixChannelServerHelper + +subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +countdown_server_answer = ( + '{{"event":"subscription:data",' + '"payload":{{"subscriptionId":"test_subscription","result":' + '{{"data":{{"number":{number}}}}}}},' + '"ref":{query_id}}}' +) + + +async def server_countdown(ws, path): + try: + await PhoenixChannelServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "doc" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["ref"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + await ws.send(subscription_server_answer) + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + result = await ws.recv() + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + stopping_task = asyncio.ensure_future(stopping_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + await PhoenixChannelServerHelper.send_close(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_phoenix_channel_subscription(event_loop, server, subscription_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with Client(transport=sample_transport) as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +heartbeat_server_answer = ( + '{{"event":"subscription:data",' + '"payload":{{"subscriptionId":"test_subscription","result":' + '{{"data":{{"heartbeat_count":{count}}}}}}},' + '"ref":1}}' +) + + +async def phoenix_heartbeat_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + + for i in range(3): + heartbeat_result = await ws.recv() + json_result = json.loads(heartbeat_result) + assert json_result["event"] == "heartbeat" + await ws.send(heartbeat_server_answer.format(count=i)) + + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +heartbeat_subscription_str = """ + subscription { + heartbeat { + heartbeat_count + } + } +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [phoenix_heartbeat_server], indirect=True) +@pytest.mark.parametrize("subscription_str", [heartbeat_subscription_str]) +async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, heartbeat_interval=1 + ) + + subscription = gql(heartbeat_subscription_str) + async with Client(transport=sample_transport) as session: + i = 0 + async for result in session.subscribe(subscription): + heartbeat_count = result["heartbeat_count"] + print(f"Heartbeat count received: {heartbeat_count}") + + assert heartbeat_count == i + i += 1 diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index a5482d43..ee7129fb 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -13,7 +13,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper invalid_query_str = """ query getContinents { @@ -59,10 +59,10 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) async def server_invalid_subscription(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -105,7 +105,7 @@ async def test_websocket_server_does_not_send_ack(event_loop, server, query_str) async def server_connection_error(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(connection_error_server_answer) @@ -132,11 +132,11 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que async def server_invalid_payload(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(invalid_payload_server_answer) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() @@ -215,7 +215,7 @@ async def test_websocket_transport_protocol_errors(event_loop, client_and_server async def server_without_ack(ws, path): # Sending something else than an ack - await WebSocketServer.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) await ws.wait_closed() @@ -252,7 +252,7 @@ async def test_websocket_server_closing_directly(event_loop, server): async def server_closing_after_ack(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) await ws.close() @@ -274,7 +274,7 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) async def server_sending_invalid_query_errors(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) invalid_error = ( '{"type":"error","id":"404","payload":' '{"message":"error for no good reason on non existing query"}}' diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 4d235d95..069d1aab 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -15,7 +15,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper query1_str = """ query getContinents { @@ -153,16 +153,16 @@ async def test_websocket_two_queries_in_series( async def server1_two_queries_in_parallel(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) await ws.send(query1_server_answer.format(query_id=2)) - await WebSocketServer.send_complete(ws, 1) - await WebSocketServer.send_complete(ws, 2) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 2) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() @@ -200,11 +200,11 @@ async def task2_coro(): async def server_closing_while_we_are_doing_something_else(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await asyncio.sleep(1 * MS) # Closing server after first query @@ -350,7 +350,7 @@ async def server_with_authentication_in_connection_init_payload(ws, path): result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) else: await ws.send( '{"type":"connection_error", "payload": "Invalid Authorization token"}' diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 2a9942ff..49ec0155 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -8,7 +8,7 @@ from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper countdown_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' @@ -20,9 +20,9 @@ async def server_countdown(ws, path): global WITH_KEEPALIVE try: - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) if WITH_KEEPALIVE: - await WebSocketServer.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) result = await ws.recv() json_result = json.loads(result) @@ -58,7 +58,7 @@ async def stopping_coro(): async def keepalive_coro(): while True: await asyncio.sleep(5 * MS) - await WebSocketServer.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) stopping_task = asyncio.ensure_future(stopping_coro()) keepalive_task = asyncio.ensure_future(keepalive_coro()) @@ -82,8 +82,8 @@ async def keepalive_coro(): except asyncio.CancelledError: print("Now keepalive task is cancelled") - await WebSocketServer.send_complete(ws, query_id) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, query_id) + await WebSocketServerHelper.wait_connection_terminate(ws) except websockets.exceptions.ConnectionClosedOK: pass finally: @@ -228,7 +228,7 @@ async def close_transport_task_coro(): async def server_countdown_close_connection_in_middle(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() json_result = json.loads(result) From f0853738a8889c425a634358604b5cbe7273c6e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Sun, 7 Jun 2020 20:26:06 +0200 Subject: [PATCH 02/19] Adding forgotten typing hints --- gql/transport/phoenix_channel_websockets.py | 11 +++++++++-- gql/transport/websockets.py | 7 ++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 84aff78e..31ad2edc 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -13,7 +13,9 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): - def __init__(self, channel_name, heartbeat_interval=30, *args, **kwargs) -> None: + def __init__( + self, channel_name: str, heartbeat_interval: int = 30, *args, **kwargs + ) -> None: self.channel_name = channel_name self.heartbeat_interval = heartbeat_interval self.subscription_ids_to_query_ids: Dict[str, int] = {} @@ -195,7 +197,12 @@ def _parse_answer( return answer_type, answer_id, execution_result - async def _handle_answer(self, answer_id, answer_type, execution_result) -> None: + async def _handle_answer( + self, + answer_id: str, + answer_type: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: if answer_type == "close": for listener in self.listeners.values(): await listener.put(("complete", execution_result)) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 954d7d79..1f639b8e 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -353,7 +353,12 @@ async def _receive_data_loop(self) -> None: finally: log.debug("Exiting _receive_data_loop()") - async def _handle_answer(self, answer_id, answer_type, execution_result) -> None: + async def _handle_answer( + self, + answer_id: str, + answer_type: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: try: # Put the answer in the queue if answer_id is not None: From 5e6018e942d1c215452ac00eb699fddce8129236 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Sun, 7 Jun 2020 22:44:41 +0200 Subject: [PATCH 03/19] Fix _handle_answer args order --- gql/transport/phoenix_channel_websockets.py | 6 +++--- gql/transport/websockets.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 31ad2edc..f773ecc1 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -199,15 +199,15 @@ def _parse_answer( async def _handle_answer( self, - answer_id: str, - answer_type: Optional[int], + answer_type: str, + answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": for listener in self.listeners.values(): await listener.put(("complete", execution_result)) else: - await super()._handle_answer(answer_id, answer_type, execution_result) + await super()._handle_answer(answer_type, answer_id, execution_result) async def close(self) -> None: self.heartbeat_task.cancel() diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 1f639b8e..df379577 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -348,15 +348,15 @@ async def _receive_data_loop(self) -> None: await self._fail(e, clean_close=False) break - await self._handle_answer(answer_id, answer_type, execution_result) + await self._handle_answer(answer_type, answer_id, execution_result) finally: log.debug("Exiting _receive_data_loop()") async def _handle_answer( self, - answer_id: str, - answer_type: Optional[int], + answer_type: str, + answer_id: Optional[int], execution_result: Optional[ExecutionResult], ) -> None: try: From e1eccbc72b94cb12f1fec9f9157ac7c3a8cea3d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Sun, 7 Jun 2020 23:12:14 +0200 Subject: [PATCH 04/19] Better handle exceptions --- gql/transport/phoenix_channel_websockets.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index f773ecc1..1bebf87f 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Tuple from graphql import DocumentNode, ExecutionResult, print_ast +from websockets.exceptions import ConnectionClosed from .exceptions import ( TransportProtocolError, @@ -54,7 +55,12 @@ async def _send_init_message_and_wait_ack(self) -> None: async def heartbeat_coro(): while True: await asyncio.sleep(self.heartbeat_interval) - await self._send(json.dumps({"topic": "phoenix", "event": "heartbeat"})) + try: + await self._send( + json.dumps({"topic": "phoenix", "event": "heartbeat"}) + ) + except ConnectionClosed: + pass self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) @@ -140,7 +146,13 @@ def _parse_answer( raise ValueError("payload is not a dict") subscription_id = str(payload.get("subscriptionId")) - answer_id = self.subscription_ids_to_query_ids[subscription_id] + try: + answer_id = self.subscription_ids_to_query_ids[subscription_id] + except KeyError: + raise ValueError( + f"subscription '{subscription_id}' has not been registerd" + ) + result = payload.get("result") if not isinstance(result, dict): From aeb18ec421f58453ab72bdbc2dd3370e66a3969f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Mon, 8 Jun 2020 22:11:51 +0200 Subject: [PATCH 05/19] Simulate complete messages in _send_stop_message --- gql/transport/phoenix_channel_websockets.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 1bebf87f..3b125177 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -65,7 +65,10 @@ async def heartbeat_coro(): self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) async def _send_stop_message(self, query_id: int) -> None: - pass + try: + await self.listeners[query_id].put(("complete", None)) + except KeyError: + pass async def _send_connection_terminate_message(self) -> None: """Send a phx_leave message to disconnect from the provided channel. @@ -216,8 +219,7 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": - for listener in self.listeners.values(): - await listener.put(("complete", execution_result)) + await self.close() else: await super()._handle_answer(answer_type, answer_id, execution_result) From 9c870061853ce5f7f77db9d6585be7719f3a7007 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Mon, 8 Jun 2020 22:16:59 +0200 Subject: [PATCH 06/19] Cancel heartbeat task in _close_coro --- gql/transport/phoenix_channel_websockets.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 3b125177..04d918c0 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -223,6 +223,11 @@ async def _handle_answer( else: await super()._handle_answer(answer_type, answer_id, execution_result) + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() + + await super()._close_coro(e, clean_close) + async def close(self) -> None: - self.heartbeat_task.cancel() await super().close() From b635138d925543e5b54ec85e38e44eb903f87d37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Mon, 8 Jun 2020 23:10:30 +0200 Subject: [PATCH 07/19] Remove useless close override --- gql/transport/phoenix_channel_websockets.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 04d918c0..ef12cc86 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -228,6 +228,3 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.heartbeat_task.cancel() await super()._close_coro(e, clean_close) - - async def close(self) -> None: - await super().close() From b7bd5a3bdbe970f08940cc7495c9e5f5e47841dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Wed, 17 Jun 2020 20:49:01 +0200 Subject: [PATCH 08/19] Set a correct ref in the heartbeat message --- gql/transport/phoenix_channel_websockets.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index ef12cc86..724ca3ea 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -56,8 +56,18 @@ async def heartbeat_coro(): while True: await asyncio.sleep(self.heartbeat_interval) try: + query_id = self.next_query_id + self.next_query_id += 1 + await self._send( - json.dumps({"topic": "phoenix", "event": "heartbeat"}) + json.dumps( + { + "topic": "phoenix", + "event": "heartbeat", + "payload": {}, + "ref": query_id, + } + ) ) except ConnectionClosed: pass From 4d6de484687ab3d9affd80b503fee3a33d83d5ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Wed, 17 Jun 2020 22:52:25 +0200 Subject: [PATCH 09/19] Adding unit tests for PhoenixChannelWebsocketsTransport exceptions --- gql/transport/phoenix_channel_websockets.py | 2 +- tests/test_phoenix_channel_exceptions.py | 135 ++++++++++++++++++++ 2 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 tests/test_phoenix_channel_exceptions.py diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 724ca3ea..b2ac2465 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -206,7 +206,7 @@ def _parse_answer( raise ValueError("reply error") elif status == "timeout": - raise ValueError("reply timeout") + raise TransportQueryError("reply timeout", query_id=answer_id) elif event == "phx_error": raise TransportServerError("Server error") diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py new file mode 100644 index 00000000..00dce76f --- /dev/null +++ b/tests/test_phoenix_channel_exceptions.py @@ -0,0 +1,135 @@ +import pytest + +from gql import Client, gql +from gql.transport.exceptions import TransportProtocolError, TransportQueryError +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import PhoenixChannelServerHelper + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + + +subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"reason":"internal error"},' + '"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +async def phoenix_server_reply_error(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(error_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +timeout_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"timeout"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +async def phoenix_server_timeout(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(timeout_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +generic_error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +async def phoenix_server_generic_error(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(generic_error_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +protocol_server_answer = '{"event":"unknown"}' + + +async def phoenix_server_protocol_error(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(protocol_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [phoenix_server_reply_error, phoenix_server_timeout], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportQueryError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [phoenix_server_generic_error, phoenix_server_protocol_error], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_protocol_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + await session.execute(query) From 76fa98acacc0ebbcb2977cd56a62a1c2630f441b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Thu, 18 Jun 2020 11:33:55 +0200 Subject: [PATCH 10/19] Increase PhoenixChannelWebsocketsTransport unit tests coverage --- tests/test_phoenix_channel_exceptions.py | 135 +++++++++++++++-------- 1 file changed, 92 insertions(+), 43 deletions(-) diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 00dce76f..4e0f1bcb 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -1,7 +1,11 @@ import pytest from gql import Client, gql -from gql.transport.exceptions import TransportProtocolError, TransportQueryError +from gql.transport.exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, +) from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport from .conftest import PhoenixChannelServerHelper @@ -15,8 +19,7 @@ } """ - -subscription_server_answer = ( +default_subscription_server_answer = ( '{"event":"phx_reply",' '"payload":' '{"response":' @@ -26,7 +29,6 @@ '"topic":"test_topic"}' ) - error_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -37,16 +39,6 @@ '"topic":"test_topic"}' ) - -async def phoenix_server_reply_error(ws, path): - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) - await ws.send(error_server_answer) - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() - - timeout_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -56,14 +48,67 @@ async def phoenix_server_reply_error(ws, path): ) -async def phoenix_server_timeout(ws, path): - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) - await ws.send(timeout_server_answer) - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() +def server( + query_server_answer, subscription_server_answer=default_subscription_server_answer, +): + async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + if query_server_answer is not None: + await ws.send(query_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + return phoenix_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [server(error_server_answer), server(timeout_server_answer)], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + query = gql(query_str) + with pytest.raises(TransportQueryError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +invalid_subscription_id_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"INVALID","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":3,' + '"topic":"test_topic"}' +) + +invalid_payload_server_answer = ( + '{"event":"subscription:data",' + '"payload":"INVALID",' + '"ref":3,' + '"topic":"test_topic"}' +) + +invalid_result_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"test_subscription","result": "INVALID"},' + '"ref":3,' + '"topic":"test_topic"}' +) generic_error_server_answer = ( '{"event":"phx_reply",' @@ -73,34 +118,35 @@ async def phoenix_server_timeout(ws, path): '"topic":"test_topic"}' ) - -async def phoenix_server_generic_error(ws, path): - await PhoenixChannelServerHelper.send_connection_ack(ws) - await ws.recv() - await ws.send(subscription_server_answer) - await ws.send(generic_error_server_answer) - await PhoenixChannelServerHelper.send_close(ws) - await ws.wait_closed() - - protocol_server_answer = '{"event":"unknown"}' +invalid_payload_subscription_server_answer = ( + '{"event":"phx_reply", "payload":"INVALID", "ref":2, "topic":"test_topic"}' +) + -async def phoenix_server_protocol_error(ws, path): - await PhoenixChannelServerHelper.send_connection_ack(ws) +async def no_connection_ack_phoenix_server(ws, path): await ws.recv() - await ws.send(subscription_server_answer) - await ws.send(protocol_server_answer) await PhoenixChannelServerHelper.send_close(ws) await ws.wait_closed() @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [phoenix_server_reply_error, phoenix_server_timeout], indirect=True + "server", + [ + server(invalid_subscription_id_server_answer), + server(invalid_result_server_answer), + server(generic_error_server_answer), + no_connection_ack_phoenix_server, + server(protocol_server_answer), + server(invalid_payload_server_answer), + server(None, invalid_payload_subscription_server_answer), + ], + indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_query_error(event_loop, server, query_str): +async def test_phoenix_channel_protocol_error(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -109,19 +155,22 @@ async def test_phoenix_channel_query_error(event_loop, server, query_str): ) query = gql(query_str) - with pytest.raises(TransportQueryError): + with pytest.raises(TransportProtocolError): async with Client(transport=sample_transport) as session: await session.execute(query) +server_error_subscription_server_answer = ( + '{"event":"phx_error", "ref":2, "topic":"test_topic"}' +) + + @pytest.mark.asyncio @pytest.mark.parametrize( - "server", - [phoenix_server_generic_error, phoenix_server_protocol_error], - indirect=True, + "server", [server(None, server_error_subscription_server_answer)], indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_phoenix_channel_protocol_error(event_loop, server, query_str): +async def test_phoenix_channel_server_error(event_loop, server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" @@ -130,6 +179,6 @@ async def test_phoenix_channel_protocol_error(event_loop, server, query_str): ) query = gql(query_str) - with pytest.raises(TransportProtocolError): + with pytest.raises(TransportServerError): async with Client(transport=sample_transport) as session: await session.execute(query) From dd6101fde2586597dd441d36298a345ff8c7ae48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Tue, 23 Jun 2020 08:02:23 +0200 Subject: [PATCH 11/19] Better handle a case when there are multiple errors in a query --- gql/transport/phoenix_channel_websockets.py | 14 +++++++++----- tests/test_phoenix_channel_exceptions.py | 18 ++++++++++++++++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index b2ac2465..82a46682 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -199,11 +199,15 @@ def _parse_answer( response = payload.get("response") if isinstance(response, dict): - raise TransportQueryError( - response.get("reason"), query_id=answer_id - ) - else: - raise ValueError("reply error") + if "errors" in response: + raise TransportQueryError( + response.get("errors"), query_id=answer_id + ) + elif "reason" in response: + raise TransportQueryError( + response.get("reason"), query_id=answer_id + ) + raise ValueError("reply error") elif status == "timeout": raise TransportQueryError("reply timeout", query_id=answer_id) diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 4e0f1bcb..97283650 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -29,7 +29,7 @@ '"topic":"test_topic"}' ) -error_server_answer = ( +error_with_reason_server_answer = ( '{"event":"phx_reply",' '"payload":' '{"response":' @@ -39,6 +39,16 @@ '"topic":"test_topic"}' ) +multiple_errors_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"errors": ["error 1", "error 2"]},' + '"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + timeout_server_answer = ( '{"event":"phx_reply",' '"payload":' @@ -66,7 +76,11 @@ async def phoenix_server(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize( "server", - [server(error_server_answer), server(timeout_server_answer)], + [ + server(error_with_reason_server_answer), + server(multiple_errors_server_answer), + server(timeout_server_answer), + ], indirect=True, ) @pytest.mark.parametrize("query_str", [query1_str]) From 33d51e6e1d7d3096251ab3601ec8266411447a7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Catinon?= Date: Mon, 29 Jun 2020 21:39:04 +0200 Subject: [PATCH 12/19] Fix mypy errors --- gql/transport/phoenix_channel_websockets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 82a46682..5b6345e0 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -201,11 +201,11 @@ def _parse_answer( if isinstance(response, dict): if "errors" in response: raise TransportQueryError( - response.get("errors"), query_id=answer_id + str(response.get("errors")), query_id=answer_id ) elif "reason" in response: raise TransportQueryError( - response.get("reason"), query_id=answer_id + str(response.get("reason")), query_id=answer_id ) raise ValueError("reply error") From f0f62be6b821ec57754c56139e880704b021cbd0 Mon Sep 17 00:00:00 2001 From: aurel Date: Sat, 11 Jul 2020 21:31:20 +0200 Subject: [PATCH 13/19] Update WebSocketServer to WebSocketServerHelper in test_websocket_non_regression_bug_108 --- tests/test_websocket_query.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index ce90432b..d44aa779 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -475,15 +475,15 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): async def server_sending_keep_alive_before_connection_ack(ws, path): - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() From 28ad41ddf5330098dc46c0fb95834f2d293de451 Mon Sep 17 00:00:00 2001 From: aurel Date: Sat, 11 Jul 2020 22:23:25 +0200 Subject: [PATCH 14/19] Adding doc on PhoenixChannelWebsocketsTransport --- gql/transport/phoenix_channel_websockets.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 5b6345e0..32a8a73a 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -15,12 +15,18 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransport): def __init__( - self, channel_name: str, heartbeat_interval: int = 30, *args, **kwargs + self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs ) -> None: self.channel_name = channel_name self.heartbeat_interval = heartbeat_interval self.subscription_ids_to_query_ids: Dict[str, int] = {} super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) + """Initialize the transport with the given request parameters. + + :param channel_name Channel on the server this transport will join + :param heartbeat_interval Interval in second between each heartbeat messages + sent by the client + """ async def _send_init_message_and_wait_ack(self) -> None: """Join the specified channel and wait for the connection ACK. From bd72e79cf2f131bbf57040b3b468e6b3519c6d0e Mon Sep 17 00:00:00 2001 From: XuZvvHYmZfYdWJNRunkJ <3367239+JBrVJxsc@users.noreply.github.com> Date: Mon, 17 Aug 2020 00:56:13 -0700 Subject: [PATCH 15/19] DSL: Fixed bug where a nested GraphQLInputObjectType is causing infinite recursive calls (#132) Co-authored-by: xzhang2 --- gql/dsl.py | 21 ++++++--- tests/nested_input/__init__.py | 0 tests/nested_input/schema.py | 30 ++++++++++++ tests/nested_input/test_nested_input.py | 63 +++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 7 deletions(-) create mode 100644 tests/nested_input/__init__.py create mode 100644 tests/nested_input/schema.py create mode 100644 tests/nested_input/test_nested_input.py diff --git a/gql/dsl.py b/gql/dsl.py index 0f66ff5e..bd592ee1 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -105,7 +105,7 @@ def args(self, **kwargs): arg = self.field.args.get(name) if not arg: raise KeyError(f"Argument {name} does not exist in {self.field}.") - arg_type_serializer = get_arg_serializer(arg.type) + arg_type_serializer = get_arg_serializer(arg.type, known_serializers=dict()) serialized_value = arg_type_serializer(value) added_args.append( ArgumentNode(name=NameNode(value=name), value=serialized_value) @@ -151,21 +151,28 @@ def serialize_list(serializer, list_values): return ListValueNode(values=FrozenList(serializer(v) for v in list_values)) -def get_arg_serializer(arg_type): +def get_arg_serializer(arg_type, known_serializers): if isinstance(arg_type, GraphQLNonNull): - return get_arg_serializer(arg_type.of_type) + return get_arg_serializer(arg_type.of_type, known_serializers) if isinstance(arg_type, GraphQLInputField): - return get_arg_serializer(arg_type.type) + return get_arg_serializer(arg_type.type, known_serializers) if isinstance(arg_type, GraphQLInputObjectType): - serializers = {k: get_arg_serializer(v) for k, v in arg_type.fields.items()} - return lambda value: ObjectValueNode( + if arg_type in known_serializers: + return known_serializers[arg_type] + known_serializers[arg_type] = None + serializers = { + k: get_arg_serializer(v, known_serializers) + for k, v in arg_type.fields.items() + } + known_serializers[arg_type] = lambda value: ObjectValueNode( fields=FrozenList( ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) for k, v in value.items() ) ) + return known_serializers[arg_type] if isinstance(arg_type, GraphQLList): - inner_serializer = get_arg_serializer(arg_type.of_type) + inner_serializer = get_arg_serializer(arg_type.of_type, known_serializers) return partial(serialize_list, inner_serializer) if isinstance(arg_type, GraphQLEnumType): return lambda value: EnumValueNode(value=arg_type.serialize(value)) diff --git a/tests/nested_input/__init__.py b/tests/nested_input/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/nested_input/schema.py b/tests/nested_input/schema.py new file mode 100644 index 00000000..f27a94e8 --- /dev/null +++ b/tests/nested_input/schema.py @@ -0,0 +1,30 @@ +from graphql import ( + GraphQLArgument, + GraphQLField, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLInt, + GraphQLObjectType, + GraphQLSchema, +) + +nestedInput = GraphQLInputObjectType( + "Nested", + description="The input object that has a field pointing to itself", + fields={"foo": GraphQLInputField(GraphQLInt, description="foo")}, +) + +nestedInput.fields["child"] = GraphQLInputField(nestedInput, description="child") + +queryType = GraphQLObjectType( + "Query", + fields=lambda: { + "foo": GraphQLField( + args={"nested": GraphQLArgument(type_=nestedInput)}, + resolve=lambda *args, **kwargs: 1, + type_=GraphQLInt, + ), + }, +) + +NestedInputSchema = GraphQLSchema(query=queryType, types=[nestedInput],) diff --git a/tests/nested_input/test_nested_input.py b/tests/nested_input/test_nested_input.py new file mode 100644 index 00000000..037d1518 --- /dev/null +++ b/tests/nested_input/test_nested_input.py @@ -0,0 +1,63 @@ +from functools import partial + +import pytest +from graphql import ( + EnumValueNode, + GraphQLEnumType, + GraphQLInputField, + GraphQLInputObjectType, + GraphQLList, + GraphQLNonNull, + NameNode, + ObjectFieldNode, + ObjectValueNode, + ast_from_value, +) +from graphql.pyutils import FrozenList + +import gql.dsl as dsl +from gql import Client +from gql.dsl import DSLSchema, serialize_list +from tests.nested_input.schema import NestedInputSchema + +# back up the new func +new_get_arg_serializer = dsl.get_arg_serializer + + +def old_get_arg_serializer(arg_type, known_serializers=None): + if isinstance(arg_type, GraphQLNonNull): + return old_get_arg_serializer(arg_type.of_type) + if isinstance(arg_type, GraphQLInputField): + return old_get_arg_serializer(arg_type.type) + if isinstance(arg_type, GraphQLInputObjectType): + serializers = {k: old_get_arg_serializer(v) for k, v in arg_type.fields.items()} + return lambda value: ObjectValueNode( + fields=FrozenList( + ObjectFieldNode(name=NameNode(value=k), value=serializers[k](v)) + for k, v in value.items() + ) + ) + if isinstance(arg_type, GraphQLList): + inner_serializer = old_get_arg_serializer(arg_type.of_type) + return partial(serialize_list, inner_serializer) + if isinstance(arg_type, GraphQLEnumType): + return lambda value: EnumValueNode(value=arg_type.serialize(value)) + return lambda value: ast_from_value(arg_type.serialize(value), arg_type) + + +@pytest.fixture +def ds(): + client = Client(schema=NestedInputSchema) + ds = DSLSchema(client) + return ds + + +def test_nested_input_with_old_get_arg_serializer(ds): + dsl.get_arg_serializer = old_get_arg_serializer + with pytest.raises(RecursionError, match="maximum recursion depth exceeded"): + ds.query(ds.Query.foo.args(nested={"foo": 1})) + + +def test_nested_input_with_new_get_arg_serializer(ds): + dsl.get_arg_serializer = new_get_arg_serializer + assert ds.query(ds.Query.foo.args(nested={"foo": 1})) == {"foo": 1} From 1a2dcec6b2834e81d66c4f23a7e50415d20e86cc Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Thu, 27 Aug 2020 22:12:45 +0200 Subject: [PATCH 16/19] Fix race condition in websocket transport close (#133) --- gql/transport/websockets.py | 60 +++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 5 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index b0c064fc..7ad91519 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -488,6 +488,8 @@ async def connect(self) -> None: GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws") + log.debug("connect: starting") + if self.websocket is None and not self._connecting: # Set connecting to True to avoid a race condition if user is trying @@ -543,6 +545,8 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") + log.debug("connect: done") + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: @@ -575,35 +579,81 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - close the websocket connection - send the exception to all the remaining listeners """ - if self.websocket: + + log.debug("_close_coro: starting") + + try: + + # We should always have an active websocket connection here + assert self.websocket is not None # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e if clean_close: - await self._clean_close(e) + log.debug("_close_coro: starting clean_close") + try: + await self._clean_close(e) + except Exception as exc: # pragma: no cover + log.warning("Ignoring exception in _clean_close: " + repr(exc)) + + log.debug("_close_coro: sending exception to listeners") # Send an exception to all remaining listeners for query_id, listener in self.listeners.items(): await listener.set_exception(e) + log.debug("_close_coro: close websocket connection") + await self.websocket.close() - self.websocket = None + log.debug("_close_coro: websocket connection closed") + + except Exception as exc: # pragma: no cover + log.warning("Exception catched in _close_coro: " + repr(exc)) + + finally: + log.debug("_close_coro: start cleanup") + + self.websocket = None self.close_task = None + self._wait_closed.set() + log.debug("_close_coro: exiting") + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + log.debug("_fail: starting with exception: " + repr(e)) + if self.close_task is None: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + + if self.websocket is None: + log.debug("_fail started with self.websocket == None -> already closed") + else: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + else: + log.debug( + "close_task is not None in _fail. Previous exception is: " + + repr(self.close_exception) + + " New exception is: " + + repr(e) ) async def close(self) -> None: + log.debug("close: starting") + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) await self.wait_closed() + log.debug("close: done") + async def wait_closed(self) -> None: + log.debug("wait_close: starting") + await self._wait_closed.wait() + + log.debug("wait_close: done") From 29f7f2b950a84a06d9e5a2914c059af5740d5be5 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Fri, 28 Aug 2020 09:52:00 +0200 Subject: [PATCH 17/19] add the data property in TransportQueryError (#136) --- gql/client.py | 12 +++++++++--- gql/transport/exceptions.py | 2 ++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/gql/client.py b/gql/client.py index 348592be..c4dca409 100644 --- a/gql/client.py +++ b/gql/client.py @@ -240,7 +240,9 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: # Raise an error if an error is returned in the ExecutionResult object if result.errors: - raise TransportQueryError(str(result.errors[0]), errors=result.errors) + raise TransportQueryError( + str(result.errors[0]), errors=result.errors, data=result.data + ) assert ( result.data is not None @@ -315,7 +317,9 @@ async def subscribe( # Raise an error if an error is returned in the ExecutionResult object if result.errors: - raise TransportQueryError(str(result.errors[0]), errors=result.errors) + raise TransportQueryError( + str(result.errors[0]), errors=result.errors, data=result.data + ) elif result.data is not None: yield result.data @@ -340,7 +344,9 @@ async def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: # Raise an error if an error is returned in the ExecutionResult object if result.errors: - raise TransportQueryError(str(result.errors[0]), errors=result.errors) + raise TransportQueryError( + str(result.errors[0]), errors=result.errors, data=result.data + ) assert ( result.data is not None diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 8119c8d2..4df2ec43 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -30,10 +30,12 @@ def __init__( msg: str, query_id: Optional[int] = None, errors: Optional[List[Any]] = None, + data: Optional[Any] = None, ): super().__init__(msg) self.query_id = query_id self.errors = errors + self.data = data class TransportClosed(TransportError): From 3c82d1eb6ea1bf38f37730aa62d9b85962ccd8bb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 7 Sep 2020 21:14:34 +0200 Subject: [PATCH 18/19] Fix running execute and subscribe of client in a Thread (#135) --- gql/client.py | 18 +++++++-- gql/transport/websockets.py | 8 ++++ tests/conftest.py | 19 +++++++++ tests/test_aiohttp.py | 59 ++++++++++++++++++++++++++++ tests/test_requests.py | 33 ++++++---------- tests/test_websocket_subscription.py | 29 ++++++++++++++ 6 files changed, 141 insertions(+), 25 deletions(-) diff --git a/gql/client.py b/gql/client.py index c4dca409..13f67327 100644 --- a/gql/client.py +++ b/gql/client.py @@ -110,7 +110,13 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict: if isinstance(self.transport, AsyncTransport): - loop = asyncio.get_event_loop() + # Get the current asyncio event loop + # Or create a new event loop if there isn't one (in a new Thread) + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) assert not loop.is_running(), ( "Cannot run client.execute(query) if an asyncio loop is running." @@ -146,9 +152,15 @@ def subscribe( We need an async transport for this functionality. """ - async_generator = self.subscribe_async(document, *args, **kwargs) + # Get the current asyncio event loop + # Or create a new event loop if there isn't one (in a new Thread) + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) - loop = asyncio.get_event_loop() + async_generator = self.subscribe_async(document, *args, **kwargs) assert not loop.is_running(), ( "Cannot run client.subscribe(query) if an asyncio loop is running." diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 7ad91519..b4552b8c 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -128,6 +128,14 @@ def __init__( self.receive_data_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None + # We need to set an event loop here if there is none + # Or else we will not be able to create an asyncio.Event() + try: + self._loop = asyncio.get_event_loop() + except RuntimeError: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._wait_closed: asyncio.Event = asyncio.Event() self._wait_closed.set() diff --git a/tests/conftest.py b/tests/conftest.py index 8ce81f8d..c2edc236 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import pathlib import ssl import types +from concurrent.futures import ThreadPoolExecutor import pytest import websockets @@ -288,3 +289,21 @@ async def client_and_server(server): # Yield both client session and server yield session, server + + +@pytest.fixture +async def run_sync_test(): + async def run_sync_test_inner(event_loop, server, test_function): + """This function will run the test in a different Thread. + + This allows us to run sync code while aiohttp server can still run. + """ + executor = ThreadPoolExecutor(max_workers=2) + test_task = event_loop.run_in_executor(executor, test_function) + + await test_task + + if hasattr(server, "close"): + await server.close() + + return run_sync_test_inner diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index bc6bd219..0e97655f 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -262,3 +262,62 @@ async def handler(request): continent = result["continent"] assert continent["name"] == "Europe" + + +@pytest.mark.asyncio +async def test_aiohttp_execute_running_in_thread( + event_loop, aiohttp_server, run_sync_test +): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = AIOHTTPTransport(url=url) + + client = Client(transport=sample_transport) + + query = gql(query1_str) + + client.execute(query) + + await run_sync_test(event_loop, server, test_code) + + +@pytest.mark.asyncio +async def test_aiohttp_subscribe_running_in_thread( + event_loop, aiohttp_server, run_sync_test +): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + sample_transport = AIOHTTPTransport(url=url) + + client = Client(transport=sample_transport) + + query = gql(query1_str) + + # Note: subscriptions are not supported on the aiohttp transport + # But we add this test in order to have 100% code coverage + # It is to check that we will correctly set an event loop + # in the subscribe function if there is none (in a Thread for example) + # We cannot test this with the websockets transport because + # the websockets transport will set an event loop in its init + + with pytest.raises(NotImplementedError): + for result in client.subscribe(query): + pass + + await run_sync_test(event_loop, server, test_code) diff --git a/tests/test_requests.py b/tests/test_requests.py index 24fab2d2..b46c8611 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,9 +1,7 @@ -from concurrent.futures import ThreadPoolExecutor - import pytest from aiohttp import web -from gql import Client, gql +from gql import Client, RequestsHTTPTransport, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -11,7 +9,6 @@ TransportQueryError, TransportServerError, ) -from gql.transport.requests import RequestsHTTPTransport query1_str = """ query getContinents { @@ -31,20 +28,8 @@ ) -async def run_sync_test(event_loop, server, test_function): - """This function will run the test in a different Thread. - - This allows us to run sync code while aiohttp server can still run. - """ - executor = ThreadPoolExecutor(max_workers=2) - test_task = event_loop.run_in_executor(executor, test_function) - - await test_task - await server.close() - - @pytest.mark.asyncio -async def test_requests_query(event_loop, aiohttp_server): +async def test_requests_query(event_loop, aiohttp_server, run_sync_test): async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -74,7 +59,7 @@ def test_code(): @pytest.mark.asyncio -async def test_requests_error_code_500(event_loop, aiohttp_server): +async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test): async def handler(request): # Will generate http error code 500 raise Exception("Server error") @@ -102,7 +87,7 @@ def test_code(): @pytest.mark.asyncio -async def test_requests_error_code(event_loop, aiohttp_server): +async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test): async def handler(request): return web.Response( text=query1_server_error_answer, content_type="application/json" @@ -136,7 +121,9 @@ def test_code(): @pytest.mark.asyncio @pytest.mark.parametrize("response", invalid_protocol_responses) -async def test_requests_invalid_protocol(event_loop, aiohttp_server, response): +async def test_requests_invalid_protocol( + event_loop, aiohttp_server, response, run_sync_test +): async def handler(request): return web.Response(text=response, content_type="application/json") @@ -160,7 +147,7 @@ def test_code(): @pytest.mark.asyncio -async def test_requests_cannot_connect_twice(event_loop, aiohttp_server): +async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test): async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") @@ -182,7 +169,9 @@ def test_code(): @pytest.mark.asyncio -async def test_requests_cannot_execute_if_not_connected(event_loop, aiohttp_server): +async def test_requests_cannot_execute_if_not_connected( + event_loop, aiohttp_server, run_sync_test +): async def handler(request): return web.Response(text=query1_server_answer, content_type="application/json") diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 728b7d6c..8152e07c 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -446,3 +446,32 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) # Check that the server received a connection_terminate message last assert logged_messages.pop() == '{"type": "connection_terminate"}' + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_running_in_thread( + event_loop, server, subscription_str, run_sync_test +): + def test_code(): + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + await run_sync_test(event_loop, server, test_code) From b4ab94164ef2d9f56b587f7a3206b41b3a20d477 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Mon, 7 Sep 2020 21:24:37 +0200 Subject: [PATCH 19/19] Allow to import PhoenixChannelWebsocketsTransport directly from gql && no cover pragmas --- gql/__init__.py | 2 ++ gql/transport/phoenix_channel_websockets.py | 6 +++--- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/gql/__init__.py b/gql/__init__.py index 7c21c1c8..bad425d4 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,6 +1,7 @@ from .client import Client from .gql import gql from .transport.aiohttp import AIOHTTPTransport +from .transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport from .transport.requests import RequestsHTTPTransport from .transport.websockets import WebsocketsTransport @@ -8,6 +9,7 @@ "gql", "AIOHTTPTransport", "Client", + "PhoenixChannelWebsocketsTransport", "RequestsHTTPTransport", "WebsocketsTransport", ] diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 32a8a73a..6e96b72e 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -75,15 +75,15 @@ async def heartbeat_coro(): } ) ) - except ConnectionClosed: - pass + except ConnectionClosed: # pragma: no cover + return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) async def _send_stop_message(self, query_id: int) -> None: try: await self.listeners[query_id].put(("complete", None)) - except KeyError: + except KeyError: # pragma: no cover pass async def _send_connection_terminate_message(self) -> None: