Skip to content

Commit 73efcae

Browse files
authored
[autorevert] fix RetryWithBackoff, add tests (#7243)
a followup to #7241 fixes the logic and adds unit tests
1 parent a3efc82 commit 73efcae

File tree

2 files changed

+92
-11
lines changed

2 files changed

+92
-11
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import sys
2+
import unittest
3+
from unittest.mock import call, patch
4+
5+
6+
# Ensure package import when running from repo root
7+
sys.path.insert(0, "aws/lambda/pytorch-auto-revert")
8+
9+
from pytorch_auto_revert.utils import RetryWithBackoff # noqa: E402
10+
11+
12+
def run_with_retry(op, **kwargs):
13+
"""Helper mirroring how RetryWithBackoff is used in code.
14+
15+
for attempt in RetryWithBackoff(...):
16+
with attempt:
17+
return op()
18+
"""
19+
for attempt in RetryWithBackoff(**kwargs):
20+
with attempt:
21+
return op()
22+
23+
24+
EX = ValueError("boom")
25+
26+
27+
class UnstableOp:
28+
def __init__(self, fail_times: int, exc: Exception = EX):
29+
self.fail_times = fail_times
30+
self.calls = 0
31+
self.exc = exc
32+
33+
def __call__(self):
34+
self.calls += 1
35+
if self.calls <= self.fail_times:
36+
raise self.exc
37+
return 42
38+
39+
40+
class TestRetryWithBackoff(unittest.TestCase):
41+
@patch("pytorch_auto_revert.utils.time.sleep")
42+
def test_success_first_try_no_sleep(self, sleep_mock):
43+
op = UnstableOp(fail_times=0)
44+
res = run_with_retry(op, max_retries=3, base_delay=0.1, jitter=False)
45+
self.assertEqual(res, 42)
46+
self.assertEqual(op.calls, 1)
47+
sleep_mock.assert_not_called()
48+
49+
@patch("pytorch_auto_revert.utils.time.sleep")
50+
def test_eventual_success_after_retries(self, sleep_mock):
51+
op = UnstableOp(fail_times=2)
52+
res = run_with_retry(op, max_retries=5, base_delay=0.1, jitter=False)
53+
self.assertEqual(res, 42)
54+
self.assertEqual(op.calls, 3)
55+
# Backoff without jitter: 0.1, 0.2
56+
self.assertEqual(sleep_mock.call_args_list, [call(0.1), call(0.2)])
57+
58+
@patch("pytorch_auto_revert.utils.time.sleep")
59+
def test_raises_after_max_retries(self, sleep_mock):
60+
op = UnstableOp(fail_times=10)
61+
with self.assertRaises(ValueError):
62+
run_with_retry(op, max_retries=3, base_delay=0.1, jitter=False)
63+
# Two sleeps for attempts 1 and 2; none after final failed attempt
64+
self.assertEqual(sleep_mock.call_args_list, [call(0.1), call(0.2)])
65+
self.assertEqual(op.calls, 3)
66+
67+
@patch("pytorch_auto_revert.utils.random.uniform", side_effect=lambda a, b: b)
68+
@patch("pytorch_auto_revert.utils.time.sleep")
69+
def test_jitter_applied_to_backoff(self, sleep_mock, _uniform_mock):
70+
op = UnstableOp(fail_times=1)
71+
res = run_with_retry(op, max_retries=3, base_delay=0.2, jitter=True)
72+
self.assertEqual(res, 42)
73+
self.assertEqual(op.calls, 2)
74+
# With max jitter (10%), expected delay = 0.2 * (1 + 0.1) = 0.22
75+
# Allow tiny floating point drift
76+
self.assertEqual(len(sleep_mock.call_args_list), 1)
77+
self.assertAlmostEqual(sleep_mock.call_args_list[0].args[0], 0.22, places=6)
78+
79+
80+
if __name__ == "__main__":
81+
unittest.main()

aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def __exit__(self, exc_type, exc, tb):
8080
delay += random.uniform(0, 0.1 * delay)
8181
time.sleep(delay)
8282

83-
# Tell the iterator to yield another attempt
84-
raise _TryAgain()
83+
# Swallow the original exception so the iterator can decide to retry
84+
return True
8585

8686

8787
class RetryWithBackoff:
@@ -101,12 +101,12 @@ def __iter__(self):
101101
self._attempt = 1
102102
self._done = False
103103
while True:
104-
try:
105-
yield _Attempt(self)
106-
# If the with-block succeeded, stop iterating.
107-
if self._done:
108-
return
109-
return # defensive: stop if block exited cleanly
110-
except _TryAgain:
111-
self._attempt += 1
112-
continue
104+
# Yield a context manager for the attempt; if the with-block
105+
# raised but is retryable, __exit__ returns True to suppress it.
106+
yield _Attempt(self)
107+
# If the with-block succeeded, stop iterating.
108+
if self._done:
109+
return
110+
# Otherwise, a retryable exception occurred and was suppressed.
111+
# Move to the next attempt.
112+
self._attempt += 1

0 commit comments

Comments
 (0)