@@ -36,10 +36,11 @@ pub struct ProjectInfoCache {
36
36
}
37
37
38
38
type RefCount = Mutex < usize > ;
39
+
39
40
// This is rather hacky.
40
41
// We use an ordered map of (K, V) -> RefCount.
41
42
// We use range queries over `(K, _)..(K+1, _)` to do the invalidation.
42
- // We use the RefCount to know when to remove the mappings .
43
+ // We use the RefCount to know when to remove entries .
43
44
type RefCountMultiSet < K , V > = SkipMap < KeyValue < K , V > , RefCount > ;
44
45
45
46
#[ derive( PartialEq , Eq , PartialOrd , Ord , Clone , Copy , Debug ) ]
@@ -70,6 +71,37 @@ struct Entry<T> {
70
71
value : T ,
71
72
}
72
73
74
+ impl < T > Entry < T > {
75
+ fn dec_ref_counts (
76
+ self ,
77
+ project2ep : & RefCountMultiSet < ProjectIdInt , EndpointIdInt > ,
78
+ account2ep : & RefCountMultiSet < AccountIdInt , EndpointIdInt > ,
79
+ endpoint_id : EndpointIdInt ,
80
+ ) {
81
+ if let Some ( project_id) = self . project_id {
82
+ dec_ref_count ( project2ep, project_id, endpoint_id) ;
83
+ }
84
+ if let Some ( account_id) = self . account_id {
85
+ dec_ref_count ( account2ep, account_id, endpoint_id) ;
86
+ }
87
+ }
88
+ }
89
+
90
+ fn dec_ref_count < Id : Ord + Send + ' static > (
91
+ id2ep : & RefCountMultiSet < Id , EndpointIdInt > ,
92
+ id : Id ,
93
+ endpoint_id : EndpointIdInt ,
94
+ ) {
95
+ if let Some ( entry) = id2ep. get ( & KeyValue ( id, endpoint_id) ) {
96
+ let mut count = entry. value ( ) . lock_propagate_poison ( ) ;
97
+ * count -= 1 ;
98
+ if * count == 0 {
99
+ // remove the entry while holding the lock
100
+ entry. remove ( ) ;
101
+ }
102
+ }
103
+ }
104
+
73
105
impl ProjectInfoCache {
74
106
pub fn invalidate_endpoint_access ( & self , endpoint_id : EndpointIdInt ) {
75
107
info ! ( "invalidating endpoint access for `{endpoint_id}`" ) ;
@@ -119,26 +151,44 @@ impl ProjectInfoCache {
119
151
. capacity
120
152
. set ( CacheKind :: ProjectInfoEndpoints , config. size as i64 ) ;
121
153
122
- let project2ep = Arc :: new ( RefCountMultiSet :: new ( ) ) ;
123
- let account2ep = Arc :: new ( RefCountMultiSet :: new ( ) ) ;
154
+ let project2ep = Arc :: new ( RefCountMultiSet :: < ProjectIdInt , EndpointIdInt > :: new ( ) ) ;
155
+ let account2ep = Arc :: new ( RefCountMultiSet :: < AccountIdInt , EndpointIdInt > :: new ( ) ) ;
156
+ let project2ep1 = Arc :: clone ( & project2ep) ;
157
+ let project2ep2 = Arc :: clone ( & project2ep) ;
158
+ let account2ep1 = Arc :: clone ( & account2ep) ;
159
+ let account2ep2 = Arc :: clone ( & account2ep) ;
124
160
125
161
// we cache errors for 30 seconds, unless retry_at is set.
126
162
let expiry = CplaneExpiry :: default ( ) ;
127
163
Self {
128
164
role_controls : Cache :: builder ( )
129
165
. name ( "role_access_controls" )
130
- . eviction_listener ( |_k, _v, cause| {
131
- eviction_listener ( CacheKind :: ProjectInfoRoles , cause) ;
132
- } )
166
+ . eviction_listener (
167
+ move |k, v : ControlPlaneResult < Entry < RoleAccessControl > > , cause| {
168
+ eviction_listener ( CacheKind :: ProjectInfoRoles , cause) ;
169
+
170
+ let ( endpoint_id, _) : ( EndpointIdInt , RoleNameInt ) = * k;
171
+ if let Ok ( v) = v {
172
+ v. dec_ref_counts ( & project2ep1, & account2ep1, endpoint_id) ;
173
+ }
174
+ } ,
175
+ )
133
176
. max_capacity ( config. size * config. max_roles )
134
177
. time_to_live ( config. ttl )
135
178
. expire_after ( expiry)
136
179
. build ( ) ,
137
180
ep_controls : Cache :: builder ( )
138
181
. name ( "endpoint_access_controls" )
139
- . eviction_listener ( |_k, _v, cause| {
140
- eviction_listener ( CacheKind :: ProjectInfoEndpoints , cause) ;
141
- } )
182
+ . eviction_listener (
183
+ move |k, v : ControlPlaneResult < Entry < EndpointAccessControl > > , cause| {
184
+ eviction_listener ( CacheKind :: ProjectInfoEndpoints , cause) ;
185
+
186
+ let endpoint_id: EndpointIdInt = * k;
187
+ if let Ok ( v) = v {
188
+ v. dec_ref_counts ( & project2ep2, & account2ep2, endpoint_id) ;
189
+ }
190
+ } ,
191
+ )
142
192
. max_capacity ( config. size )
143
193
. time_to_live ( config. ttl )
144
194
. expire_after ( expiry)
@@ -188,11 +238,12 @@ impl ProjectInfoCache {
188
238
controls : EndpointAccessControl ,
189
239
role_controls : RoleAccessControl ,
190
240
) {
241
+ // 2 corresponds to how many cache inserts we do.
191
242
if let Some ( account_id) = account_id {
192
- self . insert_account2endpoint ( account_id, endpoint_id) ;
243
+ self . inc_account2ep_ref ( account_id, endpoint_id, 2 ) ;
193
244
}
194
245
if let Some ( project_id) = project_id {
195
- self . insert_project2endpoint ( project_id, endpoint_id) ;
246
+ self . inc_project2ep_ref ( project_id, endpoint_id, 2 ) ;
196
247
}
197
248
198
249
debug ! (
@@ -256,18 +307,18 @@ impl ProjectInfoCache {
256
307
. insert ( ( endpoint_id, role_name) , Err ( msg) ) ;
257
308
}
258
309
259
- fn insert_project2endpoint ( & self , project_id : ProjectIdInt , endpoint_id : EndpointIdInt ) {
310
+ fn inc_project2ep_ref ( & self , project_id : ProjectIdInt , endpoint_id : EndpointIdInt , x : usize ) {
260
311
let entry = self
261
312
. project2ep
262
313
. get_or_insert ( KeyValue ( project_id, endpoint_id) , Mutex :: new ( 0 ) ) ;
263
- * entry. value ( ) . lock_propagate_poison ( ) += 1 ;
314
+ * entry. value ( ) . lock_propagate_poison ( ) += x ;
264
315
}
265
316
266
- fn insert_account2endpoint ( & self , account_id : AccountIdInt , endpoint_id : EndpointIdInt ) {
317
+ fn inc_account2ep_ref ( & self , account_id : AccountIdInt , endpoint_id : EndpointIdInt , x : usize ) {
267
318
let entry = self
268
319
. account2ep
269
320
. get_or_insert ( KeyValue ( account_id, endpoint_id) , Mutex :: new ( 0 ) ) ;
270
- * entry. value ( ) . lock_propagate_poison ( ) += 1 ;
321
+ * entry. value ( ) . lock_propagate_poison ( ) += x ;
271
322
}
272
323
273
324
pub fn maybe_invalidate_role_secret ( & self , _endpoint_id : & EndpointId , _role_name : & RoleName ) {
@@ -332,6 +383,16 @@ mod tests {
332
383
} ,
333
384
) ;
334
385
386
+ cache. ep_controls . run_pending_tasks ( ) ;
387
+ cache. role_controls . run_pending_tasks ( ) ;
388
+
389
+ // check the project mappings are there
390
+ assert_eq ! ( cache. project2ep. len( ) , 1 ) ;
391
+
392
+ // check the ref counts
393
+ let entry = cache. project2ep . front ( ) . unwrap ( ) ;
394
+ assert_eq ! ( * entry. value( ) . lock_propagate_poison( ) , 2 ) ;
395
+
335
396
cache. insert_endpoint_access (
336
397
account_id,
337
398
project_id,
@@ -348,6 +409,17 @@ mod tests {
348
409
} ,
349
410
) ;
350
411
412
+ cache. ep_controls . run_pending_tasks ( ) ;
413
+ cache. role_controls . run_pending_tasks ( ) ;
414
+
415
+ // check the project mappings are still there
416
+ assert_eq ! ( cache. project2ep. len( ) , 1 ) ;
417
+
418
+ // check the ref counts
419
+ let entry = cache. project2ep . front ( ) . unwrap ( ) ;
420
+ assert_eq ! ( * entry. value( ) . lock_propagate_poison( ) , 3 ) ;
421
+
422
+ // check both entries exist
351
423
let cached = cache. get_role_secret ( & endpoint_id, & user1) . unwrap ( ) ;
352
424
assert_eq ! ( cached. unwrap( ) . secret, secret1) ;
353
425
@@ -375,13 +447,26 @@ mod tests {
375
447
} ,
376
448
) ;
377
449
450
+ cache. ep_controls . run_pending_tasks ( ) ;
378
451
cache. role_controls . run_pending_tasks ( ) ;
452
+
379
453
assert_eq ! ( cache. role_controls. entry_count( ) , 2 ) ;
380
454
455
+ // check the project mappings are still there
456
+ assert_eq ! ( cache. project2ep. len( ) , 1 ) ;
457
+
458
+ // check the ref counts are unchanged.
459
+ let entry = cache. project2ep . front ( ) . unwrap ( ) ;
460
+ assert_eq ! ( * entry. value( ) . lock_propagate_poison( ) , 3 ) ;
461
+
381
462
tokio:: time:: sleep ( Duration :: from_secs ( 2 ) ) . await ;
382
463
464
+ cache. ep_controls . run_pending_tasks ( ) ;
383
465
cache. role_controls . run_pending_tasks ( ) ;
384
466
assert_eq ! ( cache. role_controls. entry_count( ) , 0 ) ;
467
+
468
+ // check the project/account mappings are no longer there
469
+ assert ! ( cache. project2ep. is_empty( ) ) ;
385
470
}
386
471
387
472
#[ tokio:: test]
0 commit comments