From cda8f6c1396253035766d5bd45f0e19d773ed41c Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 18 Jan 2022 15:22:55 +0000 Subject: [PATCH 01/32] chore: add asyncpg dependency --- requirements.txt | 1 + setup.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 942e6e822..e0c140972 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ aiohttp==3.8.1 cryptography==36.0.1 PyMySQL==1.0.2 pg8000==1.23.0 +asyncpg==0.25.0 python-tds==1.11.0 pyopenssl==21.0.0 Requests==2.27.1 diff --git a/setup.py b/setup.py index cf0c845f4..30af905e8 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,8 @@ extras_require={ "pymysql": ["PyMySQL==1.0.2"], "pg8000": ["pg8000==1.23.0"], - "pytds": ["python-tds==1.11.0"] + "pytds": ["python-tds==1.11.0"], + "asyncpg": ["asyncpg==0.25.0"], }, python_requires=">=3.6", include_package_data=True, From 124b7321228a39b25b290f37acee4067dc48f676 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 18 Jan 2022 16:35:43 +0000 Subject: [PATCH 02/32] chore: add connect method for asyncpg --- .../connector/instance_connection_manager.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index eb7365bbb..c900f6903 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -49,6 +49,7 @@ import pymysql import pg8000 import pytds + import asyncpg logger = logging.getLogger(name=__name__) APPLICATION_NAME = "cloud-sql-python-connector" @@ -554,6 +555,7 @@ async def _connect( "pymysql": self._connect_with_pymysql, "pg8000": self._connect_with_pg8000, "pytds": self._connect_with_pytds, + "asyncpg": self._connect_with_asyncpg, } instance_data: InstanceMetadata @@ -643,6 +645,41 @@ def _connect_with_pg8000( **kwargs, ) + def _connect_with_asyncpg( + self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ) -> "asyncpg.Connection": + """Helper function to create an asyncpg DB-API connection object. + + :type ip_address: str + :param ip_address: A string containing an IP address for the Cloud SQL + instance. + + :type ctx: ssl.SSLContext + :param ctx: An SSLContext object created from the Cloud SQL server CA + cert and ephemeral cert. + + :rtype: asyncpg.Connection + :returns: An asyncpg Connection object for the Cloud SQL instance. + """ + try: + import asyncpg + except ImportError: + raise ImportError( + 'Unable to import module "asyncpg." Please install and try again.' + ) + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + return asyncpg.connect( + user=user, + database=db, + password=passwd, + host=ip_address, + port=SERVER_PROXY_PORT, + ssl=ctx, + **kwargs, + ) + def _connect_with_pytds( self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any ) -> "pytds.Connection": From 4fa7fcc6a676655a701b8e377ec974baf8ad4998 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 28 Jan 2022 15:20:56 +0000 Subject: [PATCH 03/32] chore: testing async_connect method --- google/cloud/sql/connector/connector.py | 71 +++++++++++++++++++ .../connector/instance_connection_manager.py | 11 ++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 7dfd6d186..4aa215a47 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -144,6 +144,77 @@ def connect( icm.force_refresh() raise (e) + async def async_connect( + self, instance_connection_string: str, driver: str, **kwargs: Any + ) -> Any: + """Prepares and returns an async database connection object and starts a + background thread to refresh the certificates and metadata. + + :type instance_connection_string: str + :param instance_connection_string: + A string containing the GCP project name, region name, and instance + name separated by colons. + + Example: example-proj:example-region-us6:example-instance + + :type driver: str + :param: driver: + A string representing the driver to connect with. Supported drivers are + pymysql, pg8000, and pytds. + + :param kwargs: + Pass in any driver-specific arguments needed to connect to the Cloud + SQL instance. + + :rtype: Connection + :returns: + A DB-API connection to the specified Cloud SQL instance. + """ + if instance_connection_string in self._instances: + icm = self._instances[instance_connection_string] + else: + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + icm = InstanceConnectionManager( + instance_connection_string, + driver, + self._keys, + self._loop, + self._credentials, + enable_iam_auth, + ) + self._instances[instance_connection_string] = icm + + if "ip_types" in kwargs: + ip_type = kwargs.pop("ip_types") + logger.warning( + "Deprecation Warning: Parameter `ip_types` is deprecated and may be removed" + "in a future release. Please use `ip_type` instead." + ) + else: + ip_type = kwargs.pop("ip_type", self._ip_type) + if "timeout" in kwargs: + timeout = kwargs["timeout"] + elif "connect_timeout" in kwargs: + timeout = kwargs["connect_timeout"] + else: + timeout = self._timeout + try: + connect_future: concurrent.futures.Future = ( + asyncio.run_coroutine_threadsafe( + icm._connect(driver, ip_type, **kwargs), self._loop + ) + ) + conn = connect_future.result(timeout) + except concurrent.futures.TimeoutError: + connect_future.cancel() + raise TimeoutError(f"Connection timed out after {timeout}s") + try: + return conn + except Exception as e: + # with any other exception, we attempt a force refresh, then throw the error + icm.force_refresh() + raise (e) + def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any: """Uses a Connector object with default settings and returns a database diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index c900f6903..04b34d68d 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -564,7 +564,12 @@ async def _connect( ip_address: str = instance_data.get_preferred_ip(ip_type) try: - connector = connect_func[driver] + if driver == "asyncpg": + return await self._connect_with_asyncpg( + ip_address, instance_data.context, **kwargs + ) + else: + connector = connect_func[driver] except KeyError: raise KeyError("Driver {} is not supported.".format(driver)) @@ -645,7 +650,7 @@ def _connect_with_pg8000( **kwargs, ) - def _connect_with_asyncpg( + async def _connect_with_asyncpg( self, ip_address: str, ctx: ssl.SSLContext, **kwargs: Any ) -> "asyncpg.Connection": """Helper function to create an asyncpg DB-API connection object. @@ -670,7 +675,7 @@ def _connect_with_asyncpg( user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) - return asyncpg.connect( + return await asyncpg.connect( user=user, database=db, password=passwd, From 2a871ca882fdd09bc044b211f35b367458f0c14e Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 18 Feb 2022 15:18:35 +0000 Subject: [PATCH 04/32] chore: attempt new async_connect method --- google/cloud/sql/connector/connector.py | 107 ++++++++---------- .../connector/instance_connection_manager.py | 33 +----- 2 files changed, 46 insertions(+), 94 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 53d16365a..b4676c0c9 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -29,6 +29,27 @@ _default_connector = None +# This thread is used for background processing +_thread: Optional[Thread] = None +_loop: Optional[asyncio.AbstractEventLoop] = None + + +def _get_loop() -> asyncio.AbstractEventLoop: + global _loop, _thread + try: + loop = asyncio.get_running_loop() + print("Using found event loop!") + return loop + except RuntimeError as e: + if _loop is None: + print("Creating new background loop!") + _loop = asyncio.new_event_loop() + _thread = Thread(target=_loop.run_forever, daemon=True) + _thread.start() + else: + print("Using already created background loop!") + return _loop + class Connector: """A class to configure and create connections to Cloud SQL instances. @@ -59,9 +80,7 @@ def __init__( timeout: int = 30, credentials: Optional[Credentials] = None, ) -> None: - self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) - self._thread.start() + self._loop: asyncio.AbstractEventLoop = _get_loop() self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( generate_keys(), self._loop ) @@ -108,48 +127,10 @@ def connect( # Use the InstanceConnectionManager to establish an SSL Connection. # # Return a DBAPI connection - enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) - if instance_connection_string in self._instances: - icm = self._instances[instance_connection_string] - if enable_iam_auth != icm._enable_iam_auth: - raise ValueError( - "connect() called with `enable_iam_auth={}`, but previously used " - "enable_iam_auth={}`. If you require both for your use case, " - "please use a new connector.Connector object.".format( - enable_iam_auth, icm._enable_iam_auth - ) - ) - else: - icm = InstanceConnectionManager( - instance_connection_string, - driver, - self._keys, - self._loop, - self._credentials, - enable_iam_auth, - ) - self._instances[instance_connection_string] = icm - - if "ip_types" in kwargs: - ip_type = kwargs.pop("ip_types") - logger.warning( - "Deprecation Warning: Parameter `ip_types` is deprecated and may be removed" - " in a future release. Please use `ip_type` instead." - ) - else: - ip_type = kwargs.pop("ip_type", self._ip_type) - if "timeout" in kwargs: - return icm.connect(driver, ip_type, **kwargs) - elif "connect_timeout" in kwargs: - timeout = kwargs["connect_timeout"] - else: - timeout = self._timeout - try: - return icm.connect(driver, ip_type, timeout, **kwargs) - except Exception as e: - # with any other exception, we attempt a force refresh, then throw the error - icm.force_refresh() - raise (e) + connect_task = asyncio.run_coroutine_threadsafe( + self.async_connect(instance_connection_string, driver, **kwargs), self._loop + ) + return connect_task.result() async def async_connect( self, instance_connection_string: str, driver: str, **kwargs: Any @@ -177,10 +158,17 @@ async def async_connect( :returns: A DB-API connection to the specified Cloud SQL instance. """ + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: icm = self._instances[instance_connection_string] + if enable_iam_auth != icm._enable_iam_auth: + raise ValueError( + f"connect() called with `enable_iam_auth={enable_iam_auth}`, " + f"but previously used enable_iam_auth={icm._enable_iam_auth}`. " + "If you require both for your use case, please use a new " + "connector.Connector object." + ) else: - enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) icm = InstanceConnectionManager( instance_connection_string, driver, @@ -195,31 +183,26 @@ async def async_connect( ip_type = kwargs.pop("ip_types") logger.warning( "Deprecation Warning: Parameter `ip_types` is deprecated and may be removed" - "in a future release. Please use `ip_type` instead." + " in a future release. Please use `ip_type` instead." ) else: ip_type = kwargs.pop("ip_type", self._ip_type) - if "timeout" in kwargs: - timeout = kwargs["timeout"] - elif "connect_timeout" in kwargs: + timeout = kwargs.pop("timeout", self._timeout) + if "connect_timeout" in kwargs: timeout = kwargs["connect_timeout"] - else: - timeout = self._timeout + try: - connect_future: concurrent.futures.Future = ( - asyncio.run_coroutine_threadsafe( - icm._connect(driver, ip_type, **kwargs), self._loop - ) + connection_task = self._loop.create_task( + icm._connect(driver, ip_type, **kwargs) ) - conn = connect_future.result(timeout) - except concurrent.futures.TimeoutError: - connect_future.cancel() + await asyncio.wait_for(connection_task, timeout) + return await connection_task + except asyncio.TimeoutError: raise TimeoutError(f"Connection timed out after {timeout}s") - try: - return conn except Exception as e: # with any other exception, we attempt a force refresh, then throw the error - icm.force_refresh() + refresh_task = self._loop.create_task(icm._force_refresh()) + await asyncio.wait_for(refresh_task, None) raise (e) diff --git a/google/cloud/sql/connector/instance_connection_manager.py b/google/cloud/sql/connector/instance_connection_manager.py index d53c3e44d..5253b83c2 100644 --- a/google/cloud/sql/connector/instance_connection_manager.py +++ b/google/cloud/sql/connector/instance_connection_manager.py @@ -501,37 +501,6 @@ async def _schedule_refresh(self, delay: Optional[int] = None) -> asyncio.Task: raise e return await self._perform_refresh() - def connect( - self, - driver: str, - ip_type: IPTypes, - timeout: int, - **kwargs: Any, - ) -> Any: - """A method that returns a DB-API connection to the database. - - :type driver: str - :param driver: A string representing the driver. e.g. "pymysql" - - :type timeout: int - :param timeout: The time limit for the connection before raising - a TimeoutError - - :returns: A DB-API connection to the primary IP of the database. - """ - - connect_future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - self._connect(driver, ip_type, **kwargs), self._loop - ) - - try: - connection = connect_future.result(timeout) - except concurrent.futures.TimeoutError: - connect_future.cancel() - raise TimeoutError(f"Connection timed out after {timeout}s") - else: - return connection - async def _connect( self, driver: str, @@ -546,7 +515,7 @@ async def _connect( :returns: A DB-API connection to the primary IP of the database. """ logger.debug("Entered connect method") - + print("_connect thread ID: ", self._loop._thread_id) # Host and ssl options come from the certificates and metadata, so we don't # want the user to specify them. kwargs.pop("host", None) From 2489bbf2de35115cae1fef57f2b953600eb647e8 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 4 Apr 2022 20:56:41 +0000 Subject: [PATCH 05/32] chore: add loop param to Connector __init__ --- google/cloud/sql/connector/connector.py | 45 ++++--------------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 3fabac6bb..7e7838bc6 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -84,8 +84,12 @@ def __init__( enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, + loop: asyncio.AbstractEventLoop = None, ) -> None: - self._loop: asyncio.AbstractEventLoop = _get_loop() + if loop: + self._loop = loop + else: + self._loop: asyncio.AbstractEventLoop = _get_loop() self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( generate_keys(), self._loop ) @@ -128,45 +132,8 @@ def connect( ) return connect_task.result() - async def connect_async( - self, instance_connection_string: str, driver: str, **kwargs: Any - ) -> Any: - """Prepares and returns a database connection object and starts a - background task to refresh the certificates and metadata. - - :type instance_connection_string: str - :param instance_connection_string: - A string containing the GCP project name, region name, and instance - name separated by colons. - - Example: example-proj:example-region-us6:example-instance - - :type driver: str - :param: driver: - A string representing the driver to connect with. Supported drivers are - pymysql, pg8000, and pytds. - - :param kwargs: - Pass in any driver-specific arguments needed to connect to the Cloud - SQL instance. - - :rtype: Connection - :returns: - A DB-API connection to the specified Cloud SQL instance. - """ - - # Create an Instance object from the connection string. - # The Instance should verify arguments. - # - # Use the Instance to establish an SSL Connection. - # - # Return a DBAPI connection - connect_task = asyncio.run_coroutine_threadsafe( - self.async_connect(instance_connection_string, driver, **kwargs), self._loop - ) - return connect_task.result() - async def async_connect( + async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: """Prepares and returns an async database connection object and starts a From 7eafe6887b822f1ad8bc02ecb6cb58ff156c1077 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 8 Apr 2022 17:09:06 +0000 Subject: [PATCH 06/32] chore: testing asyncpg.connect --- google/cloud/sql/connector/asyncpg.py | 57 +++++++++++++------------ google/cloud/sql/connector/connector.py | 12 +++--- 2 files changed, 35 insertions(+), 34 deletions(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index de5121fdb..4880ed78e 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -6,37 +6,38 @@ if TYPE_CHECKING: import asyncpg + async def connect( ip_address: str, ctx: ssl.SSLContext, **kwargs: Any - ) -> "asyncpg.Connection": - """Helper function to create an asyncpg DB-API connection object. +) -> "asyncpg.Connection": + """Helper function to create an asyncpg DB-API connection object. - :type ip_address: str - :param ip_address: A string containing an IP address for the Cloud SQL - instance. + :type ip_address: str + :param ip_address: A string containing an IP address for the Cloud SQL + instance. - :type ctx: ssl.SSLContext - :param ctx: An SSLContext object created from the Cloud SQL server CA - cert and ephemeral cert. + :type ctx: ssl.SSLContext + :param ctx: An SSLContext object created from the Cloud SQL server CA + cert and ephemeral cert. - :rtype: asyncpg.Connection - :returns: An asyncpg Connection object for the Cloud SQL instance. - """ - try: - import asyncpg - except ImportError: - raise ImportError( - 'Unable to import module "asyncpg." Please install and try again.' - ) - user = kwargs.pop("user") - db = kwargs.pop("db") - passwd = kwargs.pop("password", None) - return await asyncpg.connect( - user=user, - database=db, - password=passwd, - host=ip_address, - port=SERVER_PROXY_PORT, - ssl=ctx, - **kwargs, + :rtype: asyncpg.Connection + :returns: An asyncpg Connection object for the Cloud SQL instance. + """ + try: + import asyncpg + except ImportError: + raise ImportError( + 'Unable to import module "asyncpg." Please install and try again.' ) + user = kwargs.pop("user") + db = kwargs.pop("db") + passwd = kwargs.pop("password", None) + return await asyncpg.connect( + user=user, + database=db, + password=passwd, + host=ip_address, + port=SERVER_PROXY_PORT, + ssl=ctx, + timeout=30, # remove prior to PR + ) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 41465b036..769bcf661 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -44,15 +44,12 @@ def _get_loop() -> asyncio.AbstractEventLoop: global _loop, _thread try: loop = asyncio.get_running_loop() - print("Using found event loop!") return loop except RuntimeError as e: if _loop is None: _loop = asyncio.new_event_loop() _thread = Thread(target=_loop.run_forever, daemon=True) _thread.start() - else: - print("Using already created background loop!") return _loop @@ -132,7 +129,6 @@ def connect( ) return connect_task.result() - async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: @@ -159,6 +155,7 @@ async def connect_async( :returns: A DB-API connection to the specified Cloud SQL instance. """ + loop = kwargs.pop("loop", self._loop) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: instance = self._instances[instance_connection_string] @@ -174,7 +171,7 @@ async def connect_async( instance_connection_string, driver, self._keys, - self._loop, + loop, self._credentials, enable_iam_auth, ) @@ -217,8 +214,11 @@ async def get_connection() -> Any: connect_partial = partial( connector, ip_address, instance_data.context, **kwargs ) + # simplified while testing if driver == "asyncpg": - return await asyncpg.connect(ip_address, instance_data.context, **kwargs) + return await asyncpg.connect( + ip_address, instance_data.context, **kwargs + ) return await self._loop.run_in_executor(None, connect_partial) # attempt to make connection to Cloud SQL instance for given timeout From 3d632a25b657159de84ae51a049f5246c70e6df9 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 13 Jun 2022 19:54:13 +0000 Subject: [PATCH 07/32] chore: update asyncpg param to direct_tls --- google/cloud/sql/connector/asyncpg.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index b2b6728d6..aa70fdb69 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -39,6 +39,5 @@ async def connect( host=ip_address, port=SERVER_PROXY_PORT, ssl=ctx, - tls_proxy=True, - timeout=30, # remove prior to PR + direct_tls=True, ) From bf59a2f7efde7e7d5fd981cff0fee8e18e2e8bd4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 30 Jun 2022 13:45:51 +0000 Subject: [PATCH 08/32] chore: add asyncpg unit test --- google/cloud/sql/connector/connector.py | 15 ++++++----- setup.py | 2 +- tests/unit/test_asyncpg.py | 34 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) create mode 100644 tests/unit/test_asyncpg.py diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 769bcf661..b6ce60b88 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -33,19 +33,20 @@ logger = logging.getLogger(name=__name__) -_default_connector = None - # This thread is used for background processing _thread: Optional[Thread] = None _loop: Optional[asyncio.AbstractEventLoop] = None +_default_connector = None def _get_loop() -> asyncio.AbstractEventLoop: + """Get event loop to use with Connector object. + Looks for an existing loop, creates one if one does not exist.""" global _loop, _thread try: loop = asyncio.get_running_loop() return loop - except RuntimeError as e: + except RuntimeError: if _loop is None: _loop = asyncio.new_event_loop() _thread = Thread(target=_loop.run_forever, daemon=True) @@ -114,7 +115,7 @@ def connect( :type driver: str :param: driver: A string representing the driver to connect with. Supported drivers are - pymysql, pg8000, and pytds. + pymysql, pg8000, asyncpg, and pytds. :param kwargs: Pass in any driver-specific arguments needed to connect to the Cloud @@ -133,7 +134,7 @@ async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: """Prepares and returns an async database connection object and starts a - background thread to refresh the certificates and metadata. + background task to refresh the certificates and metadata. :type instance_connection_string: str :param instance_connection_string: @@ -145,7 +146,7 @@ async def connect_async( :type driver: str :param: driver: A string representing the driver to connect with. Supported drivers are - pymysql, pg8000, and pytds. + pymysql, pg8000, asyncpg, and pytds. :param kwargs: Pass in any driver-specific arguments needed to connect to the Cloud @@ -180,8 +181,8 @@ async def connect_async( connect_func = { "pymysql": pymysql.connect, "pg8000": pg8000.connect, - "pytds": pytds.connect, "asyncpg": asyncpg.connect, + "pytds": pytds.connect, } # only accept supported database drivers diff --git a/setup.py b/setup.py index 7cc77c63b..6c50cc35d 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ "pymysql": ["PyMySQL==1.0.2"], "pg8000": ["pg8000==1.29.1"], "pytds": ["python-tds==1.11.0"], - "asyncpg": ["git+https://github.com/MagicStack/asyncpg.git@f2a937d2f25d1f997a066e6ba02acc3c4de676a4"] + "asyncpg": ["asyncpg==0.25.0"] }, python_requires=">=3.7", include_package_data=True, diff --git a/tests/unit/test_asyncpg.py b/tests/unit/test_asyncpg.py new file mode 100644 index 000000000..e08dc8ed6 --- /dev/null +++ b/tests/unit/test_asyncpg.py @@ -0,0 +1,34 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import ssl +import pytest +from typing import Any +from mock import patch, AsyncMock + +from google.cloud.sql.connector.asyncpg import connect + + +@pytest.mark.asyncio +@patch("asyncpg.connect", new_callable=AsyncMock) +async def test_asyncpg(mock_connect: AsyncMock, kwargs: Any) -> None: + """Test to verify that asyncpg gets to proper connection call.""" + ip_addr = "0.0.0.0" + context = ssl.create_default_context() + mock_connect.return_value = True + connection = await connect(ip_addr, context, **kwargs) + assert connection is True + # verify that driver connection call would be made + assert mock_connect.assert_called_once From 77dc3a71658a2b233980cebd8e0e9b6dd9141cd9 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 4 Jul 2022 20:38:52 +0000 Subject: [PATCH 09/32] chore: add system test for asyncpg driver --- .mypy.ini | 3 ++ google/cloud/sql/connector/connector.py | 41 +++++----------- requirements-test.txt | 1 + setup.py | 2 +- tests/system/test_asyncpg_connection.py | 64 +++++++++++++++++++++++++ 5 files changed, 80 insertions(+), 31 deletions(-) create mode 100644 tests/system/test_asyncpg_connection.py diff --git a/.mypy.ini b/.mypy.ini index 7a891e5ad..5ba446402 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -12,6 +12,9 @@ ignore_missing_imports = True [mypy-pg8000] ignore_missing_imports = True +[mypy-asyncpg] +ignore_missing_imports = True + [mypy-pytds] ignore_missing_imports = True diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index b6ce60b88..0f22cb315 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -33,25 +33,8 @@ logger = logging.getLogger(name=__name__) -# This thread is used for background processing -_thread: Optional[Thread] = None -_loop: Optional[asyncio.AbstractEventLoop] = None _default_connector = None - - -def _get_loop() -> asyncio.AbstractEventLoop: - """Get event loop to use with Connector object. - Looks for an existing loop, creates one if one does not exist.""" - global _loop, _thread - try: - loop = asyncio.get_running_loop() - return loop - except RuntimeError: - if _loop is None: - _loop = asyncio.new_event_loop() - _thread = Thread(target=_loop.run_forever, daemon=True) - _thread.start() - return _loop +ASYNC_DRIVERS = ["asyncpg"] class Connector: @@ -82,12 +65,10 @@ def __init__( enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, - loop: asyncio.AbstractEventLoop = None, ) -> None: - if loop: - self._loop = loop - else: - self._loop: asyncio.AbstractEventLoop = _get_loop() + self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) + self._thread.start() self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( generate_keys(), self._loop ) @@ -156,7 +137,8 @@ async def connect_async( :returns: A DB-API connection to the specified Cloud SQL instance. """ - loop = kwargs.pop("loop", self._loop) + # allow specific event loop to be passed in + loop = kwargs.pop("loop", asyncio.get_running_loop()) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: instance = self._instances[instance_connection_string] @@ -212,15 +194,14 @@ async def connect_async( # helper function to wrap in timeout async def get_connection() -> Any: instance_data, ip_address = await instance.connect_info(ip_type) + # async drivers are unblocking and can be awaited directly + if driver in ASYNC_DRIVERS: + return await connector(ip_address, instance_data.context, **kwargs) + # synchronous drivers are blocking and run using executor connect_partial = partial( connector, ip_address, instance_data.context, **kwargs ) - # simplified while testing - if driver == "asyncpg": - return await asyncpg.connect( - ip_address, instance_data.context, **kwargs - ) - return await self._loop.run_in_executor(None, connect_partial) + return await loop.run_in_executor(None, connect_partial) # attempt to make connection to Cloud SQL instance for given timeout try: diff --git a/requirements-test.txt b/requirements-test.txt index ab8688709..63dd56201 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -15,5 +15,6 @@ types-mock==4.0.15 twine==4.0.1 PyMySQL==1.0.2 pg8000==1.29.1 +asyncpg==0.26.0 python-tds==1.11.0 aioresponses==0.7.3 diff --git a/setup.py b/setup.py index 6c50cc35d..b803e8ef3 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ "pymysql": ["PyMySQL==1.0.2"], "pg8000": ["pg8000==1.29.1"], "pytds": ["python-tds==1.11.0"], - "asyncpg": ["asyncpg==0.25.0"] + "asyncpg": ["asyncpg==0.26.0"] }, python_requires=">=3.7", include_package_data=True, diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py new file mode 100644 index 000000000..7a0e759dc --- /dev/null +++ b/tests/system/test_asyncpg_connection.py @@ -0,0 +1,64 @@ +"""" +Copyright 2021 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import uuid +from typing import AsyncGenerator + +import asyncpg +import pytest +from google.cloud.sql.connector import Connector + +table_name = f"books_{uuid.uuid4().hex}" + + +@pytest.fixture(name="conn") +async def setup() -> AsyncGenerator: + # initialize Cloud SQL Python Connector object + connector = Connector() + conn: asyncpg.Connection = await connector.connect_async( + os.environ["POSTGRES_CONNECTION_NAME"], + "asyncpg", + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASS"], + db=os.environ["POSTGRES_DB"], + ) + await conn.execute( + f"CREATE TABLE IF NOT EXISTS {table_name}" + " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" + ) + + yield conn + + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + # close asyncpg connection + await conn.close() + # cleanup Connector object + connector.close() + + +@pytest.mark.asyncio +async def test__connection_with_asyncpg(conn: asyncpg.Connection) -> None: + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" + ) + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" + ) + + rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") + titles = [row[0] for row in rows] + + assert titles == ["Book One", "Book Two"] From 2445ee739b0f4ac93a5ad8bf99309a0de8e946c6 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 4 Jul 2022 20:48:01 +0000 Subject: [PATCH 10/32] chore: add asyncpg to readme --- README.md | 8 ++++++++ google/cloud/sql/connector/connector.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b7a6829f9..f0ba02850 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ The Cloud SQL Python Connector is a package to be used alongside a database driv Currently supported drivers are: - [`pymysql`](https://github.com/PyMySQL/PyMySQL) (MySQL) - [`pg8000`](https://github.com/tlocke/pg8000) (PostgreSQL) + - [`asyncpg`](https://github.com/MagicStack/asyncpg) (PostgreSQL) - [`pytds`](https://github.com/denisenkom/pytds) (SQL Server) @@ -37,9 +38,16 @@ based on your database dialect. pip install "cloud-sql-python-connector[pymysql]" ``` ### Postgres +There are two different database drivers that are supported for the Postgres dialect: + +#### pg8000 ``` pip install "cloud-sql-python-connector[pg8000]" ``` +#### asyncpg +``` +pip install "cloud-sql-python-connector[asyncpg]" +``` ### SQL Server ``` pip install "cloud-sql-python-connector[pytds]" diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 0f22cb315..a9d9f2551 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -114,7 +114,7 @@ def connect( async def connect_async( self, instance_connection_string: str, driver: str, **kwargs: Any ) -> Any: - """Prepares and returns an async database connection object and starts a + """Prepares and returns a database connection object and starts a background task to refresh the certificates and metadata. :type instance_connection_string: str From add5dfcbae0b70140036349f13b570db272a4f9a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 4 Jul 2022 20:55:59 +0000 Subject: [PATCH 11/32] chore: update test header --- tests/system/test_asyncpg_connection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 7a0e759dc..69b8cd460 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -1,5 +1,5 @@ """" -Copyright 2021 Google LLC +Copyright 2022 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From 00ff19a2ec82815ec90a25e1d4e668b8d7e2fd39 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 4 Jul 2022 21:04:59 +0000 Subject: [PATCH 12/32] chore: add header to asyncpg.py file --- google/cloud/sql/connector/asyncpg.py | 15 +++++++++++++++ tests/system/test_asyncpg_connection.py | 2 +- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index aa70fdb69..9aeadb3e7 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -1,3 +1,18 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" import ssl from typing import Any, TYPE_CHECKING diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 69b8cd460..a05df4666 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -1,4 +1,4 @@ -"""" +""" Copyright 2022 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); From 6a7600741b5e844b28abfdffedd83de729091bf6 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 5 Jul 2022 13:14:57 +0000 Subject: [PATCH 13/32] chore: add iam auth test for asyncpg --- tests/system/test_asyncpg_connection.py | 2 +- tests/system/test_asyncpg_iam_auth.py | 63 +++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 tests/system/test_asyncpg_iam_auth.py diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index a05df4666..1dc955477 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -50,7 +50,7 @@ async def setup() -> AsyncGenerator: @pytest.mark.asyncio -async def test__connection_with_asyncpg(conn: asyncpg.Connection) -> None: +async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: await conn.execute( f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" ) diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py new file mode 100644 index 000000000..b4c15e43c --- /dev/null +++ b/tests/system/test_asyncpg_iam_auth.py @@ -0,0 +1,63 @@ +""" +Copyright 2022 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import uuid +from typing import AsyncGenerator + +import asyncpg +import pytest +from google.cloud.sql.connector import Connector + +table_name = f"books_{uuid.uuid4().hex}" + + +@pytest.fixture(name="conn") +async def setup() -> AsyncGenerator: + # initialize Cloud SQL Python Connector object + connector = Connector() + conn: asyncpg.Connection = await connector.connect_async( + os.environ["POSTGRES_IAM_CONNECTION_NAME"], + "asyncpg", + user=os.environ["POSTGRES_IAM_USER"], + db=os.environ["POSTGRES_DB"], + ) + await conn.execute( + f"CREATE TABLE IF NOT EXISTS {table_name}" + " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" + ) + + yield conn + + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + # close asyncpg connection + await conn.close() + # cleanup Connector object + connector.close() + + +@pytest.mark.asyncio +async def test_connection_with_asyncpg_iam_auth(conn: asyncpg.Connection) -> None: + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" + ) + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" + ) + + rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") + titles = [row[0] for row in rows] + + assert titles == ["Book One", "Book Two"] From f30a8e0bacd5f0e81b280f59a143463c672e6603 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 5 Jul 2022 13:25:57 +0000 Subject: [PATCH 14/32] chore: fix iam auth system test --- tests/system/test_asyncpg_iam_auth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py index b4c15e43c..e2b74b6fc 100644 --- a/tests/system/test_asyncpg_iam_auth.py +++ b/tests/system/test_asyncpg_iam_auth.py @@ -33,6 +33,7 @@ async def setup() -> AsyncGenerator: "asyncpg", user=os.environ["POSTGRES_IAM_USER"], db=os.environ["POSTGRES_DB"], + enable_iam_auth=True, ) await conn.execute( f"CREATE TABLE IF NOT EXISTS {table_name}" From 3c8f376ee540403051ee3b9ce111c8c2b3a93986 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 5 Jul 2022 14:16:38 +0000 Subject: [PATCH 15/32] chore: add connection pool return option --- google/cloud/sql/connector/asyncpg.py | 27 ++++++++++++--- tests/system/test_asyncpg_connection.py | 45 +++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index 9aeadb3e7..c9c0ccfb5 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -14,7 +14,7 @@ limitations under the License. """ import ssl -from typing import Any, TYPE_CHECKING +from typing import Any, Union, TYPE_CHECKING SERVER_PROXY_PORT = 3307 @@ -24,7 +24,7 @@ async def connect( ip_address: str, ctx: ssl.SSLContext, **kwargs: Any -) -> "asyncpg.Connection": +) -> "Union[asyncpg.Connection, asyncpg.Pool]": """Helper function to create an asyncpg DB-API connection object. :type ip_address: str @@ -35,8 +35,12 @@ async def connect( :param ctx: An SSLContext object created from the Cloud SQL server CA cert and ephemeral cert. - :rtype: asyncpg.Connection - :returns: An asyncpg Connection object for the Cloud SQL instance. + :type kwargs: Any + :param kwargs: Keyword arguments for establishing connection object + or connection pool to Cloud SQL instance. + + :rtype: Union[asyncpg.Connection, asyncpg.Pool] + :returns: An asyncpg Connection or Pool object for the Cloud SQL instance. """ try: import asyncpg @@ -47,6 +51,20 @@ async def connect( user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) + pool = kwargs.pop("pool", False) + # return connection pool if pool is set to True + if pool: + return await asyncpg.create_pool( + user=user, + database=db, + password=passwd, + host=ip_address, + port=SERVER_PROXY_PORT, + ssl=ctx, + direct_tls=True, + **kwargs, + ) + # return regular asyncpg connection return await asyncpg.connect( user=user, database=db, @@ -55,4 +73,5 @@ async def connect( port=SERVER_PROXY_PORT, ssl=ctx, direct_tls=True, + **kwargs, ) diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 1dc955477..0a9a913b0 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -49,6 +49,35 @@ async def setup() -> AsyncGenerator: connector.close() +@pytest.fixture() +async def pool() -> AsyncGenerator: + # initialize Cloud SQL Python Connector object + connector = Connector() + pool: asyncpg.Pool = await connector.connect_async( + os.environ["POSTGRES_CONNECTION_NAME"], + "asyncpg", + user=os.environ["POSTGRES_USER"], + password=os.environ["POSTGRES_PASS"], + db=os.environ["POSTGRES_DB"], + pool=True, + ) + async with pool.acquire() as conn: + await conn.execute( + f"CREATE TABLE IF NOT EXISTS {table_name}" + " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" + ) + + yield pool + + async with pool.acquire() as conn: + await conn.execute(f"DROP TABLE IF EXISTS {table_name}") + + # close asyncpg connection pool + await pool.close() + # cleanup Connector object + connector.close() + + @pytest.mark.asyncio async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: await conn.execute( @@ -62,3 +91,19 @@ async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: titles = [row[0] for row in rows] assert titles == ["Book One", "Book Two"] + + +@pytest.mark.asyncio +async def test_pooled_connection_with_asyncpg(pool: asyncpg.Pool) -> None: + async with pool.acquire() as conn: + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" + ) + await conn.execute( + f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" + ) + + rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") + titles = [row[0] for row in rows] + + assert titles == ["Book One", "Book Two"] From 5f2daf090d161802b8d317fbcca380dc4437028f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 7 Jul 2022 19:53:02 +0000 Subject: [PATCH 16/32] chore: lint --- google/cloud/sql/connector/asyncpg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index c9c0ccfb5..f7a2743ab 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -38,7 +38,7 @@ async def connect( :type kwargs: Any :param kwargs: Keyword arguments for establishing connection object or connection pool to Cloud SQL instance. - + :rtype: Union[asyncpg.Connection, asyncpg.Pool] :returns: An asyncpg Connection or Pool object for the Cloud SQL instance. """ From 90c66f74ec67022b4a14493b204f5a5e8d7337f3 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 12 Jul 2022 14:39:03 +0000 Subject: [PATCH 17/32] chore: remove connection pooling --- google/cloud/sql/connector/asyncpg.py | 27 ++++----------- tests/system/test_asyncpg_connection.py | 45 ------------------------- 2 files changed, 7 insertions(+), 65 deletions(-) diff --git a/google/cloud/sql/connector/asyncpg.py b/google/cloud/sql/connector/asyncpg.py index f7a2743ab..0e03c0a65 100644 --- a/google/cloud/sql/connector/asyncpg.py +++ b/google/cloud/sql/connector/asyncpg.py @@ -14,7 +14,7 @@ limitations under the License. """ import ssl -from typing import Any, Union, TYPE_CHECKING +from typing import Any, TYPE_CHECKING SERVER_PROXY_PORT = 3307 @@ -24,7 +24,7 @@ async def connect( ip_address: str, ctx: ssl.SSLContext, **kwargs: Any -) -> "Union[asyncpg.Connection, asyncpg.Pool]": +) -> "asyncpg.Connection": """Helper function to create an asyncpg DB-API connection object. :type ip_address: str @@ -36,11 +36,11 @@ async def connect( cert and ephemeral cert. :type kwargs: Any - :param kwargs: Keyword arguments for establishing connection object - or connection pool to Cloud SQL instance. + :param kwargs: Keyword arguments for establishing asyncpg connection + object to Cloud SQL instance. - :rtype: Union[asyncpg.Connection, asyncpg.Pool] - :returns: An asyncpg Connection or Pool object for the Cloud SQL instance. + :rtype: asyncpg.Connection + :returns: An asyncpg.Connection object to a Cloud SQL instance. """ try: import asyncpg @@ -51,20 +51,7 @@ async def connect( user = kwargs.pop("user") db = kwargs.pop("db") passwd = kwargs.pop("password", None) - pool = kwargs.pop("pool", False) - # return connection pool if pool is set to True - if pool: - return await asyncpg.create_pool( - user=user, - database=db, - password=passwd, - host=ip_address, - port=SERVER_PROXY_PORT, - ssl=ctx, - direct_tls=True, - **kwargs, - ) - # return regular asyncpg connection + return await asyncpg.connect( user=user, database=db, diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 0a9a913b0..1dc955477 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -49,35 +49,6 @@ async def setup() -> AsyncGenerator: connector.close() -@pytest.fixture() -async def pool() -> AsyncGenerator: - # initialize Cloud SQL Python Connector object - connector = Connector() - pool: asyncpg.Pool = await connector.connect_async( - os.environ["POSTGRES_CONNECTION_NAME"], - "asyncpg", - user=os.environ["POSTGRES_USER"], - password=os.environ["POSTGRES_PASS"], - db=os.environ["POSTGRES_DB"], - pool=True, - ) - async with pool.acquire() as conn: - await conn.execute( - f"CREATE TABLE IF NOT EXISTS {table_name}" - " ( id CHAR(20) NOT NULL, title TEXT NOT NULL );" - ) - - yield pool - - async with pool.acquire() as conn: - await conn.execute(f"DROP TABLE IF EXISTS {table_name}") - - # close asyncpg connection pool - await pool.close() - # cleanup Connector object - connector.close() - - @pytest.mark.asyncio async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: await conn.execute( @@ -91,19 +62,3 @@ async def test_connection_with_asyncpg(conn: asyncpg.Connection) -> None: titles = [row[0] for row in rows] assert titles == ["Book One", "Book Two"] - - -@pytest.mark.asyncio -async def test_pooled_connection_with_asyncpg(pool: asyncpg.Pool) -> None: - async with pool.acquire() as conn: - await conn.execute( - f"INSERT INTO {table_name} (id, title) VALUES ('book1', 'Book One')" - ) - await conn.execute( - f"INSERT INTO {table_name} (id, title) VALUES ('book2', 'Book Two')" - ) - - rows = await conn.fetch(f"SELECT title FROM {table_name} ORDER BY ID") - titles = [row[0] for row in rows] - - assert titles == ["Book One", "Book Two"] From b575839ec9bff7d764b5b88d81c838b5df55fb3b Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 12 Jul 2022 16:27:38 +0000 Subject: [PATCH 18/32] chore: update Connector loop logic --- google/cloud/sql/connector/connector.py | 36 +++++++++++++++++-------- google/cloud/sql/connector/instance.py | 6 ++++- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index a9d9f2551..55ef03ed1 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -27,7 +27,7 @@ import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.utils import generate_keys from google.auth.credentials import Credentials -from threading import Thread +from threading import Thread, current_thread from typing import Any, Dict, Optional, Type from functools import partial @@ -65,13 +65,23 @@ def __init__( enable_iam_auth: bool = False, timeout: int = 30, credentials: Optional[Credentials] = None, + loop: asyncio.AbstractEventLoop = None, ) -> None: - self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) - self._thread.start() - self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - generate_keys(), self._loop - ) + # if event loop is given, use for background tasks + if loop: + print("Event loop given!") + self._loop: asyncio.AbstractEventLoop = loop + self._thread = None + self._keys: asyncio.Task = loop.create_task(generate_keys()) + # if no event loop is given, spin up new loop in background thread + else: + print("No event loop, spinning up loop in thread!") + self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) + self._thread.start() + self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + generate_keys(), self._loop + ) self._instances: Dict[str, Instance] = {} # set default params for connections @@ -106,6 +116,12 @@ def connect( :returns: A DB-API connection to the specified Cloud SQL instance. """ + # check if event loop is running in current thread + if self._loop._thread_id == current_thread().ident and self._loop.is_running(): + # TODO: make custom exception class + raise RuntimeError( + "Connector event loop is running in current thread! Event loop must be attached to a different thread to prevent blocking code!" + ) connect_task = asyncio.run_coroutine_threadsafe( self.connect_async(instance_connection_string, driver, **kwargs), self._loop ) @@ -137,8 +153,6 @@ async def connect_async( :returns: A DB-API connection to the specified Cloud SQL instance. """ - # allow specific event loop to be passed in - loop = kwargs.pop("loop", asyncio.get_running_loop()) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: instance = self._instances[instance_connection_string] @@ -154,7 +168,7 @@ async def connect_async( instance_connection_string, driver, self._keys, - loop, + self._loop, self._credentials, enable_iam_auth, ) @@ -201,7 +215,7 @@ async def get_connection() -> Any: connect_partial = partial( connector, ip_address, instance_data.context, **kwargs ) - return await loop.run_in_executor(None, connect_partial) + return await self._loop.run_in_executor(None, connect_partial) # attempt to make connection to Cloud SQL instance for given timeout try: diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index 718ba72be..da36a7f37 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -251,7 +251,11 @@ def __init__( self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}" self._loop = loop - self._keys = asyncio.wrap_future(keys, loop=self._loop) + self._keys = ( + keys + if isinstance(keys, asyncio.Task) + else asyncio.wrap_future(keys, loop=self._loop) + ) # validate credentials type if not isinstance(credentials, Credentials) and credentials is not None: raise CredentialsTypeError( From f79b4ebac9fed38dccb54cf3bfa06a9836e1e753 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 14:24:52 +0000 Subject: [PATCH 19/32] chore: add custom exception for invalid loop state --- google/cloud/sql/connector/connector.py | 53 +++++++++++++++++-------- google/cloud/sql/connector/instance.py | 3 +- 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 55ef03ed1..9d85c6137 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -27,8 +27,8 @@ import google.cloud.sql.connector.asyncpg as asyncpg from google.cloud.sql.connector.utils import generate_keys from google.auth.credentials import Credentials -from threading import Thread, current_thread -from typing import Any, Dict, Optional, Type +from threading import Thread +from typing import Any, Dict, Optional, Type, Union from functools import partial logger = logging.getLogger(name=__name__) @@ -37,6 +37,16 @@ ASYNC_DRIVERS = ["asyncpg"] +class ConnectorLoopError(Exception): + """ + Raised when Connector.connect is called with Connector._loop + in an invalid state (event loop in current thread). + """ + + def __init__(self, *args: Any) -> None: + super(ConnectorLoopError, self).__init__(self, *args) + + class Connector: """A class to configure and create connections to Cloud SQL instances. @@ -57,6 +67,11 @@ class Connector: :param credentials Credentials object used to authenticate connections to Cloud SQL server. If not specified, Application Default Credentials are used. + + :type loop: asyncio.AbstractEventLoop + :param loop + Event loop to run asyncio tasks, if not specified, defaults to + creating new event loop on background thread. """ def __init__( @@ -69,19 +84,17 @@ def __init__( ) -> None: # if event loop is given, use for background tasks if loop: - print("Event loop given!") self._loop: asyncio.AbstractEventLoop = loop - self._thread = None - self._keys: asyncio.Task = loop.create_task(generate_keys()) + self._thread: Optional[Thread] = None + self._keys: Union[ + asyncio.Task, concurrent.futures.Future + ] = loop.create_task(generate_keys()) # if no event loop is given, spin up new loop in background thread else: - print("No event loop, spinning up loop in thread!") - self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() - self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) + self._loop = asyncio.new_event_loop() + self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() - self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( - generate_keys(), self._loop - ) + self._keys = asyncio.run_coroutine_threadsafe(generate_keys(), self._loop) self._instances: Dict[str, Instance] = {} # set default params for connections @@ -116,12 +129,18 @@ def connect( :returns: A DB-API connection to the specified Cloud SQL instance. """ - # check if event loop is running in current thread - if self._loop._thread_id == current_thread().ident and self._loop.is_running(): - # TODO: make custom exception class - raise RuntimeError( - "Connector event loop is running in current thread! Event loop must be attached to a different thread to prevent blocking code!" - ) + try: + # check if event loop is running in current thread + if self._loop == asyncio.get_running_loop(): + raise ConnectorLoopError( + "Connector event loop is running in current thread!" + "Event loop must be attached to a different thread to prevent blocking code!" + ) + # asyncio.get_running_loop will throw RunTimeError if no running loop is present + except RuntimeError: + pass + + # if event loop is not in current thread, proceed with connection connect_task = asyncio.run_coroutine_threadsafe( self.connect_async(instance_connection_string, driver, **kwargs), self._loop ) diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index da36a7f37..f5ae24f57 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -43,6 +43,7 @@ Dict, Optional, Tuple, + Union, ) import logging @@ -227,7 +228,7 @@ def __init__( self, instance_connection_string: str, driver_name: str, - keys: concurrent.futures.Future, + keys: Union[asyncio.Task, concurrent.futures.Future], loop: asyncio.AbstractEventLoop, credentials: Optional[Credentials] = None, enable_iam_auth: bool = False, From 4d9ce2c93364756ff8c03b1f9d0affb64f92061f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 14:41:45 +0000 Subject: [PATCH 20/32] chore: add create_async_connector function --- google/cloud/sql/connector/connector.py | 40 +++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 94ae2276f..9808911c4 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -270,3 +270,43 @@ async def _close(self) -> None: await asyncio.gather( *[instance.close() for instance in self._instances.values()] ) + + +async def create_async_connector( + ip_type: IPTypes = IPTypes.PUBLIC, + enable_iam_auth: bool = False, + timeout: int = 30, + credentials: Optional[Credentials] = None, + loop: asyncio.AbstractEventLoop = None, +) -> Connector: + """ + Create Connector object for asyncio connections that can auto-detect + and use current thread's running event loop. + + :type ip_type: IPTypes + :param ip_type + The IP type (public or private) used to connect. IP types + can be either IPTypes.PUBLIC or IPTypes.PRIVATE. + + :type enable_iam_auth: bool + :param enable_iam_auth + Enables IAM based authentication (Postgres only). + + :type timeout: int + :param timeout + The time limit for a connection before raising a TimeoutError. + + :type credentials: google.auth.credentials.Credentials + :param credentials + Credentials object used to authenticate connections to Cloud SQL server. + If not specified, Application Default Credentials are used. + + :type loop: asyncio.AbstractEventLoop + :param loop + Event loop to run asyncio tasks, if not specified, defaults + to current thread's running event loop. + """ + # if no loop given, automatically detect running event loop + if loop is None: + loop = asyncio.get_running_loop() + return Connector(ip_type, enable_iam_auth, timeout, credentials, loop) From 85075a9a2ed289aeafe7e8a29b6f18fd51aed6d8 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 14:54:13 +0000 Subject: [PATCH 21/32] chore: update system tests to use create_async_connector --- google/cloud/sql/connector/__init__.py | 4 ++-- tests/system/test_asyncpg_connection.py | 4 ++-- tests/system/test_asyncpg_iam_auth.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/google/cloud/sql/connector/__init__.py b/google/cloud/sql/connector/__init__.py index 0e932169b..527e177c2 100644 --- a/google/cloud/sql/connector/__init__.py +++ b/google/cloud/sql/connector/__init__.py @@ -13,11 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. """ -from .connector import Connector +from .connector import Connector, create_async_connector from .instance import IPTypes -__ALL__ = [Connector, IPTypes] +__ALL__ = [create_async_connector, Connector, IPTypes] try: import pkg_resources diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index 1dc955477..fb850f101 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -19,7 +19,7 @@ import asyncpg import pytest -from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import create_async_connector table_name = f"books_{uuid.uuid4().hex}" @@ -27,7 +27,7 @@ @pytest.fixture(name="conn") async def setup() -> AsyncGenerator: # initialize Cloud SQL Python Connector object - connector = Connector() + connector = await create_async_connector() conn: asyncpg.Connection = await connector.connect_async( os.environ["POSTGRES_CONNECTION_NAME"], "asyncpg", diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py index e2b74b6fc..7bf73abc0 100644 --- a/tests/system/test_asyncpg_iam_auth.py +++ b/tests/system/test_asyncpg_iam_auth.py @@ -19,7 +19,7 @@ import asyncpg import pytest -from google.cloud.sql.connector import Connector +from google.cloud.sql.connector import create_async_connector table_name = f"books_{uuid.uuid4().hex}" @@ -27,7 +27,7 @@ @pytest.fixture(name="conn") async def setup() -> AsyncGenerator: # initialize Cloud SQL Python Connector object - connector = Connector() + connector = await create_async_connector() conn: asyncpg.Connection = await connector.connect_async( os.environ["POSTGRES_IAM_CONNECTION_NAME"], "asyncpg", From 9c68d120a557df0d744d4a6360389172abb99b0e Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 15:06:49 +0000 Subject: [PATCH 22/32] chore: expose Connector.close_async --- google/cloud/sql/connector/connector.py | 6 ++++-- tests/system/test_asyncpg_connection.py | 2 +- tests/system/test_asyncpg_iam_auth.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 9808911c4..ea5517764 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -260,11 +260,13 @@ def __exit__( def close(self) -> None: """Close Connector by stopping tasks and releasing resources.""" - close_future = asyncio.run_coroutine_threadsafe(self._close(), loop=self._loop) + close_future = asyncio.run_coroutine_threadsafe( + self.close_async(), loop=self._loop + ) # Will attempt to safely shut down tasks for 5s close_future.result(timeout=5) - async def _close(self) -> None: + async def close_async(self) -> None: """Helper function to cancel Instances' tasks and close aiohttp.ClientSession.""" await asyncio.gather( diff --git a/tests/system/test_asyncpg_connection.py b/tests/system/test_asyncpg_connection.py index fb850f101..35a1de135 100644 --- a/tests/system/test_asyncpg_connection.py +++ b/tests/system/test_asyncpg_connection.py @@ -46,7 +46,7 @@ async def setup() -> AsyncGenerator: # close asyncpg connection await conn.close() # cleanup Connector object - connector.close() + await connector.close_async() @pytest.mark.asyncio diff --git a/tests/system/test_asyncpg_iam_auth.py b/tests/system/test_asyncpg_iam_auth.py index 7bf73abc0..1229f9f17 100644 --- a/tests/system/test_asyncpg_iam_auth.py +++ b/tests/system/test_asyncpg_iam_auth.py @@ -46,7 +46,7 @@ async def setup() -> AsyncGenerator: # close asyncpg connection await conn.close() # cleanup Connector object - connector.close() + await connector.close_async() @pytest.mark.asyncio From 69bdebcf2cf72d60ff4c975a3551b58627e45770 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 15:14:52 +0000 Subject: [PATCH 23/32] chore: update comments --- google/cloud/sql/connector/connector.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index ea5517764..4098ec369 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -118,7 +118,7 @@ def connect( :type driver: str :param: driver: A string representing the driver to connect with. Supported drivers are - pymysql, pg8000, asyncpg, and pytds. + pymysql, pg8000, and pytds. :param kwargs: Pass in any driver-specific arguments needed to connect to the Cloud @@ -171,6 +171,12 @@ async def connect_async( :returns: A DB-API connection to the specified Cloud SQL instance. """ + # Create an Instance object from the connection string. + # The Instance should verify arguments. + # + # Use the Instance to establish an SSL Connection. + # + # Return a DBAPI connection enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) if instance_connection_string in self._instances: instance = self._instances[instance_connection_string] From 11aeea4e4558c641769c43a5df84ead9df62c85f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 17:21:52 +0000 Subject: [PATCH 24/32] chore: wrap_keys in Connector instead of Instance --- google/cloud/sql/connector/connector.py | 12 ++++++------ google/cloud/sql/connector/instance.py | 13 +++---------- tests/conftest.py | 7 ++++--- tests/unit/test_instance.py | 12 +++++++++--- 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 4098ec369..128807c3e 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -14,7 +14,6 @@ limitations under the License. """ import asyncio -import concurrent import logging from types import TracebackType from google.cloud.sql.connector.instance import ( @@ -28,7 +27,7 @@ from google.cloud.sql.connector.utils import generate_keys from google.auth.credentials import Credentials from threading import Thread -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type from functools import partial logger = logging.getLogger(name=__name__) @@ -85,15 +84,16 @@ def __init__( if loop: self._loop: asyncio.AbstractEventLoop = loop self._thread: Optional[Thread] = None - self._keys: Union[ - asyncio.Task, concurrent.futures.Future - ] = loop.create_task(generate_keys()) + self._keys: asyncio.Future = loop.create_task(generate_keys()) # if no event loop is given, spin up new loop in background thread else: self._loop = asyncio.new_event_loop() self._thread = Thread(target=self._loop.run_forever, daemon=True) self._thread.start() - self._keys = asyncio.run_coroutine_threadsafe(generate_keys(), self._loop) + self._keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), self._loop), + loop=self._loop, + ) self._instances: Dict[str, Instance] = {} # set default params for connections diff --git a/google/cloud/sql/connector/instance.py b/google/cloud/sql/connector/instance.py index f5ae24f57..a1282e147 100644 --- a/google/cloud/sql/connector/instance.py +++ b/google/cloud/sql/connector/instance.py @@ -28,7 +28,6 @@ # Importing libraries import asyncio import aiohttp -import concurrent import datetime from enum import Enum import google.auth @@ -39,11 +38,9 @@ from tempfile import TemporaryDirectory from typing import ( Any, - Awaitable, Dict, Optional, Tuple, - Union, ) import logging @@ -211,7 +208,7 @@ def _client_session(self) -> aiohttp.ClientSession: return self.__client_session _credentials: Optional[Credentials] = None - _keys: Awaitable + _keys: asyncio.Future _instance_connection_string: str _user_agent_string: str @@ -228,7 +225,7 @@ def __init__( self, instance_connection_string: str, driver_name: str, - keys: Union[asyncio.Task, concurrent.futures.Future], + keys: asyncio.Future, loop: asyncio.AbstractEventLoop, credentials: Optional[Credentials] = None, enable_iam_auth: bool = False, @@ -252,11 +249,7 @@ def __init__( self._user_agent_string = f"{APPLICATION_NAME}/{version}+{driver_name}" self._loop = loop - self._keys = ( - keys - if isinstance(keys, asyncio.Task) - else asyncio.wrap_future(keys, loop=self._loop) - ) + self._keys = keys # validate credentials type if not isinstance(credentials, Credentials) and credentials is not None: raise CredentialsTypeError( diff --git a/tests/conftest.py b/tests/conftest.py index 45b6c8d22..e3e2800e2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -154,9 +154,10 @@ async def instance( Instance with mocked API calls. """ # generate client key pair - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) - key_task = asyncio.wrap_future(keys, loop=event_loop) - _, client_key = await key_task + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) + _, client_key = await keys with patch("google.auth.default") as mock_auth: mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 30b1d6fab..63cb6b090 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -50,7 +50,9 @@ async def test_Instance_init( """ connect_string = "test-project:test-region:test-instance" - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) with patch("google.auth.default") as mock_auth: mock_auth.return_value = fake_credentials, None instance = Instance(connect_string, "pymysql", keys, event_loop) @@ -75,7 +77,9 @@ async def test_Instance_init_bad_credentials( throws proper error for bad credentials arg type. """ connect_string = "test-project:test-region:test-instance" - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) with pytest.raises(CredentialsTypeError): instance = Instance(connect_string, "pymysql", keys, event_loop, credentials=1) await instance.close() @@ -356,7 +360,9 @@ async def test_ClientResponseError( Test that detailed error message is applied to ClientResponseError. """ # mock Cloud SQL Admin API calls with exceptions - keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop) + keys = asyncio.wrap_future( + asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop + ) get_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings" post_url = "https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert" with aioresponses() as mocked: From 25b2734495c4b7ec568c795cf8687c35159963fe Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 17:53:52 +0000 Subject: [PATCH 25/32] chore: update keys in fixture --- tests/conftest.py | 6 ++++-- tests/unit/mocks.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e3e2800e2..1085a6729 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,7 +25,7 @@ from aioresponses import aioresponses from mock import patch -from unit.mocks import FakeCSQLInstance # type: ignore +from unit.mocks import FakeCSQLInstance, wait_for_keys # type: ignore from google.cloud.sql.connector import Connector from google.cloud.sql.connector.instance import Instance from google.cloud.sql.connector.utils import generate_keys @@ -193,7 +193,9 @@ async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls mock_instance = FakeCSQLInstance(project, region, instance_name) - _, client_key = connector._keys.result() + _, client_key = asyncio.run_coroutine_threadsafe( + wait_for_keys(connector._keys), connector._loop + ).result() with aioresponses() as mocked: mocked.get( f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings", diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 9e53e506b..72532d72d 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -15,6 +15,7 @@ """ # file containing all mocks used for Cloud SQL Python Connector unit tests +import asyncio import json import ssl from tempfile import TemporaryDirectory @@ -238,3 +239,12 @@ def generate_ephemeral(self, client_bytes: str) -> str: } } ) + + +async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]: + """ + Helper method to await keys of Connector in tests prior to + initializing an Instance object. + """ + keys = await future + return keys From cbf9d6d01520fa187c371aa6d1e44afb6335e845 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 19:53:41 +0000 Subject: [PATCH 26/32] chore: add test coverage --- tests/system/test_connector_object.py | 19 +++++++++++ tests/unit/test_connector.py | 48 ++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/tests/system/test_connector_object.py b/tests/system/test_connector_object.py index d044a78d7..909aa9c4e 100644 --- a/tests/system/test_connector_object.py +++ b/tests/system/test_connector_object.py @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. """ +import asyncio import os import pymysql import sqlalchemy @@ -21,6 +22,7 @@ from google.cloud.sql.connector import Connector import datetime import concurrent.futures +from threading import Thread def init_connection_engine( @@ -122,3 +124,20 @@ def test_connector_as_context_manager() -> None: with pool.connect() as conn: conn.execute("SELECT 1") + + +def test_connector_with_custom_loop() -> None: + """Test that Connector can be used with custom loop in background thread.""" + # create new event loop and start it in thread + loop = asyncio.new_event_loop() + thread = Thread(target=loop.run_forever, daemon=True) + thread.start() + + with Connector(loop=loop) as connector: + pool = init_connection_engine(connector) + + with pool.connect() as conn: + result = conn.execute("SELECT 1").fetchone() + assert result[0] == 1 + # assert that Connector does not start its own thread + assert connector._thread is None diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index ea91a465e..95c683115 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -17,7 +17,8 @@ import pytest # noqa F401 Needed to run the tests import asyncio -from google.cloud.sql.connector import Connector, IPTypes +from google.cloud.sql.connector import Connector, IPTypes, create_async_connector +from google.cloud.sql.connector.connector import ConnectorLoopError from mock import patch from typing import Any @@ -77,6 +78,33 @@ def test_connect_enable_iam_auth_error() -> None: connector._instances = {} +def test_connect_with_unsupported_driver(connector: Connector) -> None: + # try to connect using unsupported driver, should raise KeyError + with pytest.raises(KeyError) as exc_info: + connector.connect( + "my-project:my-region:my-instance", + "bad_driver", + ) + # assert custom error message for unsupported driver is present + assert exc_info.value.args[0] == "Driver 'bad_driver' is not supported." + connector.close() + + +@pytest.mark.asyncio +async def test_connect_ConnectorLoopError() -> None: + """Test that ConnectorLoopError is thrown when Connector.connect + is called with event loop running in current thread.""" + current_loop = asyncio.get_running_loop() + connector = Connector(loop=current_loop) + # try to connect using current thread's loop, should raise error + pytest.raises( + ConnectorLoopError, + connector.connect, + "my-project:my-region:my-instance", + "pg8000", + ) + + def test_Connector_Init() -> None: """Test that Connector __init__ sets default properties properly.""" connector = Connector() @@ -87,6 +115,15 @@ def test_Connector_Init() -> None: connector.close() +def test_Connector_Init_context_manager() -> None: + """Test that Connector as context manager sets default properties properly.""" + with Connector() as connector: + assert connector._ip_type == IPTypes.PUBLIC + assert connector._enable_iam_auth is False + assert connector._timeout == 30 + assert connector._credentials is None + + def test_Connector_connect(connector: Connector) -> None: """Test that Connector.connect can properly return a DB API connection.""" connect_string = "my-project:my-region:my-instance" @@ -98,3 +135,12 @@ def test_Connector_connect(connector: Connector) -> None: ) # verify connector made connection call assert connection is True + + +@pytest.mark.asyncio +async def test_create_async_connector() -> None: + """Test that create_async_connector properly initializes connector + object using current thread's event loop""" + connector = await create_async_connector() + assert connector._loop == asyncio.get_running_loop() + await connector.close_async() From 9a9df43fb3b2c991318f56fcc189debdb1c54d1f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 20:34:08 +0000 Subject: [PATCH 27/32] chore: add async sample to readme --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/README.md b/README.md index f0ba02850..a1492dd4d 100644 --- a/README.md +++ b/README.md @@ -283,6 +283,58 @@ connector.connect( ) ``` +### Async Driver Usage +The Cloud SQL Connector for Python currently supports the +[asyncpg](https://magicstack.github.io/asyncpg) Postgres database driver. +This driver leverages [asyncio](https://docs.python.org/3/library/asyncio.html) +to improve the speed and efficiency of database connections through concurrency. + +The Cloud SQL Connector has an async `create_async_connector` that can be used +and is recommended for async drivers as it returns a `Connector` object +that uses the current thread's running event loop automatically. + +The `create_async_connector` allows all the same input arguments as [`Connector`] +(#configuring-the-connector). + +Once a `Connector` object is returned by `create_async_connector` you can call +its `connect_async` method, just as you would the `connect` method: + +```python +import asyncpg +from google.cloud.sql.connector import create_async_connector + +async def main(): + # intialize Connector object using 'create_async_connector' + connector = await create_async_connector() + + # create connection to Cloud SQL database + conn: asyncpg.Connection = await connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="root", + password="shhh", + db="your-db-name" + # ... additional database driver args + ) + + # insert into Cloud SQL database (example) + await conn.execute("INSERT INTO ratings (title, genre, rating) VALUES ('Batman', 'Action', 8.2)") + + # query Cloud SQL database (examples) + results = await conn.fetch("SELECT * from ratings") + for row in results: + # ... do something with results + + # close asyncpg connection + await conn.close + + # close Cloud SQL Connector + await connector.close_async() +``` + +For more details on interacting with an `asyncpg.Connection`, please visit +the [official documentation](https://magicstack.github.io/asyncpg/current/api/index.html). + ## Support policy ### Major version lifecycle From d44c2e6abc18790e38eddffa7dbb46dd922926a2 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 20:37:31 +0000 Subject: [PATCH 28/32] chore: update README link --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a1492dd4d..8f6cebc43 100644 --- a/README.md +++ b/README.md @@ -293,8 +293,8 @@ The Cloud SQL Connector has an async `create_async_connector` that can be used and is recommended for async drivers as it returns a `Connector` object that uses the current thread's running event loop automatically. -The `create_async_connector` allows all the same input arguments as [`Connector`] -(#configuring-the-connector). +The `create_async_connector` allows all the same input arguments as the [Connector] +(#configuring-the-connector) object. Once a `Connector` object is returned by `create_async_connector` you can call its `connect_async` method, just as you would the `connect` method: From d564fdc105803b0bde13f159073ff1c73101d14c Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 13 Jul 2022 20:42:12 +0000 Subject: [PATCH 29/32] chore: fix link in README --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f6cebc43..42a44f8ea 100644 --- a/README.md +++ b/README.md @@ -293,8 +293,8 @@ The Cloud SQL Connector has an async `create_async_connector` that can be used and is recommended for async drivers as it returns a `Connector` object that uses the current thread's running event loop automatically. -The `create_async_connector` allows all the same input arguments as the [Connector] -(#configuring-the-connector) object. +The `create_async_connector` allows all the same input arguments as the +[Connector](#configuring-the-connector) object. Once a `Connector` object is returned by `create_async_connector` you can call its `connect_async` method, just as you would the `connect` method: From c12178cd54daea1dcfb9198fb8b4420b63be2939 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Sat, 16 Jul 2022 21:28:02 +0000 Subject: [PATCH 30/32] chore: update readme and add async context manager to Connector --- README.md | 88 +++++++++++++++++++------ google/cloud/sql/connector/connector.py | 13 ++++ tests/unit/test_connector.py | 13 ++++ 3 files changed, 93 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 42a44f8ea..114a09a1f 100644 --- a/README.md +++ b/README.md @@ -119,9 +119,9 @@ def getconn() -> pymysql.connections.Connection: conn: pymysql.connections.Connection = connector.connect( "project:region:instance", "pymysql", - user="root", - password="shhh", - db="your-db-name" + user="my-user", + password="my-password", + db="my-db-name" ) return conn @@ -196,9 +196,9 @@ def getconn() -> pymysql.connections.Connection: conn = connector.connect( "project:region:instance", "pymysql", - user="root", - password="shhh", - db="your-db-name" + user="my-user", + password="my-password", + db="my-db-name" ) return conn @@ -253,7 +253,7 @@ connector.connect( "project:region:instance", "pg8000", user="postgres-iam-user@gmail.com", - db="my_database", + db="my-db-name", enable_iam_auth=True, ) ``` @@ -266,7 +266,7 @@ Once you have followed the steps linked above, you can run the following code to connector.connect( "project:region:instance", "pytds", - db="my_database", + db="my-db-name", active_directory_auth=True, server_name="public.[instance].[location].[project].cloudsql.[domain]", ) @@ -276,7 +276,7 @@ Or, if using Private IP: connector.connect( "project:region:instance", "pytds", - db="my_database", + db="my-db-name", active_directory_auth=True, server_name="private.[instance].[location].[project].cloudsql.[domain]", ip_type=IPTypes.PRIVATE @@ -284,14 +284,18 @@ connector.connect( ``` ### Async Driver Usage -The Cloud SQL Connector for Python currently supports the -[asyncpg](https://magicstack.github.io/asyncpg) Postgres database driver. -This driver leverages [asyncio](https://docs.python.org/3/library/asyncio.html) -to improve the speed and efficiency of database connections through concurrency. - -The Cloud SQL Connector has an async `create_async_connector` that can be used -and is recommended for async drivers as it returns a `Connector` object -that uses the current thread's running event loop automatically. +The Cloud SQL Connector is compatible with +[asyncio](https://docs.python.org/3/library/asyncio.html) to improve the speed +and efficiency of database connections through concurrency. You can use all +non-asyncio drivers through the `Connector.connect_async` function, in addition +to the following asyncio database drivers: +- [asyncpg](https://magicstack.github.io/asyncpg) (Postgres) + +The Cloud SQL Connector has a helper `create_async_connector` function that is +recommended for asyncio database connections. It returns a `Connector` +object that uses the current thread's running event loop. This is different +than `Connector()` which by default initializes a new event loop in a +background thread. The `create_async_connector` allows all the same input arguments as the [Connector](#configuring-the-connector) object. @@ -311,16 +315,16 @@ async def main(): conn: asyncpg.Connection = await connector.connect_async( "project:region:instance", # Cloud SQL instance connection name "asyncpg", - user="root", - password="shhh", - db="your-db-name" + user="my-user", + password="my-password", + db="my-db-name" # ... additional database driver args ) # insert into Cloud SQL database (example) await conn.execute("INSERT INTO ratings (title, genre, rating) VALUES ('Batman', 'Action', 8.2)") - # query Cloud SQL database (examples) + # query Cloud SQL database (example) results = await conn.fetch("SELECT * from ratings") for row in results: # ... do something with results @@ -335,6 +339,48 @@ async def main(): For more details on interacting with an `asyncpg.Connection`, please visit the [official documentation](https://magicstack.github.io/asyncpg/current/api/index.html). +### Async Context Manager + +An alternative to using the `create_async_connector` function is initializing +a `Connector` as an async context manager, removing the need for explicit +calls to `connector.close_async()` to cleanup resources. + +**Note:** This alternative requires that the running event loop be +passed in as the `loop` argument to `Connector()`. + +```python +import asyncio +import asyncpg +from google.cloud.sql.connector import Connector + +async def main(): + # get current running event loop to be used with Connector + loop = asyncio.get_running_loop() + # intialize Connector object as async context manager + async with Connector(loop=loop) as connector: + + # create connection to Cloud SQL database + conn: asyncpg.Connection = await connector.connect_async( + "project:region:instance", # Cloud SQL instance connection name + "asyncpg", + user="my-user", + password="my-password", + db="my-db-name" + # ... additional database driver args + ) + + # insert into Cloud SQL database (example) + await conn.execute("INSERT INTO ratings (title, genre, rating) VALUES ('Batman', 'Action', 8.2)") + + # query Cloud SQL database (example) + results = await conn.fetch("SELECT * from ratings") + for row in results: + # ... do something with results + + # close asyncpg connection + await conn.close +``` + ## Support policy ### Major version lifecycle diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 128807c3e..e65617133 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -264,6 +264,19 @@ def __exit__( """Exit context manager by closing Connector""" self.close() + async def __aenter__(self) -> Any: + """Enter async context manager by returning Connector object""" + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit async context manager by closing Connector""" + await self.close_async() + def close(self) -> None: """Close Connector by stopping tasks and releasing resources.""" close_future = asyncio.run_coroutine_threadsafe( diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 95c683115..c48787f5d 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -124,6 +124,19 @@ def test_Connector_Init_context_manager() -> None: assert connector._credentials is None +@pytest.mark.asyncio +async def test_Connector_Init_async_context_manager() -> None: + """Test that Connector as async context manager sets default properties + properly.""" + loop = asyncio.get_running_loop() + async with Connector(loop=loop) as connector: + assert connector._ip_type == IPTypes.PUBLIC + assert connector._enable_iam_auth is False + assert connector._timeout == 30 + assert connector._credentials is None + assert connector._loop == loop + + def test_Connector_connect(connector: Connector) -> None: """Test that Connector.connect can properly return a DB API connection.""" connect_string = "my-project:my-region:my-instance" From 4805a776d923c2b451907657fad666d63724e3d4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 20 Jul 2022 14:54:22 +0000 Subject: [PATCH 31/32] chore: remove helper function from mocks --- tests/conftest.py | 18 +++++++++++++----- tests/unit/mocks.py | 10 ---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1085a6729..eabb153df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,13 +19,13 @@ import pytest # noqa F401 Needed to run the tests from threading import Thread -from typing import Any, Generator, AsyncGenerator +from typing import Any, Generator, AsyncGenerator, Tuple from google.auth.credentials import Credentials, with_scopes_if_required from google.oauth2 import service_account from aioresponses import aioresponses from mock import patch -from unit.mocks import FakeCSQLInstance, wait_for_keys # type: ignore +from unit.mocks import FakeCSQLInstance # type: ignore from google.cloud.sql.connector import Connector from google.cloud.sql.connector.instance import Instance from google.cloud.sql.connector.utils import generate_keys @@ -154,10 +154,9 @@ async def instance( Instance with mocked API calls. """ # generate client key pair - keys = asyncio.wrap_future( - asyncio.run_coroutine_threadsafe(generate_keys(), event_loop), loop=event_loop - ) + keys = event_loop.create_task(generate_keys()) _, client_key = await keys + with patch("google.auth.default") as mock_auth: mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls @@ -193,6 +192,15 @@ async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, mock_auth.return_value = fake_credentials, None # mock Cloud SQL Admin API calls mock_instance = FakeCSQLInstance(project, region, instance_name) + + async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]: + """ + Helper method to await keys of Connector in tests prior to + initializing an Instance object. + """ + keys = await future + return keys + _, client_key = asyncio.run_coroutine_threadsafe( wait_for_keys(connector._keys), connector._loop ).result() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 72532d72d..9e53e506b 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -15,7 +15,6 @@ """ # file containing all mocks used for Cloud SQL Python Connector unit tests -import asyncio import json import ssl from tempfile import TemporaryDirectory @@ -239,12 +238,3 @@ def generate_ephemeral(self, client_bytes: str) -> str: } } ) - - -async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]: - """ - Helper method to await keys of Connector in tests prior to - initializing an Instance object. - """ - keys = await future - return keys From c357e9c6e3be75abdae93dfcd74ada56cfeefd05 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 25 Jul 2022 14:19:40 +0000 Subject: [PATCH 32/32] chore: add comment explaining fixture --- tests/conftest.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index eabb153df..11813b880 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -198,9 +198,12 @@ async def wait_for_keys(future: asyncio.Future) -> Tuple[bytes, str]: Helper method to await keys of Connector in tests prior to initializing an Instance object. """ - keys = await future - return keys + return await future + # converting asyncio.Future into concurrent.Future + # await keys in background thread so that .result() is set + # required because keys are needed for mocks, but are not awaited + # in the code until Instance() is initialized _, client_key = asyncio.run_coroutine_threadsafe( wait_for_keys(connector._keys), connector._loop ).result()