|
19 | 19 | from datetime import datetime |
20 | 20 | import hashlib |
21 | 21 | import json |
| 22 | +from os import path |
22 | 23 | from urllib.parse import urlencode, urlsplit, quote |
23 | 24 | from urllib.request import Request, urlopen |
24 | 25 |
|
|
40 | 41 | GRANT_TYPE_KEY = "grant_type" |
41 | 42 | CLIENT_CREDENTIALS_KEY = "client_credentials" |
42 | 43 | EXPIRES_IN_KEY = "expires_in" |
| 44 | +SCOPE = "scope" |
43 | 45 |
|
44 | 46 |
|
45 | 47 | _empty = bytes("", encoding="utf8") |
@@ -119,13 +121,50 @@ def _print_request(req: Request, level=0): |
119 | 121 | print(json.dumps(json.loads(req.data.decode("utf8")), indent=2)) |
120 | 122 |
|
121 | 123 |
|
| 124 | +def _cache_file() -> str: |
| 125 | + return path.join(path.expanduser('~'), '.rai', 'tokens.json') |
| 126 | + |
| 127 | + |
| 128 | +# Read oauth cache |
| 129 | +def _read_cache() -> dict: |
| 130 | + try: |
| 131 | + with open(_cache_file(), 'r') as cache: |
| 132 | + return json.loads(cache.read()) |
| 133 | + except Exception: |
| 134 | + return {} |
| 135 | + |
| 136 | + |
| 137 | +# Read access token from cache |
| 138 | +def _read_token_cache(creds: ClientCredentials) -> AccessToken: |
| 139 | + try: |
| 140 | + cache = _read_cache() |
| 141 | + return AccessToken(**cache[creds.client_id]) |
| 142 | + except Exception: |
| 143 | + return None |
| 144 | + |
| 145 | + |
| 146 | +# write access token to cache |
| 147 | +def _write_token_cache(creds: ClientCredentials): |
| 148 | + try: |
| 149 | + cache = _read_cache() |
| 150 | + cache[creds.client_id] = creds.access_token |
| 151 | + |
| 152 | + with open(_cache_file(), 'w') as f: |
| 153 | + f.write(json.dumps(cache, default=vars)) |
| 154 | + except Exception: |
| 155 | + pass |
| 156 | + |
| 157 | + |
122 | 158 | # Returns the current access token if valid, otherwise requests new token. |
123 | 159 | def _get_access_token(ctx: Context, url: str) -> AccessToken: |
124 | 160 | creds = ctx.credentials |
125 | 161 | assert isinstance(creds, ClientCredentials) |
126 | 162 | if creds.access_token is None or creds.access_token.is_expired(): |
127 | | - creds.access_token = _request_access_token(ctx, url) |
128 | | - return creds.access_token.token |
| 163 | + creds.access_token = _read_token_cache(creds) |
| 164 | + if creds.access_token is None or creds.access_token.is_expired(): |
| 165 | + creds.access_token = _request_access_token(ctx, url) |
| 166 | + _write_token_cache(creds) |
| 167 | + return creds.access_token.access_token |
129 | 168 |
|
130 | 169 |
|
131 | 170 | def _request_access_token(ctx: Context, url: str) -> AccessToken: |
@@ -158,7 +197,8 @@ def _request_access_token(ctx: Context, url: str) -> AccessToken: |
158 | 197 | token = result.get(ACCESS_KEY_TOKEN_KEY, None) |
159 | 198 | if token is not None: |
160 | 199 | expires_in = result.get(EXPIRES_IN_KEY, None) |
161 | | - return AccessToken(token, expires_in) |
| 200 | + scope = result.get(SCOPE, None) |
| 201 | + return AccessToken(token, scope, expires_in) |
162 | 202 | raise Exception("failed to get the access token") |
163 | 203 |
|
164 | 204 |
|
|
0 commit comments