Skip to content

Commit c7c3f33

Browse files
author
Eugene Shershen
committed
add connection and dbapi tests
1 parent 291d336 commit c7c3f33

File tree

6 files changed

+1264
-111
lines changed

6 files changed

+1264
-111
lines changed

psqlpy_sqlalchemy/connection.py

Lines changed: 2 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,6 @@ async def _prepare_execute(
8484

8585
if casting_matches:
8686
# This is a known limitation: SQLAlchemy can't handle named parameters with explicit PostgreSQL casting
87-
import logging
88-
89-
logging.getLogger(__name__)
90-
9187
raise RuntimeError(
9288
f"Named parameters with explicit PostgreSQL casting are not supported. "
9389
f"Found casting parameters: {casting_matches} in query: {converted_query[:100]}... "
@@ -199,18 +195,7 @@ def _convert_named_params_with_casting(
199195
200196
And converts the parameters dict to a list in the correct order.
201197
"""
202-
# Add debugging logging for CI troubleshooting
203-
import logging
204-
205-
logger = logging.getLogger(__name__)
206-
207-
logger.debug(
208-
f"Parameter conversion called - Query: {querystring!r}, "
209-
f"Parameters: {parameters!r}, Parameters type: {type(parameters)}"
210-
)
211-
212198
if parameters is None or not isinstance(parameters, dict):
213-
logger.debug("Parameters is None or not dict, returning as-is")
214199
return querystring, parameters
215200

216201
import re
@@ -222,14 +207,7 @@ def _convert_named_params_with_casting(
222207
# Find all parameter references in the query
223208
matches = list(re.finditer(param_pattern, querystring))
224209

225-
logger.debug(f"Found {len(matches)} parameter matches in query")
226-
for i, match in enumerate(matches):
227-
logger.debug(
228-
f" Match {i + 1}: '{match.group(0)}' -> param='{match.group(1)}', cast='{match.group(2)}'"
229-
)
230-
231210
if not matches:
232-
logger.debug("No parameter matches found, returning as-is")
233211
return querystring, parameters
234212

235213
# Build the conversion mapping and new parameter list
@@ -249,28 +227,8 @@ def _convert_named_params_with_casting(
249227

250228
# Defensive check: ensure all parameters found in query are available
251229
if missing_params:
252-
# Enhanced error message with more debugging information
253-
available_params = list(parameters.keys()) if parameters else []
254-
found_params = [m.group(1) for m in matches]
255-
256-
# Log additional debugging information for CI troubleshooting
257-
import logging
258-
259-
logger = logging.getLogger(__name__)
260-
logger.error(
261-
f"Parameter conversion error - Missing parameters: {missing_params}. "
262-
f"Query: {querystring!r}. "
263-
f"Found in query: {found_params}. "
264-
f"Available in dict: {available_params}. "
265-
f"Parameters dict: {parameters!r}"
266-
)
267-
268230
# Instead of raising an error, return the original query and parameters
269231
# This prevents partial conversion which can cause SQL syntax errors
270-
logger.warning(
271-
"Returning original query due to missing parameters. "
272-
"This may indicate a parameter processing issue."
273-
)
274232
return querystring, parameters
275233

276234
# Convert the query string by replacing each parameter with its positional equivalent
@@ -334,16 +292,6 @@ def _convert_named_params_with_casting(
334292
f"Converted query: {converted_query}, Original query: {querystring}"
335293
)
336294

337-
# Log final conversion results for debugging
338-
logger.debug(
339-
f"Parameter conversion completed - "
340-
f"Original query: {querystring!r}, "
341-
f"Converted query: {converted_query!r}, "
342-
f"Original params: {parameters!r}, "
343-
f"Converted params: {converted_params!r}, "
344-
f"Parameter order: {param_order}"
345-
)
346-
347295
return converted_query, converted_params
348296

349297
@property
@@ -481,7 +429,6 @@ def fetchall(self):
481429
return []
482430

483431
def __iter__(self):
484-
"""Enhanced async iteration with better error handling"""
485432
if self._closed or self._cursor is None:
486433
return
487434

@@ -490,15 +437,10 @@ def __iter__(self):
490437
try:
491438
result = self.await_(iterator.__anext__())
492439
rows = self._convert_result(result=result)
493-
if rows:
494-
yield from rows
495-
else:
496-
break
440+
# Yield individual rows, not the entire result
441+
yield from rows
497442
except StopAsyncIteration:
498443
break
499-
except Exception:
500-
# Stop iteration on any error
501-
break
502444

503445

504446
class AsyncAdapt_psqlpy_connection(AsyncAdapt_dbapi_connection):

psqlpy_sqlalchemy/dbapi.py

Lines changed: 54 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@ def __init__(self, psqlpy) -> None:
1212
self.apilevel = "2.0"
1313
self.threadsafety = 2 # Threads may share the module and connections
1414

15-
self.Warning = psqlpy.Error
16-
self.Error = psqlpy.Error
17-
self.InterfaceError = psqlpy.Error
18-
self.DatabaseError = psqlpy.Error
19-
self.DataError = psqlpy.Error
20-
self.OperationalError = psqlpy.Error
21-
self.IntegrityError = psqlpy.Error
22-
self.InternalError = psqlpy.Error
23-
self.ProgrammingError = psqlpy.Error
24-
self.NotSupportedError = psqlpy.Error
15+
# Single reusable exception class for all error types
16+
_error_class = psqlpy.Error
17+
self.Warning = _error_class
18+
self.Error = _error_class
19+
self.InterfaceError = _error_class
20+
self.DatabaseError = _error_class
21+
self.DataError = _error_class
22+
self.OperationalError = _error_class
23+
self.IntegrityError = _error_class
24+
self.InternalError = _error_class
25+
self.ProgrammingError = _error_class
26+
self.NotSupportedError = _error_class
2527

2628
for k, v in self.psqlpy.__dict__.items():
2729
if k != "connect":
@@ -39,34 +41,36 @@ def connect(self, *arg, **kw):
3941
# Add other server_settings mappings as needed
4042

4143
# Filter out any other unsupported parameters that SQLAlchemy might pass
42-
supported_params = {
43-
"dsn",
44-
"username",
45-
"password",
46-
"host",
47-
"hosts",
48-
"port",
49-
"ports",
50-
"db_name",
51-
"target_session_attrs",
52-
"options",
53-
"application_name",
54-
"connect_timeout_sec",
55-
"connect_timeout_nanosec",
56-
"tcp_user_timeout_sec",
57-
"tcp_user_timeout_nanosec",
58-
"keepalives",
59-
"keepalives_idle_sec",
60-
"keepalives_idle_nanosec",
61-
"keepalives_interval_sec",
62-
"keepalives_interval_nanosec",
63-
"keepalives_retries",
64-
"load_balance_hosts",
65-
"max_db_pool_size",
66-
"conn_recycling_method",
67-
"ssl_mode",
68-
"ca_file",
69-
}
44+
supported_params = frozenset(
45+
{
46+
"dsn",
47+
"username",
48+
"password",
49+
"host",
50+
"hosts",
51+
"port",
52+
"ports",
53+
"db_name",
54+
"target_session_attrs",
55+
"options",
56+
"application_name",
57+
"connect_timeout_sec",
58+
"connect_timeout_nanosec",
59+
"tcp_user_timeout_sec",
60+
"tcp_user_timeout_nanosec",
61+
"keepalives",
62+
"keepalives_idle_sec",
63+
"keepalives_idle_nanosec",
64+
"keepalives_interval_sec",
65+
"keepalives_interval_nanosec",
66+
"keepalives_retries",
67+
"load_balance_hosts",
68+
"max_db_pool_size",
69+
"conn_recycling_method",
70+
"ssl_mode",
71+
"ca_file",
72+
}
73+
)
7074

7175
filtered_kw = {k: v for k, v in kw.items() if k in supported_params}
7276

@@ -90,17 +94,18 @@ def __init__(self):
9094

9195
self._adapt_dbapi = PSQLPyAdaptDBAPI(psqlpy)
9296

93-
# Copy attributes from psqlpy for compatibility
94-
self.Warning = psqlpy.Error
95-
self.Error = psqlpy.Error
96-
self.InterfaceError = psqlpy.Error
97-
self.DatabaseError = psqlpy.Error
98-
self.DataError = psqlpy.Error
99-
self.OperationalError = psqlpy.Error
100-
self.IntegrityError = psqlpy.Error
101-
self.InternalError = psqlpy.Error
102-
self.ProgrammingError = psqlpy.Error
103-
self.NotSupportedError = psqlpy.Error
97+
# Single reusable exception class for all error types
98+
_error_class = psqlpy.Error
99+
self.Warning = _error_class
100+
self.Error = _error_class
101+
self.InterfaceError = _error_class
102+
self.DatabaseError = _error_class
103+
self.DataError = _error_class
104+
self.OperationalError = _error_class
105+
self.IntegrityError = _error_class
106+
self.InternalError = _error_class
107+
self.ProgrammingError = _error_class
108+
self.NotSupportedError = _error_class
104109

105110
# Type constructors
106111
def Date(self, year, month, day):

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ dev = [
3939
"greenlet>=1.0.0",
4040
"ruff>=0.1.0",
4141
"mypy>=1.0.0",
42-
"sqlmodel>=0.0.8",
42+
"sqlmodel>=0.0.14",
4343
"pytest-cov",
4444
"fastapi>=0.68.0",
4545
"starlette>=0.14.0",

0 commit comments

Comments
 (0)