1
1
use std:: convert:: Infallible ;
2
- use std:: sync:: Arc ;
2
+ use std:: sync:: { Arc , Mutex } ;
3
3
4
- use crossbeam_skiplist:: SkipSet ;
4
+ use crossbeam_skiplist:: SkipMap ;
5
5
use crossbeam_skiplist:: equivalent:: { Comparable , Equivalent } ;
6
6
use moka:: sync:: Cache ;
7
7
use tracing:: { debug, info} ;
@@ -12,6 +12,7 @@ use crate::cache::common::{
12
12
use crate :: config:: ProjectInfoCacheOptions ;
13
13
use crate :: control_plane:: messages:: { ControlPlaneErrorMessage , Reason } ;
14
14
use crate :: control_plane:: { EndpointAccessControl , RoleAccessControl } ;
15
+ use crate :: ext:: LockExt ;
15
16
use crate :: intern:: { AccountIdInt , EndpointIdInt , ProjectIdInt , RoleNameInt } ;
16
17
use crate :: metrics:: { CacheKind , Metrics } ;
17
18
use crate :: types:: { EndpointId , RoleName } ;
@@ -28,16 +29,18 @@ pub struct ProjectInfoCache {
28
29
Cache < ( EndpointIdInt , RoleNameInt ) , ControlPlaneResult < Entry < RoleAccessControl > > > ,
29
30
ep_controls : Cache < EndpointIdInt , ControlPlaneResult < Entry < EndpointAccessControl > > > ,
30
31
31
- project2ep : Arc < MultiSet < ProjectIdInt , EndpointIdInt > > ,
32
- account2ep : Arc < MultiSet < AccountIdInt , EndpointIdInt > > ,
32
+ project2ep : Arc < RefCountMultiSet < ProjectIdInt , EndpointIdInt > > ,
33
+ account2ep : Arc < RefCountMultiSet < AccountIdInt , EndpointIdInt > > ,
33
34
34
35
config : ProjectInfoCacheOptions ,
35
36
}
36
37
38
+ type RefCount = Mutex < usize > ;
37
39
// This is rather hacky.
38
- // We use an ordered set of (K, V).
40
+ // We use an ordered map of (K, V) -> RefCount .
39
41
// We use range queries over `(K, _)..(K+1, _)` to do the invalidation.
40
- type MultiSet < K , V > = SkipSet < KeyValue < K , V > > ;
42
+ // We use the RefCount to know when to remove the mappings.
43
+ type RefCountMultiSet < K , V > = SkipMap < KeyValue < K , V > , RefCount > ;
41
44
42
45
#[ derive( PartialEq , Eq , PartialOrd , Ord , Clone , Copy , Debug ) ]
43
46
struct KeyValue < K , V > ( K , V ) ;
@@ -77,15 +80,15 @@ impl ProjectInfoCache {
77
80
info ! ( "invalidating endpoint access for project `{project_id}`" ) ;
78
81
79
82
for entry in self . project2ep . range ( Key :: prefix ( & project_id) ) {
80
- self . ep_controls . invalidate ( & entry. 1 ) ;
83
+ self . ep_controls . invalidate ( & entry. key ( ) . 1 ) ;
81
84
}
82
85
}
83
86
84
87
pub fn invalidate_endpoint_access_for_org ( & self , account_id : AccountIdInt ) {
85
88
info ! ( "invalidating endpoint access for org `{account_id}`" ) ;
86
89
87
90
for entry in self . account2ep . range ( Key :: prefix ( & account_id) ) {
88
- self . ep_controls . invalidate ( & entry. 1 ) ;
91
+ self . ep_controls . invalidate ( & entry. key ( ) . 1 ) ;
89
92
}
90
93
}
91
94
@@ -100,7 +103,7 @@ impl ProjectInfoCache {
100
103
) ;
101
104
102
105
for entry in self . project2ep . range ( Key :: prefix ( & project_id) ) {
103
- self . role_controls . invalidate ( & ( entry. 1 , role_name) ) ;
106
+ self . role_controls . invalidate ( & ( entry. key ( ) . 1 , role_name) ) ;
104
107
}
105
108
}
106
109
}
@@ -116,8 +119,8 @@ impl ProjectInfoCache {
116
119
. capacity
117
120
. set ( CacheKind :: ProjectInfoEndpoints , config. size as i64 ) ;
118
121
119
- let project2ep = Arc :: new ( MultiSet :: new ( ) ) ;
120
- let account2ep = Arc :: new ( MultiSet :: new ( ) ) ;
122
+ let project2ep = Arc :: new ( RefCountMultiSet :: new ( ) ) ;
123
+ let account2ep = Arc :: new ( RefCountMultiSet :: new ( ) ) ;
121
124
122
125
// we cache errors for 30 seconds, unless retry_at is set.
123
126
let expiry = CplaneExpiry :: default ( ) ;
@@ -254,11 +257,17 @@ impl ProjectInfoCache {
254
257
}
255
258
256
259
fn insert_project2endpoint ( & self , project_id : ProjectIdInt , endpoint_id : EndpointIdInt ) {
257
- self . project2ep . insert ( KeyValue ( project_id, endpoint_id) ) ;
260
+ let entry = self
261
+ . project2ep
262
+ . get_or_insert ( KeyValue ( project_id, endpoint_id) , Mutex :: new ( 0 ) ) ;
263
+ * entry. value ( ) . lock_propagate_poison ( ) += 1 ;
258
264
}
259
265
260
266
fn insert_account2endpoint ( & self , account_id : AccountIdInt , endpoint_id : EndpointIdInt ) {
261
- self . account2ep . insert ( KeyValue ( account_id, endpoint_id) ) ;
267
+ let entry = self
268
+ . account2ep
269
+ . get_or_insert ( KeyValue ( account_id, endpoint_id) , Mutex :: new ( 0 ) ) ;
270
+ * entry. value ( ) . lock_propagate_poison ( ) += 1 ;
262
271
}
263
272
264
273
pub fn maybe_invalidate_role_secret ( & self , _endpoint_id : & EndpointId , _role_name : & RoleName ) {
0 commit comments