Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions railib/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ def __init__(self, akey: str, pkey: str):

# Represents an OAuth access token.
class AccessToken:
def __init__(self, access_token: str, expires_in: int):
self.token = access_token
def __init__(self, access_token: str, scope: str, expires_in: int, created_on: float = time.time()):
self.access_token = access_token
self.scope = scope
self.expires_in = expires_in
self.created_on = round(time.time())
self.created_on = created_on

def is_expired(self):
return (
Expand Down
48 changes: 45 additions & 3 deletions railib/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime
import hashlib
import json
from os import path
from urllib.parse import urlencode, urlsplit, quote
from urllib.request import Request, urlopen

Expand All @@ -40,6 +41,7 @@
GRANT_TYPE_KEY = "grant_type"
CLIENT_CREDENTIALS_KEY = "client_credentials"
EXPIRES_IN_KEY = "expires_in"
SCOPE = "scope"


_empty = bytes("", encoding="utf8")
Expand Down Expand Up @@ -119,13 +121,52 @@ def _print_request(req: Request, level=0):
print(json.dumps(json.loads(req.data.decode("utf8")), indent=2))


def _cache_file() -> str:
return path.join(path.expanduser('~'), '.rai', 'tokens.json')


# Read oauth cache
def _read_cache() -> dict:
try:
with open(_cache_file(), 'r') as cache:
return json.loads(cache.read())
except Exception:
return None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we swallow these exceptions in these cache functions (the go sdk does sth similar as well). Is that on purpose?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as far as I do understand, if the cache is not accessible for some reason we should not block the user and instead we request a new token



# Read access token from cache
def _read_token_cache(creds: ClientCredentials) -> AccessToken:
try:
cache = _read_cache()
return AccessToken(**cache[creds.client_id])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume if the cache doesn't contain that client id an exception is thrown?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep that's true and we just by-pass it, if _read_token_cache function return None we will request a new token anyway

except Exception:
return None


# write access token to cache
def _write_token_cache(creds: ClientCredentials):
try:
cache = _read_cache()
if cache:
cache[creds.client_id] = creds.access_token
else:
cache = {creds.client_id: creds.access_token}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could avoid this branch by returning {} from _read_cache's exception path

with open(_cache_file(), 'w') as f:
f.write(json.dumps(cache, default=vars))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's default=vars needed for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is one way to make python class json serializable

except Exception:
pass


# Returns the current access token if valid, otherwise requests new token.
def _get_access_token(ctx: Context, url: str) -> AccessToken:
creds = ctx.credentials
assert isinstance(creds, ClientCredentials)
if creds.access_token is None or creds.access_token.is_expired():
creds.access_token = _request_access_token(ctx, url)
return creds.access_token.token
creds.access_token = _read_token_cache(creds)
if creds.access_token is None or creds.access_token.is_expired():
creds.access_token = _request_access_token(ctx, url)
_write_token_cache(creds)
return creds.access_token.access_token


def _request_access_token(ctx: Context, url: str) -> AccessToken:
Expand Down Expand Up @@ -158,7 +199,8 @@ def _request_access_token(ctx: Context, url: str) -> AccessToken:
token = result.get(ACCESS_KEY_TOKEN_KEY, None)
if token is not None:
expires_in = result.get(EXPIRES_IN_KEY, None)
return AccessToken(token, expires_in)
scope = result.get(SCOPE, None)
return AccessToken(token, scope, expires_in)
raise Exception("failed to get the access token")


Expand Down