Skip to content

Commit a7c7861

Browse files
feat: validate driver with database engine (#1108)
Removing ambiguous behavior if a user configures a database driver to connect to an incompatible database engine. i.e. pymysql (MySQL driver) -> Cloud SQL for PostgreSQL database This behavior would timeout trying to connect and then throw an ambiguous timeout error. The Connector.connect already know which database engine the Cloud SQL database is prior to attempting to connect via the driver. So we can validate if the driver is compatible and throw an actionable error if not.
1 parent 821cef8 commit a7c7861

File tree

4 files changed

+81
-1
lines changed

4 files changed

+81
-1
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import google.cloud.sql.connector.asyncpg as asyncpg
3232
from google.cloud.sql.connector.client import CloudSQLClient
33+
from google.cloud.sql.connector.enums import DriverMapping
3334
from google.cloud.sql.connector.exceptions import ConnectorLoopError
3435
from google.cloud.sql.connector.exceptions import DnsNameResolutionError
3536
from google.cloud.sql.connector.instance import IPTypes
@@ -332,6 +333,8 @@ async def connect_async(
332333
# attempt to make connection to Cloud SQL instance
333334
try:
334335
conn_info = await cache.connect_info()
336+
# validate driver matches intended database engine
337+
DriverMapping.validate_engine(driver, conn_info.database_version)
335338
ip_address = conn_info.get_preferred_ip(ip_type)
336339
# resolve DNS name into IP address for PSC
337340
if ip_type.value == "PSC":
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from enum import Enum
16+
17+
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
18+
19+
20+
class DriverMapping(Enum):
21+
"""Maps a given database driver to it's corresponding database engine."""
22+
23+
ASYNCPG = "POSTGRES"
24+
PG8000 = "POSTGRES"
25+
PYMYSQL = "MYSQL"
26+
PYTDS = "SQLSERVER"
27+
28+
@staticmethod
29+
def validate_engine(driver: str, engine_version: str) -> None:
30+
"""Validate that the given driver is compatible with the given engine.
31+
32+
Args:
33+
driver (str): Database driver being used. (i.e. "pg8000")
34+
engine_version (str): Database engine version. (i.e. "POSTGRES_16")
35+
36+
Raises:
37+
IncompatibleDriverError: If the given driver is not compatible with
38+
the given engine.
39+
"""
40+
mapping = DriverMapping[driver.upper()]
41+
if not engine_version.startswith(mapping.value):
42+
raise IncompatibleDriverError(
43+
f"Database driver '{driver}' is incompatible with database "
44+
f"version '{engine_version}'. Given driver can "
45+
f"only be used with Cloud SQL {mapping.value} databases."
46+
)

google/cloud/sql/connector/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,10 @@ class RefreshNotValidError(Exception):
7070
"""
7171

7272
pass
73+
74+
75+
class IncompatibleDriverError(Exception):
76+
"""
77+
Exception to be raised when the database driver given is for the wrong
78+
database engine. (i.e. asyncpg for a MySQL database)
79+
"""

tests/unit/test_connector.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
import asyncio
1718
from typing import Union
1819

@@ -25,6 +26,7 @@
2526
from google.cloud.sql.connector import IPTypes
2627
from google.cloud.sql.connector.client import CloudSQLClient
2728
from google.cloud.sql.connector.exceptions import ConnectorLoopError
29+
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
2830
from google.cloud.sql.connector.instance import RefreshAheadCache
2931

3032

@@ -46,10 +48,32 @@ def test_connect_enable_iam_auth_error(
4648
"If you require both for your use case, please use a new "
4749
"connector.Connector object."
4850
)
49-
# remove cache entrry to avoid destructor warnings
51+
# remove cache entry to avoid destructor warnings
5052
connector._cache = {}
5153

5254

55+
async def test_connect_incompatible_driver_error(
56+
fake_credentials: Credentials,
57+
fake_client: CloudSQLClient,
58+
) -> None:
59+
"""Test that calling connect() with driver that is incompatible with
60+
database version throws error."""
61+
connect_string = "test-project:test-region:test-instance"
62+
async with Connector(
63+
credentials=fake_credentials, loop=asyncio.get_running_loop()
64+
) as connector:
65+
connector._client = fake_client
66+
# try to connect using pymysql driver to a Postgres database
67+
with pytest.raises(IncompatibleDriverError) as exc_info:
68+
await connector.connect_async(connect_string, "pymysql")
69+
assert (
70+
exc_info.value.args[0]
71+
== "Database driver 'pymysql' is incompatible with database version"
72+
" 'POSTGRES_15'. Given driver can only be used with Cloud SQL MYSQL"
73+
" databases."
74+
)
75+
76+
5377
def test_connect_with_unsupported_driver(fake_credentials: Credentials) -> None:
5478
with Connector(credentials=fake_credentials) as connector:
5579
# try to connect using unsupported driver, should raise KeyError

0 commit comments

Comments
 (0)