1+ import asyncio
12from typing import Any , Dict , List , Tuple
23
34import psqlpy
4- from sqlalchemy .engine import default
5+ from sqlalchemy import util
6+ from sqlalchemy .dialects .postgresql .base import INTERVAL , PGDialect
7+ from sqlalchemy .dialects .postgresql .json import JSONPathType
58from sqlalchemy .engine .url import URL
9+ from sqlalchemy .sql import sqltypes
610
711from .connection import PsqlpyConnection
812from .dbapi import PsqlpyDBAPI
913
1014
11- class PsqlpyDialect (default .DefaultDialect ):
15+ # Custom type classes with render_bind_cast for better PostgreSQL compatibility
16+ class _PGString (sqltypes .String ):
17+ render_bind_cast = True
18+
19+
20+ class _PGJSONIntIndexType (sqltypes .JSON .JSONIntIndexType ):
21+ __visit_name__ = "json_int_index"
22+ render_bind_cast = True
23+
24+
25+ class _PGJSONStrIndexType (sqltypes .JSON .JSONStrIndexType ):
26+ __visit_name__ = "json_str_index"
27+ render_bind_cast = True
28+
29+
30+ class _PGJSONPathType (JSONPathType ):
31+ pass
32+
33+
34+ class _PGInterval (INTERVAL ):
35+ render_bind_cast = True
36+
37+
38+ class _PGTimeStamp (sqltypes .DateTime ):
39+ render_bind_cast = True
40+
41+
42+ class _PGDate (sqltypes .Date ):
43+ render_bind_cast = True
44+
45+
46+ class _PGTime (sqltypes .Time ):
47+ render_bind_cast = True
48+
49+
50+ class _PGInteger (sqltypes .Integer ):
51+ render_bind_cast = True
52+
53+
54+ class _PGSmallInteger (sqltypes .SmallInteger ):
55+ render_bind_cast = True
56+
57+
58+ class _PGBigInteger (sqltypes .BigInteger ):
59+ render_bind_cast = True
60+
61+
62+ class _PGBoolean (sqltypes .Boolean ):
63+ render_bind_cast = True
64+
65+
66+ class _PGNullType (sqltypes .NullType ):
67+ render_bind_cast = True
68+
69+
70+ class PsqlpyDialect (PGDialect ):
1271 """SQLAlchemy dialect for psqlpy PostgreSQL driver"""
1372
1473 name = "postgresql"
1574 driver = "psqlpy"
1675
1776 # Dialect capabilities
1877 supports_statement_cache = True
78+ supports_server_side_cursors = True
1979 supports_multivalues_insert = True
2080 supports_unicode_statements = True
2181 supports_unicode_binds = True
@@ -31,6 +91,7 @@ class PsqlpyDialect(default.DefaultDialect):
3191 update_returning = True
3292 delete_returning = True
3393 favor_returning_over_lastrowid = True
94+ default_paramstyle = "numeric_dollar"
3495
3596 # Connection pooling
3697 supports_sane_rowcount = True
@@ -40,11 +101,41 @@ class PsqlpyDialect(default.DefaultDialect):
40101 supports_isolation_level = True
41102 default_isolation_level = "READ_COMMITTED"
42103
104+ # Comprehensive colspecs mapping for better PostgreSQL type handling
105+ colspecs = util .update_copy (
106+ PGDialect .colspecs ,
107+ {
108+ sqltypes .String : _PGString ,
109+ sqltypes .JSON .JSONPathType : _PGJSONPathType ,
110+ sqltypes .JSON .JSONIntIndexType : _PGJSONIntIndexType ,
111+ sqltypes .JSON .JSONStrIndexType : _PGJSONStrIndexType ,
112+ sqltypes .Interval : _PGInterval ,
113+ INTERVAL : _PGInterval ,
114+ sqltypes .Date : _PGDate ,
115+ sqltypes .DateTime : _PGTimeStamp ,
116+ sqltypes .Time : _PGTime ,
117+ sqltypes .Integer : _PGInteger ,
118+ sqltypes .SmallInteger : _PGSmallInteger ,
119+ sqltypes .BigInteger : _PGBigInteger ,
120+ sqltypes .Boolean : _PGBoolean ,
121+ sqltypes .NullType : _PGNullType ,
122+ },
123+ )
124+
43125 @classmethod
44126 def import_dbapi (cls ):
45127 """Import the psqlpy module as DBAPI"""
46128 return PsqlpyDBAPI ()
47129
130+ @util .memoized_property
131+ def _isolation_lookup (self ) -> Dict [str , Any ]:
132+ """Mapping of SQLAlchemy isolation levels to psqlpy isolation levels"""
133+ return {
134+ "READ_COMMITTED" : psqlpy .IsolationLevel .ReadCommitted ,
135+ "REPEATABLE_READ" : psqlpy .IsolationLevel .RepeatableRead ,
136+ "SERIALIZABLE" : psqlpy .IsolationLevel .Serializable ,
137+ }
138+
48139 def create_connect_args (
49140 self , url : URL
50141 ) -> Tuple [List [Any ], Dict [str , Any ]]:
@@ -89,8 +180,10 @@ def create_connect_args(
89180 def connect (self , * cargs , ** cparams ):
90181 """Create a connection to the database"""
91182 try :
92- # Use psqlpy.connect to create a connection
93- raw_connection = psqlpy .connect (** cparams )
183+ # psqlpy.connect returns a coroutine that needs to be awaited
184+ # Since SQLAlchemy dialects are synchronous, we use asyncio.run()
185+ connection_coro = psqlpy .connect (** cparams )
186+ raw_connection = asyncio .run (connection_coro )
94187 # Wrap it in our DBAPI-compatible connection
95188 return PsqlpyConnection (raw_connection )
96189 except Exception as e :
@@ -150,7 +243,15 @@ def get_isolation_level(self, dbapi_connection):
150243 return self .default_isolation_level
151244
152245 def set_isolation_level (self , dbapi_connection , level ):
153- """Set the isolation level"""
246+ """Set the isolation level using psqlpy enums"""
247+ if hasattr (dbapi_connection , "set_isolation_level" ):
248+ # Use psqlpy's native isolation level setting if available
249+ psqlpy_level = self ._isolation_lookup .get (level )
250+ if psqlpy_level is not None :
251+ dbapi_connection .set_isolation_level (psqlpy_level )
252+ return
253+
254+ # Fallback to SQL-based approach
154255 try :
155256 cursor = dbapi_connection .cursor ()
156257 level_map = {
@@ -176,3 +277,71 @@ def _handle_exception(self, e):
176277 def get_default_isolation_level (self , dbapi_conn ):
177278 """Get the default isolation level for new connections"""
178279 return self .default_isolation_level
280+
281+ def set_readonly (self , connection , value ):
282+ """Set the readonly state of the connection"""
283+ if hasattr (connection , "readonly" ):
284+ if value is True :
285+ connection .readonly = psqlpy .ReadVariant .ReadOnly
286+ else :
287+ connection .readonly = psqlpy .ReadVariant .ReadWrite
288+ else :
289+ # Fallback to SQL-based approach
290+ try :
291+ cursor = connection .cursor ()
292+ if value :
293+ cursor .execute ("SET TRANSACTION READ ONLY" )
294+ else :
295+ cursor .execute ("SET TRANSACTION READ WRITE" )
296+ except Exception as e :
297+ raise self ._handle_exception (e )
298+
299+ def get_readonly (self , connection ):
300+ """Get the readonly state of the connection"""
301+ if hasattr (connection , "readonly" ):
302+ return connection .readonly == psqlpy .ReadVariant .ReadOnly
303+ return False
304+
305+ def set_deferrable (self , connection , value ):
306+ """Set the deferrable state of the connection"""
307+ if hasattr (connection , "deferrable" ):
308+ connection .deferrable = value
309+ else :
310+ # Fallback to SQL-based approach
311+ try :
312+ cursor = connection .cursor ()
313+ if value :
314+ cursor .execute ("SET TRANSACTION DEFERRABLE" )
315+ else :
316+ cursor .execute ("SET TRANSACTION NOT DEFERRABLE" )
317+ except Exception as e :
318+ raise self ._handle_exception (e )
319+
320+ def get_deferrable (self , connection ):
321+ """Get the deferrable state of the connection"""
322+ if hasattr (connection , "deferrable" ):
323+ return connection .deferrable
324+ return False
325+
326+ def has_table (self , connection , table_name , schema = None ):
327+ """Check if a table exists in the database"""
328+ if schema is None :
329+ schema = "public"
330+
331+ query = """
332+ SELECT EXISTS (
333+ SELECT 1
334+ FROM information_schema.tables
335+ WHERE table_schema = %s
336+ AND table_name = %s
337+ )
338+ """
339+
340+ try :
341+ cursor = connection .cursor ()
342+ cursor .execute (query , (schema , table_name ))
343+ result = cursor .fetchone ()
344+ return result [0 ] if result else False
345+ except Exception :
346+ # If we can't check, assume table doesn't exist
347+ return False
0 commit comments