Skip to content

Commit 77d50e0

Browse files
authored
Add kwargs parameters support to SDK API (#107)
* adding kwargs params * add custom headers to integration tests * custom headers
1 parent 898e507 commit 77d50e0

File tree

2 files changed

+62
-58
lines changed

2 files changed

+62
-58
lines changed

railib/api.py

Lines changed: 59 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -326,26 +326,26 @@ def _parse_arrow_results(files: List[TransactionAsyncFile]):
326326
return results
327327

328328

329-
def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS):
329+
def create_engine(ctx: Context, engine: str, size: EngineSize = EngineSize.XS, **kwargs):
330330
data = {"region": ctx.region, "name": engine, "size": size.value}
331331
url = _mkurl(ctx, PATH_ENGINE)
332-
rsp = rest.put(ctx, url, data)
332+
rsp = rest.put(ctx, url, data, **kwargs)
333333
return json.loads(rsp.read())
334334

335335

336-
def create_user(ctx: Context, email: str, roles: List[Role] = None):
336+
def create_user(ctx: Context, email: str, roles: List[Role] = None, **kwargs):
337337
rs = roles or []
338338
data = {"email": email, "roles": [r.value for r in rs]}
339339
url = _mkurl(ctx, PATH_USER)
340-
rsp = rest.post(ctx, url, data)
340+
rsp = rest.post(ctx, url, data, **kwargs)
341341
return json.loads(rsp.read())
342342

343343

344-
def create_oauth_client(ctx: Context, name: str, permissions: List[Permission] = None):
344+
def create_oauth_client(ctx: Context, name: str, permissions: List[Permission] = None, **kwargs):
345345
ps = permissions or []
346346
data = {"name": name, "permissions": ps}
347347
url = _mkurl(ctx, PATH_OAUTH_CLIENT)
348-
rsp = rest.post(ctx, url, data)
348+
rsp = rest.post(ctx, url, data, **kwargs)
349349
return json.loads(rsp.read())["client"]
350350

351351

@@ -358,78 +358,78 @@ def _create_mode(source_database: str, overwrite: bool) -> Mode:
358358
return result
359359

360360

361-
def delete_database(ctx: Context, database: str) -> dict:
361+
def delete_database(ctx: Context, database: str, **kwargs) -> dict:
362362
data = {"name": database}
363363
url = _mkurl(ctx, PATH_DATABASE)
364-
rsp = rest.delete(ctx, url, data)
364+
rsp = rest.delete(ctx, url, data, **kwargs)
365365
return json.loads(rsp.read())
366366

367367

368-
def delete_engine(ctx: Context, engine: str) -> dict:
368+
def delete_engine(ctx: Context, engine: str, **kwargs) -> dict:
369369
data = {"name": engine}
370370
url = _mkurl(ctx, PATH_ENGINE)
371-
rsp = rest.delete(ctx, url, data)
371+
rsp = rest.delete(ctx, url, data, **kwargs)
372372
return json.loads(rsp.read())
373373

374374

375-
def delete_user(ctx: Context, id: str) -> dict:
375+
def delete_user(ctx: Context, id: str, **kwargs) -> dict:
376376
url = _mkurl(ctx, f"{PATH_USER}/{id}")
377-
rsp = rest.delete(ctx, url, None)
377+
rsp = rest.delete(ctx, url, None, **kwargs)
378378
return json.loads(rsp.read())
379379

380380

381-
def disable_user(ctx: Context, userid: str) -> dict:
382-
return update_user(ctx, userid, status="INACTIVE")
381+
def disable_user(ctx: Context, userid: str, **kwargs) -> dict:
382+
return update_user(ctx, userid, status="INACTIVE", **kwargs)
383383

384384

385-
def delete_oauth_client(ctx: Context, id: str) -> dict:
385+
def delete_oauth_client(ctx: Context, id: str, **kwargs) -> dict:
386386
url = _mkurl(ctx, f"{PATH_OAUTH_CLIENT}/{id}")
387-
rsp = rest.delete(ctx, url, None)
387+
rsp = rest.delete(ctx, url, None, **kwargs)
388388
return json.loads(rsp.read())
389389

390390

391-
def enable_user(ctx: Context, userid: str) -> dict:
392-
return update_user(ctx, userid, status="ACTIVE")
391+
def enable_user(ctx: Context, userid: str, **kwargs) -> dict:
392+
return update_user(ctx, userid, status="ACTIVE", **kwargs)
393393

394394

395-
def get_engine(ctx: Context, engine: str) -> dict:
396-
return _get_resource(ctx, PATH_ENGINE, name=engine, deleted_on="", key="computes")
395+
def get_engine(ctx: Context, engine: str, **kwargs) -> dict:
396+
return _get_resource(ctx, PATH_ENGINE, name=engine, deleted_on="", key="computes", **kwargs)
397397

398398

399-
def get_database(ctx: Context, database: str) -> dict:
400-
return _get_resource(ctx, PATH_DATABASE, name=database, key="databases")
399+
def get_database(ctx: Context, database: str, **kwargs) -> dict:
400+
return _get_resource(ctx, PATH_DATABASE, name=database, key="databases", **kwargs)
401401

402402

403-
def get_oauth_client(ctx: Context, id: str) -> dict:
404-
return _get_resource(ctx, f"{PATH_OAUTH_CLIENT}/{id}", key="client")
403+
def get_oauth_client(ctx: Context, id: str, **kwargs) -> dict:
404+
return _get_resource(ctx, f"{PATH_OAUTH_CLIENT}/{id}", key="client", **kwargs)
405405

406406

407-
def get_transaction(ctx: Context, id: str) -> dict:
408-
return _get_resource(ctx, f"{PATH_TRANSACTIONS}/{id}", key="transaction")
407+
def get_transaction(ctx: Context, id: str, **kwargs) -> dict:
408+
return _get_resource(ctx, f"{PATH_TRANSACTIONS}/{id}", key="transaction", **kwargs)
409409

410410

411-
def get_transaction_metadata(ctx: Context, id: str) -> list:
411+
def get_transaction_metadata(ctx: Context, id: str, **kwargs) -> list:
412412
headers = {"Accept": "application/x-protobuf"}
413413
url = _mkurl(ctx, f"{PATH_TRANSACTIONS}/{id}/metadata")
414-
rsp = rest.get(ctx, url, headers=headers)
414+
rsp = rest.get(ctx, url, headers=headers, **kwargs)
415415
content_type = rsp.headers.get("content-type", "")
416416
if "application/x-protobuf" in content_type:
417417
return _parse_metadata_proto(rsp.read())
418418

419419
raise Exception(f"invalid content type for metadata proto: {content_type}")
420420

421421

422-
def list_transactions(ctx: Context) -> list:
423-
return _get_collection(ctx, PATH_TRANSACTIONS, key="transactions")
422+
def list_transactions(ctx: Context, **kwargs) -> list:
423+
return _get_collection(ctx, PATH_TRANSACTIONS, key="transactions", **kwargs)
424424

425425

426-
def get_transaction_problems(ctx: Context, id: str) -> list:
427-
return _get_collection(ctx, f"{PATH_TRANSACTIONS}/{id}/problems")
426+
def get_transaction_problems(ctx: Context, id: str, **kwargs) -> list:
427+
return _get_collection(ctx, f"{PATH_TRANSACTIONS}/{id}/problems", **kwargs)
428428

429429

430-
def get_transaction_results(ctx: Context, id: str) -> list:
430+
def get_transaction_results(ctx: Context, id: str, **kwargs) -> list:
431431
url = _mkurl(ctx, f"{PATH_TRANSACTIONS}/{id}/results")
432-
rsp = rest.get(ctx, url)
432+
rsp = rest.get(ctx, url, **kwargs)
433433
content_type = rsp.headers.get("content-type", "")
434434
if "multipart/form-data" in content_type:
435435
parts = _parse_multipart_form(content_type, rsp.read())
@@ -442,20 +442,20 @@ def get_transaction_results(ctx: Context, id: str) -> list:
442442
# deprecated, get_transaction_results should be called instead
443443

444444

445-
def get_transaction_results_and_problems(ctx: Context, id: str) -> list:
445+
def get_transaction_results_and_problems(ctx: Context, id: str, **kwargs) -> list:
446446
rsp = TransactionAsyncResponse()
447-
rsp.problems = get_transaction_problems(ctx, id)
448-
rsp.results = get_transaction_results(ctx, id)
447+
rsp.problems = get_transaction_problems(ctx, id, **kwargs)
448+
rsp.results = get_transaction_results(ctx, id, **kwargs)
449449
return rsp
450450

451451

452-
def cancel_transaction(ctx: Context, id: str) -> dict:
453-
rsp = rest.post(ctx, _mkurl(ctx, f"{PATH_TRANSACTIONS}/{id}/cancel"), {})
452+
def cancel_transaction(ctx: Context, id: str, **kwargs) -> dict:
453+
rsp = rest.post(ctx, _mkurl(ctx, f"{PATH_TRANSACTIONS}/{id}/cancel"), {}, **kwargs)
454454
return json.loads(rsp.read())
455455

456456

457-
def get_user(ctx: Context, userid: str) -> dict:
458-
return _get_resource(ctx, f"{PATH_USER}/{userid}", name=userid)
457+
def get_user(ctx: Context, userid: str, **kwargs) -> dict:
458+
return _get_resource(ctx, f"{PATH_USER}/{userid}", name=userid, **kwargs)
459459

460460

461461
def list_engines(ctx: Context, state=None) -> list:
@@ -472,22 +472,22 @@ def list_databases(ctx: Context, state=None) -> list:
472472
return _get_collection(ctx, PATH_DATABASE, key="databases", **kwargs)
473473

474474

475-
def list_users(ctx: Context) -> list:
476-
return _get_collection(ctx, PATH_USER, key="users")
475+
def list_users(ctx: Context, **kwargs) -> list:
476+
return _get_collection(ctx, PATH_USER, key="users", **kwargs)
477477

478478

479-
def list_oauth_clients(ctx: Context) -> list:
480-
return _get_collection(ctx, PATH_OAUTH_CLIENT, key="clients")
479+
def list_oauth_clients(ctx: Context, **kwargs) -> list:
480+
return _get_collection(ctx, PATH_OAUTH_CLIENT, key="clients", **kwargs)
481481

482482

483-
def update_user(ctx: Context, userid: str, status: str = None, roles=None):
483+
def update_user(ctx: Context, userid: str, status: str = None, roles=None, **kwargs):
484484
data = {}
485485
if status is not None:
486486
data["status"] = status
487487
if roles is not None:
488488
data["roles"] = roles
489489
url = _mkurl(ctx, f"{PATH_USER}/{userid}")
490-
rsp = rest.patch(ctx, url, data)
490+
rsp = rest.patch(ctx, url, data, **kwargs)
491491
return json.loads(rsp.read())
492492

493493

@@ -578,14 +578,14 @@ def data(self):
578578
result["engine_name"] = self.engine
579579
return result
580580

581-
def run(self, ctx: Context, command: str, language: str, inputs: dict = None) -> Union[dict, list]:
581+
def run(self, ctx: Context, command: str, language: str, inputs: dict = None, **kwargs) -> Union[dict, list]:
582582
data = self.data
583583
data["query"] = command
584584
data["language"] = language
585585
if inputs is not None:
586586
inputs = [_query_action_input(k, v) for k, v in inputs.items()]
587587
data["v1_inputs"] = inputs
588-
rsp = rest.post(ctx, _mkurl(ctx, PATH_TRANSACTIONS), data)
588+
rsp = rest.post(ctx, _mkurl(ctx, PATH_TRANSACTIONS), data, **kwargs)
589589
content_type = rsp.headers.get("content-type", None)
590590
content = rsp.read()
591591
# todo: response model should be based on status code (200 v. 201)
@@ -668,10 +668,10 @@ def _list_models(ctx: Context, database: str, engine: str) -> dict:
668668
return models
669669

670670

671-
def create_database(ctx: Context, database: str, source: str = None) -> dict:
671+
def create_database(ctx: Context, database: str, source: str = None, **kwargs) -> dict:
672672
data = {"name": database, "source_name": source}
673673
url = _mkurl(ctx, PATH_DATABASE)
674-
rsp = rest.put(ctx, url, data)
674+
rsp = rest.put(ctx, url, data, **kwargs)
675675
return json.loads(rsp.read())
676676

677677

@@ -834,6 +834,7 @@ def exec(
834834
command: str,
835835
inputs: dict = None,
836836
readonly: bool = True,
837+
**kwargs
837838
) -> TransactionAsyncResponse:
838839
txn = exec_async(ctx, database, engine, command, inputs=inputs, readonly=readonly)
839840
# in case of if short-path, return results directly, no need to poll for
@@ -842,15 +843,15 @@ def exec(
842843
return txn
843844

844845
rsp = TransactionAsyncResponse()
845-
txn = get_transaction(ctx, txn.transaction["id"])
846+
txn = get_transaction(ctx, txn.transaction["id"], **kwargs)
846847
while True:
847848
time.sleep(1)
848-
txn = get_transaction(ctx, txn["id"])
849+
txn = get_transaction(ctx, txn["id"], **kwargs)
849850
if is_txn_term_state(txn["state"]):
850851
rsp.transaction = txn
851-
rsp.metadata = get_transaction_metadata(ctx, txn["id"])
852-
rsp.problems = get_transaction_problems(ctx, txn["id"])
853-
rsp.results = get_transaction_results(ctx, txn["id"])
852+
rsp.metadata = get_transaction_metadata(ctx, txn["id"], **kwargs)
853+
rsp.problems = get_transaction_problems(ctx, txn["id"], **kwargs)
854+
rsp.results = get_transaction_results(ctx, txn["id"], **kwargs)
854855
break
855856

856857
return rsp
@@ -864,9 +865,10 @@ def exec_async(
864865
language: str = "",
865866
readonly: bool = True,
866867
inputs: dict = None,
868+
**kwargs,
867869
) -> TransactionAsyncResponse:
868870
tx = TransactionAsync(database, engine, readonly=readonly)
869-
rsp = tx.run(ctx, command, language=language, inputs=inputs)
871+
rsp = tx.run(ctx, command, language=language, inputs=inputs, **kwargs)
870872

871873
if isinstance(rsp, dict):
872874
return TransactionAsyncResponse(rsp, None, None, None)

tests/integration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from time import sleep
23
import unittest
34
import os
@@ -12,7 +13,7 @@
1213

1314

1415
def create_engine_wait(ctx: api.Context, engine: str):
15-
state = api.create_engine(ctx, engine)["compute"]["state"]
16+
state = api.create_engine(ctx, engine, headers=custom_headers)["compute"]["state"]
1617

1718
count = 0
1819
while not ("PROVISIONED" == state):
@@ -28,6 +29,7 @@ def create_engine_wait(ctx: api.Context, engine: str):
2829
client_id = os.getenv("CLIENT_ID")
2930
client_secret = os.getenv("CLIENT_SECRET")
3031
client_credentials_url = os.getenv("CLIENT_CREDENTIALS_URL")
32+
custom_headers = json.loads(os.getenv('CUSTOM_HEADERS', '{}'))
3133

3234
if client_id is None:
3335
print("using config from path")

0 commit comments

Comments
 (0)