11import pickle
2- from typing import TypeVar
2+ from typing import Dict , Optional , TypeVar , Union
33
44from redis .asyncio import ConnectionPool , Redis
55from taskiq import AsyncResultBackend
66from taskiq .abc .result_backend import TaskiqResult
77
8+ from taskiq_redis .exceptions import (
9+ DuplicateExpireTimeSelectedError ,
10+ ExpireTimeMustBeMoreThanZeroError ,
11+ )
12+
813_ReturnType = TypeVar ("_ReturnType" )
914
1015
1116class RedisAsyncResultBackend (AsyncResultBackend [_ReturnType ]):
1217 """Async result based on redis."""
1318
14- def __init__ (self , redis_url : str , keep_results : bool = True ):
19+ def __init__ (
20+ self ,
21+ redis_url : str ,
22+ keep_results : bool = True ,
23+ result_ex_time : Optional [int ] = None ,
24+ result_px_time : Optional [int ] = None ,
25+ ):
1526 """
1627 Constructs a new result backend.
1728
1829 :param redis_url: url to redis.
1930 :param keep_results: flag to not remove results from Redis after reading.
31+ :param result_ex_time: expire time in seconds for result.
32+ :param result_px_time: expire time in milliseconds for result.
33+
34+ :raises DuplicateExpireTimeSelectedError: if result_ex_time
35+ and result_px_time are selected.
36+ :raises ExpireTimeMustBeMoreThanZeroError: if result_ex_time
37+ and result_px_time are equal zero.
2038 """
2139 self .redis_pool = ConnectionPool .from_url (redis_url )
2240 self .keep_results = keep_results
41+ self .result_ex_time = result_ex_time
42+ self .result_px_time = result_px_time
43+
44+ if self .result_ex_time == 0 or self .result_px_time == 0 :
45+ raise ExpireTimeMustBeMoreThanZeroError (
46+ "You must select one expire time param and it must be more than zero." ,
47+ )
48+
49+ if self .result_ex_time and self .result_px_time :
50+ raise DuplicateExpireTimeSelectedError (
51+ "Choose either result_ex_time or result_px_time." ,
52+ )
53+
54+ if not self .result_ex_time and not self .result_px_time :
55+ self .result_ex_time = 60
2356
2457 async def shutdown (self ) -> None :
2558 """Closes redis connection."""
@@ -40,19 +73,17 @@ async def set_result(
4073 :param task_id: ID of the task.
4174 :param result: TaskiqResult instance.
4275 """
43- result_dict = result .dict (exclude = {"return_value" })
44-
45- for result_key , result_value in result_dict .items ():
46- result_dict [result_key ] = pickle .dumps (result_value )
47- # This trick will preserve original returned value.
48- # It helps when you return not serializable classes.
49- result_dict ["return_value" ] = pickle .dumps (result .return_value )
76+ redis_set_params : Dict [str , Union [str , bytes , int ]] = {
77+ "name" : task_id ,
78+ "value" : pickle .dumps (result ),
79+ }
80+ if self .result_ex_time :
81+ redis_set_params ["ex" ] = self .result_ex_time
82+ elif self .result_px_time :
83+ redis_set_params ["px" ] = self .result_px_time
5084
5185 async with Redis (connection_pool = self .redis_pool ) as redis :
52- await redis .hset (
53- task_id ,
54- mapping = result_dict ,
55- )
86+ await redis .set (** redis_set_params )
5687
5788 async def is_result_ready (self , task_id : str ) -> bool :
5889 """
@@ -77,23 +108,19 @@ async def get_result( # noqa: WPS210
77108 :param with_logs: if True it will download task's logs.
78109 :return: task's return value.
79110 """
80- fields = list (TaskiqResult .__fields__ .keys ())
81-
82- if not with_logs :
83- fields .remove ("log" )
84-
85111 async with Redis (connection_pool = self .redis_pool ) as redis :
86- result_values = await redis .hmget (
87- name = task_id ,
88- keys = fields ,
89- )
112+ if self .keep_results :
113+ result_value = await redis .get (
114+ name = task_id ,
115+ )
116+ else :
117+ result_value = await redis .getdel (
118+ name = task_id ,
119+ )
90120
91- if not self .keep_results :
92- await redis .delete (task_id )
121+ taskiq_result : TaskiqResult [_ReturnType ] = pickle .loads (result_value )
93122
94- result = {
95- result_key : pickle .loads (result_value )
96- for result_value , result_key in zip (result_values , fields )
97- }
123+ if not with_logs :
124+ taskiq_result .log = None
98125
99- return TaskiqResult ( ** result )
126+ return taskiq_result
0 commit comments