Skip to content

Commit 282cc68

Browse files
author
Eugene Shershen
committed
codeql-analysis.yml
1 parent 9a5ce68 commit 282cc68

File tree

5 files changed

+185
-14
lines changed

5 files changed

+185
-14
lines changed

psqlpy_sqlalchemy/connection.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def execute(
2222
"""Execute a query with optional parameters"""
2323
try:
2424
if parameters is None:
25-
self._result = self._psqlpy_connection.fetch(query)
25+
query_result = self._psqlpy_connection.fetch(query)
2626
else:
2727
if isinstance(parameters, (list, tuple)):
2828
param_dict = {
@@ -31,17 +31,17 @@ def execute(
3131
query = self._convert_positional_to_named(
3232
query, len(parameters)
3333
)
34-
self._result = self._psqlpy_connection.fetch(
34+
query_result = self._psqlpy_connection.fetch(
3535
query, param_dict
3636
)
3737
else:
38-
self._result = self._psqlpy_connection.fetch(
38+
query_result = self._psqlpy_connection.fetch(
3939
query, parameters
4040
)
4141

42-
# Process the result
43-
if self._result:
44-
self._rows = self._result.result()
42+
# Process the result - call .result() on the QueryResult object
43+
if query_result:
44+
self._rows = query_result.result()
4545
self.rowcount = len(self._rows) if self._rows else 0
4646
self._row_index = 0
4747

psqlpy_sqlalchemy/dbapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ class PsqlpyDBAPI:
77
# DBAPI 2.0 module attributes
88
apilevel = "2.0"
99
threadsafety = 2 # Threads may share the module and connections
10-
paramstyle = "pyformat" # PostgreSQL uses %(name)s style parameters
10+
paramstyle = (
11+
"numeric_dollar" # PostgreSQL uses $1, $2, etc. style parameters
12+
)
1113

1214
# Exception hierarchy (DBAPI 2.0 standard)
1315
Warning = psqlpy.Error

psqlpy_sqlalchemy/dialect.py

Lines changed: 174 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,81 @@
1+
import asyncio
12
from typing import Any, Dict, List, Tuple
23

34
import psqlpy
4-
from sqlalchemy.engine import default
5+
from sqlalchemy import util
6+
from sqlalchemy.dialects.postgresql.base import INTERVAL, PGDialect
7+
from sqlalchemy.dialects.postgresql.json import JSONPathType
58
from sqlalchemy.engine.url import URL
9+
from sqlalchemy.sql import sqltypes
610

711
from .connection import PsqlpyConnection
812
from .dbapi import PsqlpyDBAPI
913

1014

11-
class PsqlpyDialect(default.DefaultDialect):
15+
# Custom type classes with render_bind_cast for better PostgreSQL compatibility
16+
class _PGString(sqltypes.String):
17+
render_bind_cast = True
18+
19+
20+
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
21+
__visit_name__ = "json_int_index"
22+
render_bind_cast = True
23+
24+
25+
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
26+
__visit_name__ = "json_str_index"
27+
render_bind_cast = True
28+
29+
30+
class _PGJSONPathType(JSONPathType):
31+
pass
32+
33+
34+
class _PGInterval(INTERVAL):
35+
render_bind_cast = True
36+
37+
38+
class _PGTimeStamp(sqltypes.DateTime):
39+
render_bind_cast = True
40+
41+
42+
class _PGDate(sqltypes.Date):
43+
render_bind_cast = True
44+
45+
46+
class _PGTime(sqltypes.Time):
47+
render_bind_cast = True
48+
49+
50+
class _PGInteger(sqltypes.Integer):
51+
render_bind_cast = True
52+
53+
54+
class _PGSmallInteger(sqltypes.SmallInteger):
55+
render_bind_cast = True
56+
57+
58+
class _PGBigInteger(sqltypes.BigInteger):
59+
render_bind_cast = True
60+
61+
62+
class _PGBoolean(sqltypes.Boolean):
63+
render_bind_cast = True
64+
65+
66+
class _PGNullType(sqltypes.NullType):
67+
render_bind_cast = True
68+
69+
70+
class PsqlpyDialect(PGDialect):
1271
"""SQLAlchemy dialect for psqlpy PostgreSQL driver"""
1372

1473
name = "postgresql"
1574
driver = "psqlpy"
1675

1776
# Dialect capabilities
1877
supports_statement_cache = True
78+
supports_server_side_cursors = True
1979
supports_multivalues_insert = True
2080
supports_unicode_statements = True
2181
supports_unicode_binds = True
@@ -31,6 +91,7 @@ class PsqlpyDialect(default.DefaultDialect):
3191
update_returning = True
3292
delete_returning = True
3393
favor_returning_over_lastrowid = True
94+
default_paramstyle = "numeric_dollar"
3495

3596
# Connection pooling
3697
supports_sane_rowcount = True
@@ -40,11 +101,41 @@ class PsqlpyDialect(default.DefaultDialect):
40101
supports_isolation_level = True
41102
default_isolation_level = "READ_COMMITTED"
42103

104+
# Comprehensive colspecs mapping for better PostgreSQL type handling
105+
colspecs = util.update_copy(
106+
PGDialect.colspecs,
107+
{
108+
sqltypes.String: _PGString,
109+
sqltypes.JSON.JSONPathType: _PGJSONPathType,
110+
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
111+
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
112+
sqltypes.Interval: _PGInterval,
113+
INTERVAL: _PGInterval,
114+
sqltypes.Date: _PGDate,
115+
sqltypes.DateTime: _PGTimeStamp,
116+
sqltypes.Time: _PGTime,
117+
sqltypes.Integer: _PGInteger,
118+
sqltypes.SmallInteger: _PGSmallInteger,
119+
sqltypes.BigInteger: _PGBigInteger,
120+
sqltypes.Boolean: _PGBoolean,
121+
sqltypes.NullType: _PGNullType,
122+
},
123+
)
124+
43125
@classmethod
44126
def import_dbapi(cls):
45127
"""Import the psqlpy module as DBAPI"""
46128
return PsqlpyDBAPI()
47129

130+
@util.memoized_property
131+
def _isolation_lookup(self) -> Dict[str, Any]:
132+
"""Mapping of SQLAlchemy isolation levels to psqlpy isolation levels"""
133+
return {
134+
"READ_COMMITTED": psqlpy.IsolationLevel.ReadCommitted,
135+
"REPEATABLE_READ": psqlpy.IsolationLevel.RepeatableRead,
136+
"SERIALIZABLE": psqlpy.IsolationLevel.Serializable,
137+
}
138+
48139
def create_connect_args(
49140
self, url: URL
50141
) -> Tuple[List[Any], Dict[str, Any]]:
@@ -89,8 +180,10 @@ def create_connect_args(
89180
def connect(self, *cargs, **cparams):
90181
"""Create a connection to the database"""
91182
try:
92-
# Use psqlpy.connect to create a connection
93-
raw_connection = psqlpy.connect(**cparams)
183+
# psqlpy.connect returns a coroutine that needs to be awaited
184+
# Since SQLAlchemy dialects are synchronous, we use asyncio.run()
185+
connection_coro = psqlpy.connect(**cparams)
186+
raw_connection = asyncio.run(connection_coro)
94187
# Wrap it in our DBAPI-compatible connection
95188
return PsqlpyConnection(raw_connection)
96189
except Exception as e:
@@ -150,7 +243,15 @@ def get_isolation_level(self, dbapi_connection):
150243
return self.default_isolation_level
151244

152245
def set_isolation_level(self, dbapi_connection, level):
153-
"""Set the isolation level"""
246+
"""Set the isolation level using psqlpy enums"""
247+
if hasattr(dbapi_connection, "set_isolation_level"):
248+
# Use psqlpy's native isolation level setting if available
249+
psqlpy_level = self._isolation_lookup.get(level)
250+
if psqlpy_level is not None:
251+
dbapi_connection.set_isolation_level(psqlpy_level)
252+
return
253+
254+
# Fallback to SQL-based approach
154255
try:
155256
cursor = dbapi_connection.cursor()
156257
level_map = {
@@ -176,3 +277,71 @@ def _handle_exception(self, e):
176277
def get_default_isolation_level(self, dbapi_conn):
177278
"""Get the default isolation level for new connections"""
178279
return self.default_isolation_level
280+
281+
def set_readonly(self, connection, value):
282+
"""Set the readonly state of the connection"""
283+
if hasattr(connection, "readonly"):
284+
if value is True:
285+
connection.readonly = psqlpy.ReadVariant.ReadOnly
286+
else:
287+
connection.readonly = psqlpy.ReadVariant.ReadWrite
288+
else:
289+
# Fallback to SQL-based approach
290+
try:
291+
cursor = connection.cursor()
292+
if value:
293+
cursor.execute("SET TRANSACTION READ ONLY")
294+
else:
295+
cursor.execute("SET TRANSACTION READ WRITE")
296+
except Exception as e:
297+
raise self._handle_exception(e)
298+
299+
def get_readonly(self, connection):
300+
"""Get the readonly state of the connection"""
301+
if hasattr(connection, "readonly"):
302+
return connection.readonly == psqlpy.ReadVariant.ReadOnly
303+
return False
304+
305+
def set_deferrable(self, connection, value):
306+
"""Set the deferrable state of the connection"""
307+
if hasattr(connection, "deferrable"):
308+
connection.deferrable = value
309+
else:
310+
# Fallback to SQL-based approach
311+
try:
312+
cursor = connection.cursor()
313+
if value:
314+
cursor.execute("SET TRANSACTION DEFERRABLE")
315+
else:
316+
cursor.execute("SET TRANSACTION NOT DEFERRABLE")
317+
except Exception as e:
318+
raise self._handle_exception(e)
319+
320+
def get_deferrable(self, connection):
321+
"""Get the deferrable state of the connection"""
322+
if hasattr(connection, "deferrable"):
323+
return connection.deferrable
324+
return False
325+
326+
def has_table(self, connection, table_name, schema=None):
327+
"""Check if a table exists in the database"""
328+
if schema is None:
329+
schema = "public"
330+
331+
query = """
332+
SELECT EXISTS (
333+
SELECT 1
334+
FROM information_schema.tables
335+
WHERE table_schema = %s
336+
AND table_name = %s
337+
)
338+
"""
339+
340+
try:
341+
cursor = connection.cursor()
342+
cursor.execute(query, (schema, table_name))
343+
result = cursor.fetchone()
344+
return result[0] if result else False
345+
except Exception:
346+
# If we can't check, assume table doesn't exist
347+
return False

tests/test_dbapi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_dbapi_attributes(self):
2020
"""Test DBAPI 2.0 module attributes"""
2121
self.assertEqual(self.dbapi.apilevel, "2.0")
2222
self.assertEqual(self.dbapi.threadsafety, 2)
23-
self.assertEqual(self.dbapi.paramstyle, "pyformat")
23+
self.assertEqual(self.dbapi.paramstyle, "numeric_dollar")
2424

2525
def test_exception_hierarchy(self):
2626
"""Test that all required DBAPI exceptions are available"""

tests/test_dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_dbapi_interface(self):
118118
# Test DBAPI attributes
119119
self.assertEqual(dbapi.apilevel, "2.0")
120120
self.assertEqual(dbapi.threadsafety, 2)
121-
self.assertEqual(dbapi.paramstyle, "pyformat")
121+
self.assertEqual(dbapi.paramstyle, "numeric_dollar")
122122

123123
# Test exception hierarchy
124124
exceptions = [

0 commit comments

Comments
 (0)