|
18 | 18 | import logging |
19 | 19 | from typing import ( |
20 | 20 | Any, |
| 21 | + Awaitable, |
21 | 22 | Callable, |
22 | 23 | Dict, |
23 | 24 | Generic, |
@@ -346,15 +347,15 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): |
346 | 347 | """Wraps an existing cache to support bulk fetching of keys. |
347 | 348 |
|
348 | 349 | Given an iterable of keys it looks in the cache to find any hits, then passes |
349 | | - the tuple of missing keys to the wrapped function. |
| 350 | + the set of missing keys to the wrapped function. |
350 | 351 |
|
351 | | - Once wrapped, the function returns a Deferred which resolves to the list |
352 | | - of results. |
| 352 | + Once wrapped, the function returns a Deferred which resolves to a Dict mapping from |
| 353 | + input key to output value. |
353 | 354 | """ |
354 | 355 |
|
355 | 356 | def __init__( |
356 | 357 | self, |
357 | | - orig: Callable[..., Any], |
| 358 | + orig: Callable[..., Awaitable[Dict]], |
358 | 359 | cached_method_name: str, |
359 | 360 | list_name: str, |
360 | 361 | num_args: Optional[int] = None, |
@@ -385,13 +386,13 @@ def __init__( |
385 | 386 |
|
386 | 387 | def __get__( |
387 | 388 | self, obj: Optional[Any], objtype: Optional[Type] = None |
388 | | - ) -> Callable[..., Any]: |
| 389 | + ) -> Callable[..., "defer.Deferred[Dict[Hashable, Any]]"]: |
389 | 390 | cached_method = getattr(obj, self.cached_method_name) |
390 | 391 | cache: DeferredCache[CacheKey, Any] = cached_method.cache |
391 | 392 | num_args = cached_method.num_args |
392 | 393 |
|
393 | 394 | @functools.wraps(self.orig) |
394 | | - def wrapped(*args: Any, **kwargs: Any) -> Any: |
| 395 | + def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": |
395 | 396 | # If we're passed a cache_context then we'll want to call its |
396 | 397 | # invalidate() whenever we are invalidated |
397 | 398 | invalidate_callback = kwargs.pop("on_invalidate", None) |
@@ -444,39 +445,38 @@ def arg_to_cache_key(arg: Hashable) -> Hashable: |
444 | 445 | deferred: "defer.Deferred[Any]" = defer.Deferred() |
445 | 446 | deferreds_map[arg] = deferred |
446 | 447 | key = arg_to_cache_key(arg) |
447 | | - cache.set(key, deferred, callback=invalidate_callback) |
| 448 | + cached_defers.append( |
| 449 | + cache.set(key, deferred, callback=invalidate_callback) |
| 450 | + ) |
448 | 451 |
|
449 | 452 | def complete_all(res: Dict[Hashable, Any]) -> None: |
450 | | - # the wrapped function has completed. It returns a |
451 | | - # a dict. We can now resolve the observable deferreds in |
452 | | - # the cache and update our own result map. |
453 | | - for e in missing: |
| 453 | + # the wrapped function has completed. It returns a dict. |
| 454 | + # We can now update our own result map, and then resolve the |
| 455 | + # observable deferreds in the cache. |
| 456 | + for e, d1 in deferreds_map.items(): |
454 | 457 | val = res.get(e, None) |
455 | | - deferreds_map[e].callback(val) |
| 458 | + # make sure we update the results map before running the |
| 459 | + # deferreds, because as soon as we run the last deferred, the |
| 460 | + # gatherResults() below will complete and return the result |
| 461 | + # dict to our caller. |
456 | 462 | results[e] = val |
| 463 | + d1.callback(val) |
457 | 464 |
|
458 | | - def errback(f: Failure) -> Failure: |
459 | | - # the wrapped function has failed. Invalidate any cache |
460 | | - # entries we're supposed to be populating, and fail |
461 | | - # their deferreds. |
462 | | - for e in missing: |
463 | | - key = arg_to_cache_key(e) |
464 | | - cache.invalidate(key) |
465 | | - deferreds_map[e].errback(f) |
466 | | - |
467 | | - # return the failure, to propagate to our caller. |
468 | | - return f |
| 465 | + def errback_all(f: Failure) -> None: |
| 466 | + # the wrapped function has failed. Propagate the failure into |
| 467 | + # the cache, which will invalidate the entry, and cause the |
| 468 | + # relevant cached_deferreds to fail, which will propagate the |
| 469 | + # failure to our caller. |
| 470 | + for d1 in deferreds_map.values(): |
| 471 | + d1.errback(f) |
469 | 472 |
|
470 | 473 | args_to_call = dict(arg_dict) |
471 | | - # copy the missing set before sending it to the callee, to guard against |
472 | | - # modification. |
473 | | - args_to_call[self.list_name] = tuple(missing) |
474 | | - |
475 | | - cached_defers.append( |
476 | | - defer.maybeDeferred( |
477 | | - preserve_fn(self.orig), **args_to_call |
478 | | - ).addCallbacks(complete_all, errback) |
479 | | - ) |
| 474 | + args_to_call[self.list_name] = missing |
| 475 | + |
| 476 | + # dispatch the call, and attach the two handlers |
| 477 | + defer.maybeDeferred( |
| 478 | + preserve_fn(self.orig), **args_to_call |
| 479 | + ).addCallbacks(complete_all, errback_all) |
480 | 480 |
|
481 | 481 | if cached_defers: |
482 | 482 | d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks( |
|
0 commit comments