11import pickle
2- from typing import TypeVar
2+ from typing import Dict , Optional , TypeVar , Union
33
44from redis .asyncio import ConnectionPool , Redis
55from taskiq import AsyncResultBackend
66from taskiq .abc .result_backend import TaskiqResult
77
8+ from taskiq_redis .exceptions import (
9+ DuplicateExpireTimeSelectedError ,
10+ ExpireTimeMustBeMoreThanZeroError ,
11+ )
12+
813_ReturnType = TypeVar ("_ReturnType" )
914
1015
1116class RedisAsyncResultBackend (AsyncResultBackend [_ReturnType ]):
1217 """Async result based on redis."""
1318
14- def __init__ (self , redis_url : str , keep_results : bool = True ):
19+ def __init__ (
20+ self ,
21+ redis_url : str ,
22+ keep_results : bool = True ,
23+ result_ex_time : Optional [int ] = None ,
24+ result_px_time : Optional [int ] = None ,
25+ ):
1526 """
1627 Constructs a new result backend.
1728
1829 :param redis_url: url to redis.
1930 :param keep_results: flag to not remove results from Redis after reading.
31+ :param result_ex_time: expire time in seconds for result.
32+ :param result_px_time: expire time in milliseconds for result.
33+
34+ :raises DuplicateExpireTimeSelectedError: if result_ex_time
35+ and result_px_time are selected.
36+ :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
37+ and result_px_time are equal zero.
2038 """
2139 self .redis_pool = ConnectionPool .from_url (redis_url )
2240 self .keep_results = keep_results
41+ self .result_ex_time = result_ex_time
42+ self .result_px_time = result_px_time
43+
44+ if self .result_ex_time == 0 or self .result_px_time == 0 :
45+ raise ExpireTimeMustBeMoreThanZeroError (
46+ "You must select one expire time param and it must be more than zero." ,
47+ )
48+
49+ if self .result_ex_time and self .result_px_time :
50+ raise DuplicateExpireTimeSelectedError (
51+ "Choose either result_ex_time or result_px_time." ,
52+ )
53+
54+ if not self .result_ex_time and not self .result_px_time :
55+ self .result_ex_time = 60
2356
2457 async def shutdown (self ) -> None :
2558 """Closes redis connection."""
2659 await self .redis_pool .disconnect ()
60+ await super ().shutdown ()
2761
2862 async def set_result (
2963 self ,
@@ -39,19 +73,17 @@ async def set_result(
3973 :param task_id: ID of the task.
4074 :param result: TaskiqResult instance.
4175 """
42- result_dict = result .dict (exclude = {"return_value" })
43-
44- for result_key , result_value in result_dict .items ():
45- result_dict [result_key ] = pickle .dumps (result_value )
46- # This trick will preserve original returned value.
47- # It helps when you return not serializable classes.
48- result_dict ["return_value" ] = pickle .dumps (result .return_value )
76+ redis_set_params : Dict [str , Union [str , bytes , int ]] = {
77+ "name" : task_id ,
78+ "value" : pickle .dumps (result ),
79+ }
80+ if self .result_ex_time :
81+ redis_set_params ["ex" ] = self .result_ex_time
82+ elif self .result_px_time :
83+ redis_set_params ["px" ] = self .result_px_time
4984
5085 async with Redis (connection_pool = self .redis_pool ) as redis :
51- await redis .hset (
52- task_id ,
53- mapping = result_dict ,
54- )
86+ await redis .set (** redis_set_params )
5587
5688 async def is_result_ready (self , task_id : str ) -> bool :
5789 """
@@ -76,23 +108,19 @@ async def get_result( # noqa: WPS210
76108 :param with_logs: if True it will download task's logs.
77109 :return: task's return value.
78110 """
79- fields = list (TaskiqResult .__fields__ .keys ())
80-
81- if not with_logs :
82- fields .remove ("log" )
83-
84111 async with Redis (connection_pool = self .redis_pool ) as redis :
85- result_values = await redis .hmget (
86- name = task_id ,
87- keys = fields ,
88- )
112+ if self .keep_results :
113+ result_value = await redis .get (
114+ name = task_id ,
115+ )
116+ else :
117+ result_value = await redis .getdel (
118+ name = task_id ,
119+ )
89120
90- if not self .keep_results :
91- await redis .delete (task_id )
121+ taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads (result_value )
92122
93- result = {
94- result_key : pickle .loads (result_value )
95- for result_value , result_key in zip (result_values , fields )
96- }
123+ if not with_logs :
124+ taskiq_result .log = None
97125
98- return TaskiqResult ( ** result )
126+ return taskiq_result
0 commit comments