Skip to content

Commit 8549b42

Browse files
committed
implement removal
1 parent b39498c commit 8549b42

File tree

1 file changed

+100
-15
lines changed

1 file changed

+100
-15
lines changed

proxy/src/cache/project_info.rs

Lines changed: 100 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,11 @@ pub struct ProjectInfoCache {
3636
}
3737

3838
type RefCount = Mutex<usize>;
39+
3940
// This is rather hacky.
4041
// We use an ordered map of (K, V) -> RefCount.
4142
// We use range queries over `(K, _)..(K+1, _)` to do the invalidation.
42-
// We use the RefCount to know when to remove the mappings.
43+
// We use the RefCount to know when to remove entries.
4344
type RefCountMultiSet<K, V> = SkipMap<KeyValue<K, V>, RefCount>;
4445

4546
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Debug)]
@@ -70,6 +71,37 @@ struct Entry<T> {
7071
value: T,
7172
}
7273

74+
impl<T> Entry<T> {
75+
fn dec_ref_counts(
76+
self,
77+
project2ep: &RefCountMultiSet<ProjectIdInt, EndpointIdInt>,
78+
account2ep: &RefCountMultiSet<AccountIdInt, EndpointIdInt>,
79+
endpoint_id: EndpointIdInt,
80+
) {
81+
if let Some(project_id) = self.project_id {
82+
dec_ref_count(project2ep, project_id, endpoint_id);
83+
}
84+
if let Some(account_id) = self.account_id {
85+
dec_ref_count(account2ep, account_id, endpoint_id);
86+
}
87+
}
88+
}
89+
90+
fn dec_ref_count<Id: Ord + Send + 'static>(
91+
id2ep: &RefCountMultiSet<Id, EndpointIdInt>,
92+
id: Id,
93+
endpoint_id: EndpointIdInt,
94+
) {
95+
if let Some(entry) = id2ep.get(&KeyValue(id, endpoint_id)) {
96+
let mut count = entry.value().lock_propagate_poison();
97+
*count -= 1;
98+
if *count == 0 {
99+
// remove the entry while holding the lock
100+
entry.remove();
101+
}
102+
}
103+
}
104+
73105
impl ProjectInfoCache {
74106
pub fn invalidate_endpoint_access(&self, endpoint_id: EndpointIdInt) {
75107
info!("invalidating endpoint access for `{endpoint_id}`");
@@ -119,26 +151,44 @@ impl ProjectInfoCache {
119151
.capacity
120152
.set(CacheKind::ProjectInfoEndpoints, config.size as i64);
121153

122-
let project2ep = Arc::new(RefCountMultiSet::new());
123-
let account2ep = Arc::new(RefCountMultiSet::new());
154+
let project2ep = Arc::new(RefCountMultiSet::<ProjectIdInt, EndpointIdInt>::new());
155+
let account2ep = Arc::new(RefCountMultiSet::<AccountIdInt, EndpointIdInt>::new());
156+
let project2ep1 = Arc::clone(&project2ep);
157+
let project2ep2 = Arc::clone(&project2ep);
158+
let account2ep1 = Arc::clone(&account2ep);
159+
let account2ep2 = Arc::clone(&account2ep);
124160

125161
// we cache errors for 30 seconds, unless retry_at is set.
126162
let expiry = CplaneExpiry::default();
127163
Self {
128164
role_controls: Cache::builder()
129165
.name("role_access_controls")
130-
.eviction_listener(|_k, _v, cause| {
131-
eviction_listener(CacheKind::ProjectInfoRoles, cause);
132-
})
166+
.eviction_listener(
167+
move |k, v: ControlPlaneResult<Entry<RoleAccessControl>>, cause| {
168+
eviction_listener(CacheKind::ProjectInfoRoles, cause);
169+
170+
let (endpoint_id, _): (EndpointIdInt, RoleNameInt) = *k;
171+
if let Ok(v) = v {
172+
v.dec_ref_counts(&project2ep1, &account2ep1, endpoint_id);
173+
}
174+
},
175+
)
133176
.max_capacity(config.size * config.max_roles)
134177
.time_to_live(config.ttl)
135178
.expire_after(expiry)
136179
.build(),
137180
ep_controls: Cache::builder()
138181
.name("endpoint_access_controls")
139-
.eviction_listener(|_k, _v, cause| {
140-
eviction_listener(CacheKind::ProjectInfoEndpoints, cause);
141-
})
182+
.eviction_listener(
183+
move |k, v: ControlPlaneResult<Entry<EndpointAccessControl>>, cause| {
184+
eviction_listener(CacheKind::ProjectInfoEndpoints, cause);
185+
186+
let endpoint_id: EndpointIdInt = *k;
187+
if let Ok(v) = v {
188+
v.dec_ref_counts(&project2ep2, &account2ep2, endpoint_id);
189+
}
190+
},
191+
)
142192
.max_capacity(config.size)
143193
.time_to_live(config.ttl)
144194
.expire_after(expiry)
@@ -188,11 +238,12 @@ impl ProjectInfoCache {
188238
controls: EndpointAccessControl,
189239
role_controls: RoleAccessControl,
190240
) {
241+
// 2 corresponds to how many cache inserts we do.
191242
if let Some(account_id) = account_id {
192-
self.insert_account2endpoint(account_id, endpoint_id);
243+
self.inc_account2ep_ref(account_id, endpoint_id, 2);
193244
}
194245
if let Some(project_id) = project_id {
195-
self.insert_project2endpoint(project_id, endpoint_id);
246+
self.inc_project2ep_ref(project_id, endpoint_id, 2);
196247
}
197248

198249
debug!(
@@ -256,18 +307,18 @@ impl ProjectInfoCache {
256307
.insert((endpoint_id, role_name), Err(msg));
257308
}
258309

259-
fn insert_project2endpoint(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt) {
310+
fn inc_project2ep_ref(&self, project_id: ProjectIdInt, endpoint_id: EndpointIdInt, x: usize) {
260311
let entry = self
261312
.project2ep
262313
.get_or_insert(KeyValue(project_id, endpoint_id), Mutex::new(0));
263-
*entry.value().lock_propagate_poison() += 1;
314+
*entry.value().lock_propagate_poison() += x;
264315
}
265316

266-
fn insert_account2endpoint(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt) {
317+
fn inc_account2ep_ref(&self, account_id: AccountIdInt, endpoint_id: EndpointIdInt, x: usize) {
267318
let entry = self
268319
.account2ep
269320
.get_or_insert(KeyValue(account_id, endpoint_id), Mutex::new(0));
270-
*entry.value().lock_propagate_poison() += 1;
321+
*entry.value().lock_propagate_poison() += x;
271322
}
272323

273324
pub fn maybe_invalidate_role_secret(&self, _endpoint_id: &EndpointId, _role_name: &RoleName) {
@@ -332,6 +383,16 @@ mod tests {
332383
},
333384
);
334385

386+
cache.ep_controls.run_pending_tasks();
387+
cache.role_controls.run_pending_tasks();
388+
389+
// check the project mappings are there
390+
assert_eq!(cache.project2ep.len(), 1);
391+
392+
// check the ref counts
393+
let entry = cache.project2ep.front().unwrap();
394+
assert_eq!(*entry.value().lock_propagate_poison(), 2);
395+
335396
cache.insert_endpoint_access(
336397
account_id,
337398
project_id,
@@ -348,6 +409,17 @@ mod tests {
348409
},
349410
);
350411

412+
cache.ep_controls.run_pending_tasks();
413+
cache.role_controls.run_pending_tasks();
414+
415+
// check the project mappings are still there
416+
assert_eq!(cache.project2ep.len(), 1);
417+
418+
// check the ref counts
419+
let entry = cache.project2ep.front().unwrap();
420+
assert_eq!(*entry.value().lock_propagate_poison(), 3);
421+
422+
// check both entries exist
351423
let cached = cache.get_role_secret(&endpoint_id, &user1).unwrap();
352424
assert_eq!(cached.unwrap().secret, secret1);
353425

@@ -375,13 +447,26 @@ mod tests {
375447
},
376448
);
377449

450+
cache.ep_controls.run_pending_tasks();
378451
cache.role_controls.run_pending_tasks();
452+
379453
assert_eq!(cache.role_controls.entry_count(), 2);
380454

455+
// check the project mappings are still there
456+
assert_eq!(cache.project2ep.len(), 1);
457+
458+
// check the ref counts are unchanged.
459+
let entry = cache.project2ep.front().unwrap();
460+
assert_eq!(*entry.value().lock_propagate_poison(), 3);
461+
381462
tokio::time::sleep(Duration::from_secs(2)).await;
382463

464+
cache.ep_controls.run_pending_tasks();
383465
cache.role_controls.run_pending_tasks();
384466
assert_eq!(cache.role_controls.entry_count(), 0);
467+
468+
// check the project/account mappings are no longer there
469+
assert!(cache.project2ep.is_empty());
385470
}
386471

387472
#[tokio::test]

0 commit comments

Comments
 (0)