Skip to content

Commit c81fd92

Browse files
authored
Add polling with overhead to python sdk (#109)
* add polling to python sdk * addressing PR comments * fix linter * run all tests * updated polling policies * fix * start_time is a timestamp * cleanup
1 parent 77d50e0 commit c81fd92

File tree

10 files changed

+94
-44
lines changed

10 files changed

+94
-44
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ jobs:
3535
CLIENT_SECRET: ${{ secrets.CLIENT_SECRET }}
3636
CLIENT_CREDENTIALS_URL: ${{ secrets.CLIENT_CREDENTIALS_URL }}
3737
run: |
38-
python tests/integration.py
38+
python -m unittest

examples/create_engine.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,16 @@
1616

1717
from argparse import ArgumentParser
1818
import json
19-
import time
2019
from urllib.request import HTTPError
2120
from railib import api, config, show
2221
from railib.api import EngineSize
2322

2423

25-
# Answers if the given state is a terminal state.
26-
def is_term_state(state: str) -> bool:
27-
return state == "PROVISIONED" or ("FAILED" in state)
28-
29-
3024
def run(engine: str, size: str, profile: str):
3125
cfg = config.read(profile=profile)
3226
ctx = api.Context(**cfg)
33-
rsp = api.create_engine(ctx, engine, EngineSize(size))
34-
while True: # wait for request to reach terminal state
35-
time.sleep(3)
36-
rsp = api.get_engine(ctx, engine)
37-
if is_term_state(rsp["state"]):
38-
break
39-
print(json.dumps(rsp, indent=2))
27+
api.create_engine_wait(ctx, engine, EngineSize(size))
28+
print(json.dumps(api.get_engine(ctx, engine), indent=2))
4029

4130

4231
if __name__ == "__main__":

railib/api.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,42 @@ def _parse_arrow_results(files: List[TransactionAsyncFile]):
325325
results.append({"relationId": file.name, "table": table})
326326
return results
327327

328+
# polling with specified overhead
329+
# delay is the overhead % of the time the transaction has been running so far
330+
331+
332+
def poll_with_specified_overhead(
333+
f,
334+
overhead_rate: float,
335+
start_time: int = time.time(),
336+
timeout: int = None,
337+
max_tries: int = None,
338+
max_delay: int = 120,
339+
):
340+
tries = 0
341+
max_time = time.time() + timeout if timeout else None
342+
343+
while True:
344+
if f():
345+
break
346+
347+
if max_tries is not None and tries >= max_tries:
348+
raise Exception(f'max tries {max_tries} exhausted')
349+
350+
if max_time is not None and time.time() >= max_time:
351+
raise Exception(f'timed out after {timeout} seconds')
352+
353+
tries += 1
354+
duration = min((time.time() - start_time) * overhead_rate, max_delay)
355+
if tries == 1:
356+
time.sleep(0.5)
357+
else:
358+
time.sleep(duration)
359+
360+
361+
def is_engine_term_state(state: str) -> bool:
362+
return state == "PROVISIONED" or ("FAILED" in state)
363+
328364

329365
def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs):
330366
data = {"region": ctx.region, "name": engine, "size": size.value}
@@ -333,6 +369,16 @@ def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, *
333369
return json.loads(rsp.read())
334370

335371

372+
def create_engine_wait(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs):
373+
create_engine(ctx, engine, size, **kwargs)
374+
poll_with_specified_overhead(
375+
lambda: is_engine_term_state(get_engine(ctx, engine)["state"]),
376+
overhead_rate=0.2,
377+
timeout=30 * 60,
378+
)
379+
return get_engine(ctx, engine)
380+
381+
336382
def create_user(ctx: Context, email: str, roles: List[Role] = None, **kwargs):
337383
rs = roles or []
338384
data = {"email": email, "roles": [r.value for r in rs]}
@@ -836,6 +882,7 @@ def exec(
836882
readonly: bool = True,
837883
**kwargs
838884
) -> TransactionAsyncResponse:
885+
start_time = time.time()
839886
txn = exec_async(ctx, database, engine, command, inputs=inputs, readonly=readonly)
840887
# in case of if short-path, return results directly, no need to poll for
841888
# state
@@ -844,15 +891,17 @@ def exec(
844891

845892
rsp = TransactionAsyncResponse()
846893
txn = get_transaction(ctx, txn.transaction["id"], **kwargs)
847-
while True:
848-
time.sleep(1)
849-
txn = get_transaction(ctx, txn["id"], **kwargs)
850-
if is_txn_term_state(txn["state"]):
851-
rsp.transaction = txn
852-
rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs)
853-
rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs)
854-
rsp.results = get_transaction_results(ctx, txn["id"], **kwargs)
855-
break
894+
895+
poll_with_specified_overhead(
896+
lambda: is_txn_term_state(get_transaction(ctx, txn["id"], **kwargs)["state"]),
897+
overhead_rate=0.2,
898+
start_time=start_time,
899+
)
900+
901+
rsp.transaction = get_transaction(ctx, txn["id"], **kwargs)
902+
rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs)
903+
rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs)
904+
rsp.results = get_transaction_results(ctx, txn["id"], **kwargs)
856905

857906
return rsp
858907

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
ed25519==1.5
22
grpcio-tools==1.47.0
3-
protobuf==3.20.1
3+
protobuf==3.20.2
44
pyarrow==6.0.1
55
requests-toolbelt==0.9.1
66

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
"ed25519==1.5",
3636
"pyarrow>=6.0.1",
3737
"requests-toolbelt==0.9.1",
38-
"protobuf==3.20.1"],
38+
"protobuf==3.20.2"],
3939
license="http://www.apache.org/licenses/LICENSE-2.0",
4040
long_description="Enables access to the RelationalAI REST APIs from Python",
4141
long_description_content_type="text/markdown",
File renamed without changes.
File renamed without changes.
File renamed without changes.

tests/integration.py renamed to test/test_integration.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import json
2-
from time import sleep
32
import unittest
43
import os
54
import uuid
@@ -8,22 +7,6 @@
87
from pathlib import Path
98
from railib import api, config
109

11-
# TODO: create_engine_wait should be added to API
12-
# with exponential backoff
13-
14-
15-
def create_engine_wait(ctx: api.Context, engine: str):
16-
state = api.create_engine(ctx, engine, headers=custom_headers)["compute"]["state"]
17-
18-
count = 0
19-
while not ("PROVISIONED" == state):
20-
if count > 12:
21-
return
22-
23-
count += 1
24-
sleep(30)
25-
state = api.get_engine(ctx, engine)["state"]
26-
2710

2811
# Get creds from env vars if exists
2912
client_id = os.getenv("CLIENT_ID")
@@ -58,8 +41,10 @@ def create_engine_wait(ctx: api.Context, engine: str):
5841

5942
class TestTransactionAsync(unittest.TestCase):
6043
def setUp(self):
61-
create_engine_wait(ctx, engine)
62-
api.create_database(ctx, dbname)
44+
rsp = api.create_engine_wait(ctx, engine, headers=custom_headers)
45+
self.assertEqual("PROVISIONED", rsp["state"])
46+
rsp = api.create_database(ctx, dbname)
47+
self.assertEqual("CREATED", rsp["database"]["state"])
6348

6449
def test_v2_exec(self):
6550
cmd = "x, x^2, x^3, x^4 from x in {1; 2; 3; 4; 5}"

test/test_unit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import unittest
2+
3+
from railib import api
4+
5+
6+
class TestPolling(unittest.TestCase):
7+
def test_timeout_exception(self):
8+
try:
9+
api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, timeout=1)
10+
except Exception as e:
11+
self.assertEqual('timed out after 1 seconds', str(e))
12+
13+
def test_max_tries_exception(self):
14+
try:
15+
api.poll_with_specified_overhead(lambda: False, overhead_rate=0.1, max_tries=1)
16+
except Exception as e:
17+
self.assertEqual('max tries 1 exhausted', str(e))
18+
19+
def test_validation(self):
20+
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1)
21+
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1)
22+
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, max_tries=1)
23+
api.poll_with_specified_overhead(lambda: True, overhead_rate=0.1, timeout=1, max_tries=1)
24+
25+
26+
if __name__ == '__main__':
27+
unittest.main()

0 commit comments

Comments
 (0)