Skip to content

Commit 1ee5293

Browse files
committed
address pr comments
1 parent 528f2e9 commit 1ee5293

File tree

9 files changed

+142
-93
lines changed

9 files changed

+142
-93
lines changed

internal/auth/streaming/conn_reauth_credentials_listener.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,7 @@ type ConnReAuthCredentialsListener struct {
2727
// It calls the reAuth function with the new credentials.
2828
// If the reAuth function returns an error, it calls the onErr function with the error.
2929
func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
30-
if c.conn.IsClosed() {
31-
return
32-
}
33-
34-
if c.reAuth == nil {
30+
if c.conn == nil || c.conn.IsClosed() || c.manager == nil || c.reAuth == nil {
3531
return
3632
}
3733

@@ -41,17 +37,20 @@ func (c *ConnReAuthCredentialsListener) OnNext(credentials auth.Credentials) {
4137
// The connection pool hook will re-authenticate the connection when it is
4238
// returned to the pool in a clean, idle state.
4339
c.manager.MarkForReAuth(c.conn, func(err error) {
40+
// err is from connection acquisition (timeout, etc.)
4441
if err != nil {
42+
// Log the error
4543
c.OnError(err)
4644
return
4745
}
46+
// err is from reauth command execution
4847
err = c.reAuth(c.conn, credentials)
4948
if err != nil {
49+
// Log the error
5050
c.OnError(err)
5151
return
5252
}
5353
})
54-
5554
}
5655

5756
// OnError is called when an error occurs.

internal/auth/streaming/manager.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ type Manager struct {
1515
}
1616

1717
func NewManager(pl pool.Pooler, reAuthTimeout time.Duration) *Manager {
18-
return &Manager{
18+
m := &Manager{
1919
pool: pl,
2020
poolHookRef: NewReAuthPoolHook(pl.Size(), reAuthTimeout),
2121
credentialsListeners: NewCredentialsListeners(),
2222
}
23+
m.poolHookRef.manager = m
24+
return m
2325
}
2426

2527
func (m *Manager) PoolHook() pool.PoolHook {
@@ -35,6 +37,10 @@ func (m *Manager) Listener(
3537
return nil, errors.New("poolCn cannot be nil")
3638
}
3739
connID := poolCn.GetID()
40+
// if we reconnect the underlying network connection, the streaming credentials listener will continue to work
41+
// so we can get the old listener from the cache and use it.
42+
43+
// subscribing the same (an already subscribed) listener for a StreamingCredentialsProvider SHOULD be a no-op
3844
listener, ok := m.credentialsListeners.Get(connID)
3945
if !ok || listener == nil {
4046
newCredListener := &ConnReAuthCredentialsListener{
@@ -54,3 +60,7 @@ func (m *Manager) MarkForReAuth(poolCn *pool.Conn, reAuthFn func(error)) {
5460
connID := poolCn.GetID()
5561
m.poolHookRef.MarkForReAuth(connID, reAuthFn)
5662
}
63+
64+
func (m *Manager) RemoveListener(connID uint64) {
65+
m.credentialsListeners.Remove(connID)
66+
}

internal/auth/streaming/pool_hook.go

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"sync"
66
"time"
77

8+
"github.com/redis/go-redis/v9/internal"
89
"github.com/redis/go-redis/v9/internal/pool"
910
)
1011

@@ -17,6 +18,9 @@ type ReAuthPoolHook struct {
1718
// conn id -> bool
1819
scheduledReAuth map[uint64]bool
1920
scheduledLock sync.RWMutex
21+
22+
// for cleanup
23+
manager *Manager
2024
}
2125

2226
func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHook {
@@ -32,7 +36,6 @@ func NewReAuthPoolHook(poolSize int, reAuthTimeout time.Duration) *ReAuthPoolHoo
3236
workers: workers,
3337
reAuthTimeout: reAuthTimeout,
3438
}
35-
3639
}
3740

3841
func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
@@ -41,27 +44,22 @@ func (r *ReAuthPoolHook) MarkForReAuth(connID uint64, reAuthFn func(error)) {
4144
r.shouldReAuth[connID] = reAuthFn
4245
}
4346

44-
func (r *ReAuthPoolHook) ClearReAuthMark(connID uint64) {
45-
r.shouldReAuthLock.Lock()
46-
defer r.shouldReAuthLock.Unlock()
47-
delete(r.shouldReAuth, connID)
48-
}
49-
5047
func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (accept bool, err error) {
48+
connID := conn.GetID()
5149
r.shouldReAuthLock.RLock()
52-
_, ok := r.shouldReAuth[conn.GetID()]
50+
_, shouldReAuth := r.shouldReAuth[connID]
5351
r.shouldReAuthLock.RUnlock()
5452
// This connection was marked for reauth while in the pool,
5553
// reject the connection
56-
if ok {
54+
if shouldReAuth {
5755
// simply reject the connection, it will be re-authenticated in OnPut
5856
return false, nil
5957
}
6058
r.scheduledLock.RLock()
61-
hasScheduled, ok := r.scheduledReAuth[conn.GetID()]
59+
_, hasScheduled := r.scheduledReAuth[connID]
6260
r.scheduledLock.RUnlock()
6361
// has scheduled reauth, reject the connection
64-
if ok && hasScheduled {
62+
if hasScheduled {
6563
// simply reject the connection, it currently has a reauth scheduled
6664
// and the worker is waiting for slot to execute the reauth
6765
return false, nil
@@ -70,22 +68,38 @@ func (r *ReAuthPoolHook) OnGet(_ context.Context, conn *pool.Conn, _ bool) (acce
7068
}
7169

7270
func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool, error) {
71+
if conn == nil {
72+
// noop
73+
return true, false, nil
74+
}
75+
connID := conn.GetID()
7376
// Check if reauth is needed and get the function with proper locking
7477
r.shouldReAuthLock.RLock()
75-
reAuthFn, ok := r.shouldReAuth[conn.GetID()]
78+
r.scheduledLock.RLock()
79+
reAuthFn, ok := r.shouldReAuth[connID]
7680
r.shouldReAuthLock.RUnlock()
7781

7882
if ok {
83+
r.shouldReAuthLock.Lock()
7984
r.scheduledLock.Lock()
80-
r.scheduledReAuth[conn.GetID()] = true
85+
r.scheduledReAuth[connID] = true
86+
delete(r.shouldReAuth, connID)
8187
r.scheduledLock.Unlock()
82-
// Clear the mark immediately to prevent duplicate reauth attempts
83-
r.ClearReAuthMark(conn.GetID())
88+
r.shouldReAuthLock.Unlock()
8489
go func() {
8590
<-r.workers
91+
// safety first
92+
if conn == nil || (conn != nil && conn.IsClosed()) {
93+
r.workers <- struct{}{}
94+
return
95+
}
8696
defer func() {
97+
if rec := recover(); rec != nil {
98+
// once again - safety first
99+
internal.Logger.Printf(context.Background(), "panic in reauth worker: %v", rec)
100+
}
87101
r.scheduledLock.Lock()
88-
delete(r.scheduledReAuth, conn.GetID())
102+
delete(r.scheduledReAuth, connID)
89103
r.scheduledLock.Unlock()
90104
r.workers <- struct{}{}
91105
}()
@@ -96,7 +110,7 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
96110
// Try to acquire the connection
97111
// We need to ensure the connection is both Usable and not Used
98112
// to prevent data races with concurrent operations
99-
const baseDelay = time.Microsecond
113+
const baseDelay = 10 * time.Microsecond
100114
acquired := false
101115
attempt := 0
102116
for !acquired {
@@ -108,36 +122,33 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
108122
return
109123
default:
110124
// Try to acquire: set Usable=false, then check Used
111-
if conn.Usable.CompareAndSwap(true, false) {
112-
if !conn.Used.Load() {
125+
if conn.CompareAndSwapUsable(true, false) {
126+
if !conn.IsUsed() {
113127
acquired = true
114128
} else {
115129
// Release Usable and retry with exponential backoff
116-
conn.Usable.Store(true)
117-
if attempt > 0 {
118-
// Exponential backoff: 1, 2, 4, 8... up to 512 microseconds
119-
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
120-
time.Sleep(delay)
121-
}
122-
attempt++
123-
}
124-
} else {
125-
// Connection not usable, retry with exponential backoff
126-
if attempt > 0 {
127-
// Exponential backoff: 1, 2, 4, 8... up to 512 microseconds
128-
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
129-
time.Sleep(delay)
130+
// todo(ndyakov): think of a better way to do this without the need
131+
// to release the connection, but just wait till it is not used
132+
conn.SetUsable(true)
130133
}
134+
}
135+
if !acquired {
136+
// Exponential backoff: 10, 20, 40, 80... up to 5120 microseconds
137+
delay := baseDelay * time.Duration(1<<uint(attempt%10)) // Cap exponential growth
138+
time.Sleep(delay)
131139
attempt++
132140
}
133141
}
134142
}
135143

136-
// Successfully acquired the connection, perform reauth
137-
reAuthFn(nil)
144+
// safety first
145+
if !conn.IsClosed() {
146+
// Successfully acquired the connection, perform reauth
147+
reAuthFn(nil)
148+
}
138149

139150
// Release the connection
140-
conn.Usable.Store(true)
151+
conn.SetUsable(true)
141152
}()
142153
}
143154

@@ -147,10 +158,16 @@ func (r *ReAuthPoolHook) OnPut(_ context.Context, conn *pool.Conn) (bool, bool,
147158
}
148159

149160
func (r *ReAuthPoolHook) OnRemove(_ context.Context, conn *pool.Conn, _ error) {
161+
connID := conn.GetID()
162+
r.shouldReAuthLock.Lock()
150163
r.scheduledLock.Lock()
151-
delete(r.scheduledReAuth, conn.GetID())
164+
delete(r.scheduledReAuth, connID)
165+
delete(r.shouldReAuth, connID)
152166
r.scheduledLock.Unlock()
153-
r.ClearReAuthMark(conn.GetID())
167+
r.shouldReAuthLock.Unlock()
168+
if r.manager != nil {
169+
r.manager.RemoveListener(connID)
170+
}
154171
}
155172

156173
var _ pool.PoolHook = (*ReAuthPoolHook)(nil)

internal/pool/conn.go

Lines changed: 44 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,15 @@ type Conn struct {
6868
// is not in use. That way, the connection won't be used to send multiple commands at the same time and
6969
// potentially corrupt the command stream.
7070

71-
// Usable flag to mark connection as safe for use
71+
// usable flag to mark connection as safe for use
7272
// It is false before initialization and after a handoff is marked
7373
// It will be false during other background operations like re-authentication
74-
Usable atomic.Bool
74+
usable atomic.Bool
7575

76-
// Used flag to mark connection as used when a command is going to be
76+
// used flag to mark connection as used when a command is going to be
7777
// processed on that connection. This is used to prevent a race condition with
7878
// background operations that may execute commands, like re-authentication.
79-
Used atomic.Bool
79+
used atomic.Bool
8080

8181
// Inited flag to mark connection as initialized, this is almost the same as usable
8282
// but it is used to make sure we don't initialize a network connection twice
@@ -142,7 +142,7 @@ func NewConnWithBufferSize(netConn net.Conn, readBufSize, writeBufSize int) *Con
142142
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
143143

144144
// Initialize atomic state
145-
cn.Usable.Store(false) // false initially, set to true after initialization
145+
cn.usable.Store(false) // false initially, set to true after initialization
146146
cn.handoffRetriesAtomic.Store(0) // 0 initially
147147

148148
// Initialize handoff state atomically
@@ -167,6 +167,42 @@ func (cn *Conn) SetUsedAt(tm time.Time) {
167167
atomic.StoreInt64(&cn.usedAt, tm.Unix())
168168
}
169169

170+
// Usable
171+
172+
// CompareAndSwapUsable atomically compares and swaps the usable flag (lock-free).
173+
func (cn *Conn) CompareAndSwapUsable(old, new bool) bool {
174+
return cn.usable.CompareAndSwap(old, new)
175+
}
176+
177+
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
178+
func (cn *Conn) IsUsable() bool {
179+
return cn.usable.Load()
180+
}
181+
182+
// SetUsable sets the usable flag for the connection (lock-free).
183+
// prefer CompareAndSwapUsable() when possible
184+
func (cn *Conn) SetUsable(usable bool) {
185+
cn.usable.Store(usable)
186+
}
187+
188+
// Used
189+
190+
// CompareAndSwapUsed atomically compares and swaps the used flag (lock-free).
191+
func (cn *Conn) CompareAndSwapUsed(old, new bool) bool {
192+
return cn.used.CompareAndSwap(old, new)
193+
}
194+
195+
// IsUsed returns true if the connection is currently in use (lock-free).
196+
func (cn *Conn) IsUsed() bool {
197+
return cn.used.Load()
198+
}
199+
200+
// SetUsed sets the used flag for the connection (lock-free).
201+
// prefer CompareAndSwapUsed() when possible
202+
func (cn *Conn) SetUsed(val bool) {
203+
cn.used.Store(val)
204+
}
205+
170206
// getNetConn returns the current network connection using atomic load (lock-free).
171207
// This is the fast path for accessing netConn without mutex overhead.
172208
func (cn *Conn) getNetConn() net.Conn {
@@ -184,18 +220,6 @@ func (cn *Conn) setNetConn(netConn net.Conn) {
184220
cn.netConnAtomic.Store(&atomicNetConn{conn: netConn})
185221
}
186222

187-
// Lock-free helper methods for handoff state management
188-
189-
// isUsable returns true if the connection is safe to use (lock-free).
190-
func (cn *Conn) isUsable() bool {
191-
return cn.Usable.Load()
192-
}
193-
194-
// setUsable sets the usable flag atomically (lock-free).
195-
func (cn *Conn) setUsable(usable bool) {
196-
cn.Usable.Store(usable)
197-
}
198-
199223
// getHandoffState returns the current handoff state atomically (lock-free).
200224
func (cn *Conn) getHandoffState() *HandoffState {
201225
state := cn.handoffStateAtomic.Load()
@@ -240,11 +264,6 @@ func (cn *Conn) incrementHandoffRetries(delta int) int {
240264
return int(cn.handoffRetriesAtomic.Add(uint32(delta)))
241265
}
242266

243-
// IsUsable returns true if the connection is safe to use for new commands (lock-free).
244-
func (cn *Conn) IsUsable() bool {
245-
return cn.isUsable()
246-
}
247-
248267
// IsPooled returns true if the connection is managed by a pool and will be pooled on Put.
249268
func (cn *Conn) IsPooled() bool {
250269
return cn.pooled
@@ -259,11 +278,6 @@ func (cn *Conn) IsInited() bool {
259278
return cn.Inited.Load()
260279
}
261280

262-
// SetUsable sets the usable flag for the connection (lock-free).
263-
func (cn *Conn) SetUsable(usable bool) {
264-
cn.setUsable(usable)
265-
}
266-
267281
// SetRelaxedTimeout sets relaxed timeouts for this connection during maintenanceNotifications upgrades.
268282
// These timeouts will be used for all subsequent commands until the deadline expires.
269283
// Uses atomic operations for lock-free access.
@@ -494,11 +508,10 @@ func (cn *Conn) MarkQueuedForHandoff() error {
494508
// first we need to mark the connection as not usable
495509
// to prevent the pool from returning it to the caller
496510
if !connAcquired {
497-
if cn.Usable.CompareAndSwap(true, false) {
498-
connAcquired = true
499-
} else {
511+
if !cn.usable.CompareAndSwap(true, false) {
500512
continue
501513
}
514+
connAcquired = true
502515
}
503516

504517
currentState := cn.getHandoffState()
@@ -568,7 +581,7 @@ func (cn *Conn) ClearHandoffState() {
568581
cn.setHandoffState(cleanState)
569582
cn.setHandoffRetries(0)
570583
// Clearing handoff state also means the connection is usable again
571-
cn.setUsable(true)
584+
cn.SetUsable(true)
572585
}
573586

574587
// IncrementAndGetHandoffRetries atomically increments and returns handoff retries (lock-free).

0 commit comments

Comments
 (0)