diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_retry_with_backoff.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_retry_with_backoff.py new file mode 100644 index 0000000000..58f22052d0 --- /dev/null +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/tests/test_retry_with_backoff.py @@ -0,0 +1,81 @@ +import sys +import unittest +from unittest.mock import call, patch + + +# Ensure package import when running from repo root +sys.path.insert(0, "aws/lambda/pytorch-auto-revert") + +from pytorch_auto_revert.utils import RetryWithBackoff # noqa: E402 + + +def run_with_retry(op, **kwargs): + """Helper mirroring how RetryWithBackoff is used in code. + + for attempt in RetryWithBackoff(...): + with attempt: + return op() + """ + for attempt in RetryWithBackoff(**kwargs): + with attempt: + return op() + + +EX = ValueError("boom") + + +class UnstableOp: + def __init__(self, fail_times: int, exc: Exception = EX): + self.fail_times = fail_times + self.calls = 0 + self.exc = exc + + def __call__(self): + self.calls += 1 + if self.calls <= self.fail_times: + raise self.exc + return 42 + + +class TestRetryWithBackoff(unittest.TestCase): + @patch("pytorch_auto_revert.utils.time.sleep") + def test_success_first_try_no_sleep(self, sleep_mock): + op = UnstableOp(fail_times=0) + res = run_with_retry(op, max_retries=3, base_delay=0.1, jitter=False) + self.assertEqual(res, 42) + self.assertEqual(op.calls, 1) + sleep_mock.assert_not_called() + + @patch("pytorch_auto_revert.utils.time.sleep") + def test_eventual_success_after_retries(self, sleep_mock): + op = UnstableOp(fail_times=2) + res = run_with_retry(op, max_retries=5, base_delay=0.1, jitter=False) + self.assertEqual(res, 42) + self.assertEqual(op.calls, 3) + # Backoff without jitter: 0.1, 0.2 + self.assertEqual(sleep_mock.call_args_list, [call(0.1), call(0.2)]) + + @patch("pytorch_auto_revert.utils.time.sleep") + def test_raises_after_max_retries(self, sleep_mock): + op = UnstableOp(fail_times=10) + with self.assertRaises(ValueError): + run_with_retry(op, max_retries=3, base_delay=0.1, jitter=False) + # Two sleeps for attempts 1 and 2; none after final failed attempt + self.assertEqual(sleep_mock.call_args_list, [call(0.1), call(0.2)]) + self.assertEqual(op.calls, 3) + + @patch("pytorch_auto_revert.utils.random.uniform", side_effect=lambda a, b: b) + @patch("pytorch_auto_revert.utils.time.sleep") + def test_jitter_applied_to_backoff(self, sleep_mock, _uniform_mock): + op = UnstableOp(fail_times=1) + res = run_with_retry(op, max_retries=3, base_delay=0.2, jitter=True) + self.assertEqual(res, 42) + self.assertEqual(op.calls, 2) + # With max jitter (10%), expected delay = 0.2 * (1 + 0.1) = 0.22 + # Allow tiny floating point drift + self.assertEqual(len(sleep_mock.call_args_list), 1) + self.assertAlmostEqual(sleep_mock.call_args_list[0].args[0], 0.22, places=6) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py index 4c4b7285ce..88756a70a2 100644 --- a/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py +++ b/aws/lambda/pytorch-auto-revert/pytorch_auto_revert/utils.py @@ -80,8 +80,8 @@ def __exit__(self, exc_type, exc, tb): delay += random.uniform(0, 0.1 * delay) time.sleep(delay) - # Tell the iterator to yield another attempt - raise _TryAgain() + # Swallow the original exception so the iterator can decide to retry + return True class RetryWithBackoff: @@ -101,12 +101,12 @@ def __iter__(self): self._attempt = 1 self._done = False while True: - try: - yield _Attempt(self) - # If the with-block succeeded, stop iterating. - if self._done: - return - return # defensive: stop if block exited cleanly - except _TryAgain: - self._attempt += 1 - continue + # Yield a context manager for the attempt; if the with-block + # raised but is retryable, __exit__ returns True to suppress it. + yield _Attempt(self) + # If the with-block succeeded, stop iterating. + if self._done: + return + # Otherwise, a retryable exception occurred and was suppressed. + # Move to the next attempt. + self._attempt += 1