Skip to content

Commit 918f11d

Browse files
authored
oauth token caching (#110)
* adding support to oauth token caching * fix linter * cleanup * addressing PR comments * make sure integration tests use cached token
1 parent c81fd92 commit 918f11d

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@ jobs:
3535
CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }}
3636
CLIENT_CREDENTIALS_URL: ${{ secrets.CLIENT_CREDENTIALS_URL }}
3737
run: |
38+
mkdir -p ~/.rai
3839
python -m unittest

railib/credentials.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,11 @@ def __init__(self, akey: str, pkey: str):
4040

4141
# Represents an OAuth access token.
4242
class AccessToken:
43-
def __init__(self, access_token: str, expires_in: int):
44-
self.token = access_token
43+
def __init__(self, access_token: str, scope: str, expires_in: int, created_on: float = time.time()):
44+
self.access_token = access_token
45+
self.scope = scope
4546
self.expires_in = expires_in
46-
self.created_on = round(time.time())
47+
self.created_on = created_on
4748

4849
def is_expired(self):
4950
return (

railib/rest.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from datetime import datetime
2020
import hashlib
2121
import json
22+
from os import path
2223
from urllib.parse import urlencode, urlsplit, quote
2324
from urllib.request import Request, urlopen
2425

@@ -40,6 +41,7 @@
4041
GRANT_TYPE_KEY = "grant_type"
4142
CLIENT_CREDENTIALS_KEY = "client_credentials"
4243
EXPIRES_IN_KEY = "expires_in"
44+
SCOPE = "scope"
4345

4446

4547
_empty = bytes("", encoding="utf8")
@@ -119,13 +121,50 @@ def _print_request(req: Request, level=0):
119121
print(json.dumps(json.loads(req.data.decode("utf8")), indent=2))
120122

121123

124+
def _cache_file() -> str:
125+
return path.join(path.expanduser('~'), '.rai', 'tokens.json')
126+
127+
128+
# Read oauth cache
129+
def _read_cache() -> dict:
130+
try:
131+
with open(_cache_file(), 'r') as cache:
132+
return json.loads(cache.read())
133+
except Exception:
134+
return {}
135+
136+
137+
# Read access token from cache
138+
def _read_token_cache(creds: ClientCredentials) -> AccessToken:
139+
try:
140+
cache = _read_cache()
141+
return AccessToken(**cache[creds.client_id])
142+
except Exception:
143+
return None
144+
145+
146+
# write access token to cache
147+
def _write_token_cache(creds: ClientCredentials):
148+
try:
149+
cache = _read_cache()
150+
cache[creds.client_id] = creds.access_token
151+
152+
with open(_cache_file(), 'w') as f:
153+
f.write(json.dumps(cache, default=vars))
154+
except Exception:
155+
pass
156+
157+
122158
# Returns the current access token if valid, otherwise requests new token.
123159
def _get_access_token(ctx: Context, url: str) -> AccessToken:
124160
creds = ctx.credentials
125161
assert isinstance(creds, ClientCredentials)
126162
if creds.access_token is None or creds.access_token.is_expired():
127-
creds.access_token = _request_access_token(ctx, url)
128-
return creds.access_token.token
163+
creds.access_token = _read_token_cache(creds)
164+
if creds.access_token is None or creds.access_token.is_expired():
165+
creds.access_token = _request_access_token(ctx, url)
166+
_write_token_cache(creds)
167+
return creds.access_token.access_token
129168

130169

131170
def _request_access_token(ctx: Context, url: str) -> AccessToken:
@@ -158,7 +197,8 @@ def _request_access_token(ctx: Context, url: str) -> AccessToken:
158197
token = result.get(ACCESS_KEY_TOKEN_KEY, None)
159198
if token is not None:
160199
expires_in = result.get(EXPIRES_IN_KEY, None)
161-
return AccessToken(token, expires_in)
200+
scope = result.get(SCOPE, None)
201+
return AccessToken(token, scope, expires_in)
162202
raise Exception("failed to get the access token")
163203

164204

0 commit comments

Comments
 (0)