55 "sync"
66 "time"
77
8+ "github.com/redis/go-redis/v9/internal"
89 "github.com/redis/go-redis/v9/internal/pool"
910)
1011
@@ -17,6 +18,9 @@ type ReAuthPoolHook struct {
1718 // conn id -> bool
1819 scheduledReAuth map [uint64 ]bool
1920 scheduledLock sync.RWMutex
21+
22+ // for cleanup
23+ manager * Manager
2024}
2125
2226func NewReAuthPoolHook (poolSize int , reAuthTimeout time.Duration ) * ReAuthPoolHook {
@@ -32,7 +36,6 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
3236 workers : workers ,
3337 reAuthTimeout : reAuthTimeout ,
3438 }
35-
3639}
3740
3841func (r * ReAuthPoolHook ) MarkForReAuth (connID uint64 , reAuthFn func (error )) {
@@ -41,27 +44,22 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
4144 r .shouldReAuth [connID ] = reAuthFn
4245}
4346
44- func (r * ReAuthPoolHook ) ClearReAuthMark (connID uint64 ) {
45- r .shouldReAuthLock .Lock ()
46- defer r .shouldReAuthLock .Unlock ()
47- delete (r .shouldReAuth , connID )
48- }
49-
5047func (r * ReAuthPoolHook ) OnGet (_ context.Context , conn * pool.Conn , _ bool ) (accept bool , err error ) {
48+ connID := conn .GetID ()
5149 r .shouldReAuthLock .RLock ()
52- _ , ok := r .shouldReAuth [conn . GetID () ]
50+ _ , shouldReAuth := r .shouldReAuth [connID ]
5351 r .shouldReAuthLock .RUnlock ()
5452 // This connection was marked for reauth while in the pool,
5553 // reject the connection
56- if ok {
54+ if shouldReAuth {
5755 // simply reject the connection, it will be re-authenticated in OnPut
5856 return false , nil
5957 }
6058 r .scheduledLock .RLock ()
61- hasScheduled , ok := r .scheduledReAuth [conn . GetID () ]
59+ _ , hasScheduled := r .scheduledReAuth [connID ]
6260 r .scheduledLock .RUnlock ()
6361 // has scheduled reauth, reject the connection
64- if ok && hasScheduled {
62+ if hasScheduled {
6563 // simply reject the connection, it currently has a reauth scheduled
6664 // and the worker is waiting for slot to execute the reauth
6765 return false , nil
@@ -70,22 +68,38 @@ func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (acce
7068}
7169
7270func (r * ReAuthPoolHook ) OnPut (_ context.Context , conn * pool.Conn ) (bool , bool , error ) {
71+ if conn == nil {
72+ // noop
73+ return true , false , nil
74+ }
75+ connID := conn .GetID ()
7376 // Check if reauth is needed and get the function with proper locking
7477 r .shouldReAuthLock .RLock ()
75- reAuthFn , ok := r .shouldReAuth [conn .GetID ()]
78+ r .scheduledLock .RLock ()
79+ reAuthFn , ok := r .shouldReAuth [connID ]
7680 r .shouldReAuthLock .RUnlock ()
7781
7882 if ok {
83+ r .shouldReAuthLock .Lock ()
7984 r .scheduledLock .Lock ()
80- r .scheduledReAuth [conn .GetID ()] = true
85+ r .scheduledReAuth [connID ] = true
86+ delete (r .shouldReAuth , connID )
8187 r .scheduledLock .Unlock ()
82- // Clear the mark immediately to prevent duplicate reauth attempts
83- r .ClearReAuthMark (conn .GetID ())
88+ r .shouldReAuthLock .Unlock ()
8489 go func () {
8590 <- r .workers
91+ // safety first
92+ if conn == nil || (conn != nil && conn .IsClosed ()) {
93+ r .workers <- struct {}{}
94+ return
95+ }
8696 defer func () {
97+ if rec := recover (); rec != nil {
98+ // once again - safety first
99+ internal .Logger .Printf (context .Background (), "panic in reauth worker: %v" , rec )
100+ }
87101 r .scheduledLock .Lock ()
88- delete (r .scheduledReAuth , conn . GetID () )
102+ delete (r .scheduledReAuth , connID )
89103 r .scheduledLock .Unlock ()
90104 r .workers <- struct {}{}
91105 }()
@@ -96,7 +110,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
96110 // Try to acquire the connection
97111 // We need to ensure the connection is both Usable and not Used
98112 // to prevent data races with concurrent operations
99- const baseDelay = time .Microsecond
113+ const baseDelay = 10 * time .Microsecond
100114 acquired := false
101115 attempt := 0
102116 for ! acquired {
@@ -108,36 +122,33 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
108122 return
109123 default :
110124 // Try to acquire: set Usable=false, then check Used
111- if conn .Usable . CompareAndSwap (true , false ) {
112- if ! conn .Used . Load () {
125+ if conn .CompareAndSwapUsable (true , false ) {
126+ if ! conn .IsUsed () {
113127 acquired = true
114128 } else {
115129 // Release Usable and retry with exponential backoff
116- conn .Usable .Store (true )
117- if attempt > 0 {
118- // Exponential backoff: 1, 2, 4, 8... up to 512 microseconds
119- delay := baseDelay * time .Duration (1 << uint (attempt % 10 )) // Cap exponential growth
120- time .Sleep (delay )
121- }
122- attempt ++
123- }
124- } else {
125- // Connection not usable, retry with exponential backoff
126- if attempt > 0 {
127- // Exponential backoff: 1, 2, 4, 8... up to 512 microseconds
128- delay := baseDelay * time .Duration (1 << uint (attempt % 10 )) // Cap exponential growth
129- time .Sleep (delay )
130+ // todo(ndyakov): think of a better way to do this without the need
131+ // to release the connection, but just wait till it is not used
132+ conn .SetUsable (true )
130133 }
134+ }
135+ if ! acquired {
136+ // Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
137+ delay := baseDelay * time .Duration (1 << uint (attempt % 10 )) // Cap exponential growth
138+ time .Sleep (delay )
131139 attempt ++
132140 }
133141 }
134142 }
135143
136- // Successfully acquired the connection, perform reauth
137- reAuthFn (nil )
144+ // safety first
145+ if ! conn .IsClosed () {
146+ // Successfully acquired the connection, perform reauth
147+ reAuthFn (nil )
148+ }
138149
139150 // Release the connection
140- conn .Usable . Store (true )
151+ conn .SetUsable (true )
141152 }()
142153 }
143154
@@ -147,10 +158,16 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
147158}
148159
149160func (r * ReAuthPoolHook ) OnRemove (_ context.Context , conn * pool.Conn , _ error ) {
161+ connID := conn .GetID ()
162+ r .shouldReAuthLock .Lock ()
150163 r .scheduledLock .Lock ()
151- delete (r .scheduledReAuth , conn .GetID ())
164+ delete (r .scheduledReAuth , connID )
165+ delete (r .shouldReAuth , connID )
152166 r .scheduledLock .Unlock ()
153- r .ClearReAuthMark (conn .GetID ())
167+ r .shouldReAuthLock .Unlock ()
168+ if r .manager != nil {
169+ r .manager .RemoveListener (connID )
170+ }
154171}
155172
156173var _ pool.PoolHook = (* ReAuthPoolHook )(nil )
0 commit comments