Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## v0.6.14

* Added a retry mechanism for HTTP failures raised as `URLError`.
* Defaults to `0` retries
* Configurable through `Context` (example `Context(**cfg, retries=3)` to set retries to 3)

## v0.6.13

* Improved debug logging for `exec`.
Expand Down
2 changes: 1 addition & 1 deletion railib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version_info__ = (0, 6, 13)
__version_info__ = (0, 6, 14)
__version__ = ".".join(map(str, __version_info__))
3 changes: 2 additions & 1 deletion railib/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def __init__(
region: str = None,
credentials=None,
audience: str = None,
retries: int = 0,
):
super().__init__(region=region, credentials=credentials)
super().__init__(region=region, credentials=credentials, retries=retries)
self.host = host
self.port = port or "443"
self.scheme = scheme or "https"
Expand Down
32 changes: 30 additions & 2 deletions railib/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
import json
import logging
from os import path
import socket
import time
from urllib.error import URLError
from urllib.parse import urlencode, urlsplit, quote
from urllib.request import Request, urlopen

Expand Down Expand Up @@ -45,10 +48,14 @@

# Context contains the state required to make rAI REST API calls.
class Context(object):
def __init__(self, region: str = None, credentials: Credentials = None):
def __init__(self, region: str = None, credentials: Credentials = None, retries: int = 0):
if retries < 0:
raise ValueError("Retries must be a non-negative integer")

self.region = region or "us-east"
self.credentials = credentials
self.service = "transaction"
self.retries = retries


# Answers if the keys of the passed dict contain a case insensitive match
Expand Down Expand Up @@ -211,6 +218,27 @@ def _authenticate(ctx: Context, req: Request) -> Request:
raise Exception("unknown credential type")


# Issues an HTTP request and retries if failed due to URLError.
def _urlopen_with_retry(req: Request, retries: int = 0):
if retries < 0:
raise ValueError("Retries must be a non-negative integer")

attempts = retries + 1

for attempt in range(attempts):
try:
return urlopen(req)
except URLError as e:
if isinstance(e.reason, socket.timeout):
logger.warning(f"Timeout occurred (attempt {attempt + 1}/{attempts}): {req.full_url}")
else:
logger.warning(f"URLError occurred {e.reason} (attempt {attempt + 1}/{attempts}): {req.full_url}")

if attempt == attempts - 1:
logger.error(f"Failed to connect to {req.full_url} after {attempts} attempt{'s' if attempts > 1 else ''}")
raise e


# Issues an RAI REST API request, and returns response contents if successful.
def request(ctx: Context, method: str, url: str, headers={}, data=None, **kwargs):
headers = _default_headers(url, headers)
Expand All @@ -220,7 +248,7 @@ def request(ctx: Context, method: str, url: str, headers={}, data=None, **kwargs
req = Request(method=method, url=url, headers=headers, data=data)
req = _authenticate(ctx, req)
_print_request(req)
rsp = urlopen(req)
rsp = _urlopen_with_retry(req, ctx.retries)

# logging
content_type = headers["Content-Type"] if "Content-Type" in headers else ""
Expand Down
65 changes: 65 additions & 0 deletions test/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import socket
import unittest
from unittest.mock import patch, MagicMock
from urllib.error import URLError
from urllib.request import Request

from railib import api
from railib.rest import _urlopen_with_retry


class TestPolling(unittest.TestCase):
Expand All @@ -23,5 +28,65 @@ def test_validation(self):
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1)


@patch('railib.rest.urlopen')
class TestURLOpenWithRetry(unittest.TestCase):

def test_successful_response(self, mock_urlopen):
# Set up the mock urlopen to return a successful response
mock_response = MagicMock()
mock_urlopen.return_value = mock_response
mock_response.read.return_value = b'Hello, World!'

req = Request('https://example.com')

response = _urlopen_with_retry(req)
self.assertEqual(response.read(), b'Hello, World!')
mock_urlopen.assert_called_once_with(req)

response = _urlopen_with_retry(req, 3)
self.assertEqual(response.read(), b'Hello, World!')
self.assertEqual(mock_urlopen.call_count, 2)

def test_negative_retries(self, _):
req = Request('https://example.com')

with self.assertRaises(Exception) as e:
_urlopen_with_retry(req, -1)

self.assertIn("Retries must be a non-negative integer", str(e.exception))

def test_timeout_retry(self, mock_urlopen):
# Set up the mock urlopen to raise a socket timeout error
mock_urlopen.side_effect = URLError(socket.timeout())

req = Request('https://example.com')
with self.assertLogs() as log:
with self.assertRaises(Exception):
_urlopen_with_retry(req, 2)

self.assertEqual(mock_urlopen.call_count, 3) # Expect 1 original call and 2 calls for retries
self.assertEqual(len(log.output), 4) # Expect 3 log messages for retries and 1 for failure to connect
self.assertIn('Timeout occurred', log.output[0])
self.assertIn('Timeout occurred', log.output[1])
self.assertIn('Timeout occurred', log.output[2])
self.assertIn('Failed to connect to', log.output[3])

def test_other_error_retry(self, mock_urlopen):
# Set up the mock urlopen to raise a non-timeout URLError
mock_urlopen.side_effect = URLError('Some other error')

req = Request('https://example.com')
with self.assertLogs() as log:
with self.assertRaises(Exception):
_urlopen_with_retry(req, retries=2)

self.assertEqual(mock_urlopen.call_count, 3) # Expect 3 calls for retries
self.assertEqual(len(log.output), 4) # Expect 3 log messages for retries and 1 for failure to connect
self.assertIn('URLError occurred', log.output[0])
self.assertIn('URLError occurred', log.output[1])
self.assertIn('URLError occurred', log.output[2])
self.assertIn('Failed to connect to', log.output[3])


if __name__ == '__main__':
unittest.main()