Skip to content

Commit 1d63bb1

Browse files
Implement connection service file functionality (#1223)
1 parent 5b14653 commit 1d63bb1

File tree

3 files changed

+274
-9
lines changed

3 files changed

+274
-9
lines changed

asyncpg/connect_utils.py

Lines changed: 148 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import annotations
88

99
import asyncio
10+
import configparser
1011
import collections
1112
from collections.abc import Callable
1213
import enum
@@ -87,6 +88,9 @@ class SSLNegotiation(compat.StrEnum):
8788
PGPASSFILE = '.pgpass'
8889

8990

91+
PG_SERVICEFILE = '.pg_service.conf'
92+
93+
9094
def _read_password_file(passfile: pathlib.Path) \
9195
-> typing.List[typing.Tuple[str, ...]]:
9296

@@ -271,6 +275,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
271275

272276
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
273277
password, passfile, database, ssl,
278+
service, servicefile,
274279
direct_tls, server_settings,
275280
target_session_attrs, krbsrvname, gsslib):
276281
# `auth_hosts` is the version of host information for the purposes
@@ -283,6 +288,32 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
283288
if dsn:
284289
parsed = urllib.parse.urlparse(dsn)
285290

291+
query = None
292+
if parsed.query:
293+
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
294+
for key, val in query.items():
295+
if isinstance(val, list):
296+
query[key] = val[-1]
297+
298+
if 'service' in query:
299+
val = query.pop('service')
300+
if not service and val:
301+
service = val
302+
303+
connection_service_file = servicefile
304+
305+
if connection_service_file is None:
306+
connection_service_file = os.getenv('PGSERVICEFILE')
307+
308+
if connection_service_file is None:
309+
homedir = compat.get_pg_home_directory()
310+
if homedir:
311+
connection_service_file = homedir / PG_SERVICEFILE
312+
else:
313+
connection_service_file = None
314+
else:
315+
connection_service_file = pathlib.Path(connection_service_file)
316+
286317
if parsed.scheme not in {'postgresql', 'postgres'}:
287318
raise exceptions.ClientConfigurationError(
288319
'invalid DSN: scheme is expected to be either '
@@ -317,11 +348,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
317348
if password is None and dsn_password:
318349
password = urllib.parse.unquote(dsn_password)
319350

320-
if parsed.query:
321-
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
322-
for key, val in query.items():
323-
if isinstance(val, list):
324-
query[key] = val[-1]
351+
if query:
325352

326353
if 'port' in query:
327354
val = query.pop('port')
@@ -408,12 +435,124 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
408435
if gsslib is None:
409436
gsslib = val
410437

438+
if 'service' in query:
439+
val = query.pop('service')
440+
if service is None:
441+
service = val
442+
411443
if query:
412444
if server_settings is None:
413445
server_settings = query
414446
else:
415447
server_settings = {**query, **server_settings}
416448

449+
if connection_service_file is not None and service is not None:
450+
pg_service = configparser.ConfigParser()
451+
pg_service.read(connection_service_file)
452+
if service in pg_service.sections():
453+
service_params = pg_service[service]
454+
if 'port' in service_params:
455+
val = service_params.pop('port')
456+
if not port and val:
457+
port = [int(p) for p in val.split(',')]
458+
459+
if 'host' in service_params:
460+
val = service_params.pop('host')
461+
if not host and val:
462+
host, port = _parse_hostlist(val, port)
463+
464+
if 'dbname' in service_params:
465+
val = service_params.pop('dbname')
466+
if database is None:
467+
database = val
468+
469+
if 'database' in service_params:
470+
val = service_params.pop('database')
471+
if database is None:
472+
database = val
473+
474+
if 'user' in service_params:
475+
val = service_params.pop('user')
476+
if user is None:
477+
user = val
478+
479+
if 'password' in service_params:
480+
val = service_params.pop('password')
481+
if password is None:
482+
password = val
483+
484+
if 'passfile' in service_params:
485+
val = service_params.pop('passfile')
486+
if passfile is None:
487+
passfile = val
488+
489+
if 'sslmode' in service_params:
490+
val = service_params.pop('sslmode')
491+
if ssl is None:
492+
ssl = val
493+
494+
if 'sslcert' in service_params:
495+
val = service_params.pop('sslcert')
496+
if sslcert is None:
497+
sslcert = val
498+
499+
if 'sslkey' in service_params:
500+
val = service_params.pop('sslkey')
501+
if sslkey is None:
502+
sslkey = val
503+
504+
if 'sslrootcert' in service_params:
505+
val = service_params.pop('sslrootcert')
506+
if sslrootcert is None:
507+
sslrootcert = val
508+
509+
if 'sslnegotiation' in service_params:
510+
val = service_params.pop('sslnegotiation')
511+
if sslnegotiation is None:
512+
sslnegotiation = val
513+
514+
if 'sslcrl' in service_params:
515+
val = service_params.pop('sslcrl')
516+
if sslcrl is None:
517+
sslcrl = val
518+
519+
if 'sslpassword' in service_params:
520+
val = service_params.pop('sslpassword')
521+
if sslpassword is None:
522+
sslpassword = val
523+
524+
if 'ssl_min_protocol_version' in service_params:
525+
val = service_params.pop(
526+
'ssl_min_protocol_version'
527+
)
528+
if ssl_min_protocol_version is None:
529+
ssl_min_protocol_version = val
530+
531+
if 'ssl_max_protocol_version' in service_params:
532+
val = service_params.pop(
533+
'ssl_max_protocol_version'
534+
)
535+
if ssl_max_protocol_version is None:
536+
ssl_max_protocol_version = val
537+
538+
if 'target_session_attrs' in service_params:
539+
dsn_target_session_attrs = service_params.pop(
540+
'target_session_attrs'
541+
)
542+
if target_session_attrs is None:
543+
target_session_attrs = dsn_target_session_attrs
544+
545+
if 'krbsrvname' in service_params:
546+
val = service_params.pop('krbsrvname')
547+
if krbsrvname is None:
548+
krbsrvname = val
549+
550+
if 'gsslib' in service_params:
551+
val = service_params.pop('gsslib')
552+
if gsslib is None:
553+
gsslib = val
554+
if not service:
555+
service = os.environ.get('PGSERVICE')
417556
if not host:
418557
hostspec = os.environ.get('PGHOST')
419558
if hostspec:
@@ -726,7 +865,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
726865
max_cached_statement_lifetime,
727866
max_cacheable_statement_size,
728867
ssl, direct_tls, server_settings,
729-
target_session_attrs, krbsrvname, gsslib):
868+
target_session_attrs, krbsrvname, gsslib,
869+
service, servicefile):
730870
local_vars = locals()
731871
for var_name in {'max_cacheable_statement_size',
732872
'max_cached_statement_lifetime',
@@ -756,7 +896,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
756896
direct_tls=direct_tls, database=database,
757897
server_settings=server_settings,
758898
target_session_attrs=target_session_attrs,
759-
krbsrvname=krbsrvname, gsslib=gsslib)
899+
krbsrvname=krbsrvname, gsslib=gsslib,
900+
service=service, servicefile=servicefile)
760901

761902
config = _ClientConfiguration(
762903
command_timeout=command_timeout,

asyncpg/connection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2083,6 +2083,8 @@ async def _do_execute(
20832083
async def connect(dsn=None, *,
20842084
host=None, port=None,
20852085
user=None, password=None, passfile=None,
2086+
service=None,
2087+
servicefile=None,
20862088
database=None,
20872089
loop=None,
20882090
timeout=60,
@@ -2192,6 +2194,14 @@ async def connect(dsn=None, *,
21922194
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
21932195
on Windows).
21942196
2197+
:param service:
2198+
The name of the postgres connection service stored in the postgres
2199+
connection service file.
2200+
2201+
:param servicefile:
2202+
The location of the connnection service file used to store
2203+
connection parameters.
2204+
21952205
:param loop:
21962206
An asyncio event loop instance. If ``None``, the default
21972207
event loop will be used.
@@ -2404,6 +2414,9 @@ async def connect(dsn=None, *,
24042414
.. versionchanged:: 0.30.0
24052415
Added the *krbsrvname* and *gsslib* parameters.
24062416
2417+
.. versionchanged:: 0.31.0
2418+
Added the *servicefile* and *service* parameters.
2419+
24072420
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
24082421
.. _create_default_context:
24092422
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
@@ -2437,6 +2450,8 @@ async def connect(dsn=None, *,
24372450
user=user,
24382451
password=password,
24392452
passfile=passfile,
2453+
service=service,
2454+
servicefile=servicefile,
24402455
ssl=ssl,
24412456
direct_tls=direct_tls,
24422457
database=database,

0 commit comments

Comments
 (0)