11use crate :: Config ;
22use postgres:: Connection ;
3- use std :: marker :: PhantomData ;
3+ use r2d2_postgres :: PostgresConnectionManager ;
44
5- #[ cfg( test) ]
6- use std:: sync:: { Arc , Mutex , MutexGuard } ;
5+ pub ( crate ) type PoolConnection = r2d2:: PooledConnection < PostgresConnectionManager > ;
6+
7+ const DEFAULT_SCHEMA : & str = "public" ;
78
89#[ derive( Debug , Clone ) ]
9- pub enum Pool {
10- R2D2 ( r2d2:: Pool < r2d2_postgres:: PostgresConnectionManager > ) ,
11- #[ cfg( test) ]
12- Simple ( Arc < Mutex < Connection > > ) ,
10+ pub struct Pool {
11+ pool : r2d2:: Pool < PostgresConnectionManager > ,
1312}
1413
1514impl Pool {
1615 pub fn new ( config : & Config ) -> Result < Pool , PoolError > {
16+ Self :: new_inner ( config, DEFAULT_SCHEMA )
17+ }
18+
19+ #[ cfg( test) ]
20+ pub ( crate ) fn new_with_schema ( config : & Config , schema : & str ) -> Result < Pool , PoolError > {
21+ Self :: new_inner ( config, schema)
22+ }
23+
24+ fn new_inner ( config : & Config , schema : & str ) -> Result < Pool , PoolError > {
1725 crate :: web:: metrics:: MAX_DB_CONNECTIONS . set ( config. max_pool_size as i64 ) ;
1826
19- let manager = r2d2_postgres :: PostgresConnectionManager :: new (
27+ let manager = PostgresConnectionManager :: new (
2028 config. database_url . as_str ( ) ,
2129 r2d2_postgres:: TlsMode :: None ,
2230 )
@@ -25,73 +33,54 @@ impl Pool {
2533 let pool = r2d2:: Pool :: builder ( )
2634 . max_size ( config. max_pool_size )
2735 . min_idle ( Some ( config. min_pool_idle ) )
36+ . connection_customizer ( Box :: new ( SetSchema :: new ( schema) ) )
2837 . build ( manager)
2938 . map_err ( PoolError :: PoolCreationFailed ) ?;
3039
31- Ok ( Pool :: R2D2 ( pool) )
32- }
33-
34- #[ cfg( test) ]
35- pub ( crate ) fn new_simple ( conn : Arc < Mutex < Connection > > ) -> Self {
36- Pool :: Simple ( conn)
40+ Ok ( Pool { pool } )
3741 }
3842
39- pub fn get ( & self ) -> Result < DerefConnection < ' _ > , PoolError > {
40- match self {
41- Self :: R2D2 ( r2d2) => match r2d2. get ( ) {
42- Ok ( conn) => Ok ( DerefConnection :: Connection ( conn, PhantomData ) ) ,
43- Err ( err) => {
44- crate :: web:: metrics:: FAILED_DB_CONNECTIONS . inc ( ) ;
45- Err ( PoolError :: ConnectionError ( err) )
46- }
47- } ,
48-
49- #[ cfg( test) ]
50- Self :: Simple ( mutex) => Ok ( DerefConnection :: Guard (
51- mutex. lock ( ) . expect ( "failed to lock the connection" ) ,
52- ) ) ,
43+ pub fn get ( & self ) -> Result < PoolConnection , PoolError > {
44+ match self . pool . get ( ) {
45+ Ok ( conn) => Ok ( conn) ,
46+ Err ( err) => {
47+ crate :: web:: metrics:: FAILED_DB_CONNECTIONS . inc ( ) ;
48+ Err ( PoolError :: ConnectionError ( err) )
49+ }
5350 }
5451 }
5552
5653 pub ( crate ) fn used_connections ( & self ) -> u32 {
57- match self {
58- Self :: R2D2 ( conn) => conn. state ( ) . connections - conn. state ( ) . idle_connections ,
59-
60- #[ cfg( test) ]
61- Self :: Simple ( ..) => 0 ,
62- }
54+ self . pool . state ( ) . connections - self . pool . state ( ) . idle_connections
6355 }
6456
6557 pub ( crate ) fn idle_connections ( & self ) -> u32 {
66- match self {
67- Self :: R2D2 ( conn) => conn. state ( ) . idle_connections ,
68-
69- #[ cfg( test) ]
70- Self :: Simple ( ..) => 0 ,
71- }
58+ self . pool . state ( ) . idle_connections
7259 }
7360}
7461
75- pub enum DerefConnection < ' a > {
76- Connection (
77- r2d2:: PooledConnection < r2d2_postgres:: PostgresConnectionManager > ,
78- PhantomData < & ' a ( ) > ,
79- ) ,
80-
81- #[ cfg( test) ]
82- Guard ( MutexGuard < ' a , Connection > ) ,
62+ #[ derive( Debug ) ]
63+ struct SetSchema {
64+ schema : String ,
8365}
8466
85- impl < ' a > std:: ops:: Deref for DerefConnection < ' a > {
86- type Target = Connection ;
87-
88- fn deref ( & self ) -> & Connection {
89- match self {
90- Self :: Connection ( conn, ..) => conn,
67+ impl SetSchema {
68+ fn new ( schema : & str ) -> Self {
69+ Self {
70+ schema : schema. into ( ) ,
71+ }
72+ }
73+ }
9174
92- #[ cfg( test) ]
93- Self :: Guard ( guard) => & guard,
75+ impl r2d2:: CustomizeConnection < Connection , postgres:: Error > for SetSchema {
76+ fn on_acquire ( & self , conn : & mut Connection ) -> Result < ( ) , postgres:: Error > {
77+ if self . schema != DEFAULT_SCHEMA {
78+ conn. execute (
79+ & format ! ( "SET search_path TO {}, {};" , self . schema, DEFAULT_SCHEMA ) ,
80+ & [ ] ,
81+ ) ?;
9482 }
83+ Ok ( ( ) )
9584 }
9685}
9786
0 commit comments