diff --git a/.github/workflows/lint-flake8.yml b/.github/workflows/lint-flake8.yml index 7cb4596b..f6787b9d 100644 --- a/.github/workflows/lint-flake8.yml +++ b/.github/workflows/lint-flake8.yml @@ -9,10 +9,10 @@ jobs: steps: - uses: actions/checkout@v1 - - name: Set up Python 3.6 + - name: Set up Python 3.7 uses: actions/setup-python@v1 with: - python-version: 3.6 + python-version: 3.7 - name: Install dependencies run: | python -m pip install -U pip setuptools diff --git a/.github/workflows/typecheck-mypy.yml b/.github/workflows/typecheck-mypy.yml index 32ec4557..1293625b 100644 --- a/.github/workflows/typecheck-mypy.yml +++ b/.github/workflows/typecheck-mypy.yml @@ -24,4 +24,4 @@ jobs: else echo "::add-matcher::.github/workflows/mypy-matcher.json" fi - python -m mypy --no-color-output src/ai/backend || exit 0 + python -m mypy --no-color-output src/ai/backend diff --git a/.travis.yml b/.travis.yml index dc0b5cbd..a8015a7b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,6 @@ stages: # build matrix for test stage python: - - "3.6" - "3.7" - "3.8" os: @@ -47,13 +46,6 @@ jobs: username: "__token__" password: secure: "nXxyiMTQZgdLnpw+hZBm2nHtlMV9prg5bl+3lB4Q/pnWWaW4VvAU6U2Lw/gljAaD3jxOV+RWKOCdt6ZWmQc9M8Fh5mcTlq9IjcMgk0R39onsP7YP7UJUh7saqZG1EkruglCHwCjcz3XwmRyJ+GKIANDH6jRooEmGQt/b8sR0ZIuMxx9ANNPozGEIxcrEqkO2CT1NQzEYc969danjYoyRImDUyDLKTJKd5ZkC7vwmT9z1chm0oxbZMdBJbL26g3TEr7dq1gQAiiLB5lhFVxklWqYlthlWl5qvmtgcn9ZNh1OA2WF8jTwDaafXoYHOotfq82ASRZI3dOckJQM6bsEJEPh5tTIvJJNxMmPTomHCRmc8/sNfOOoPPTLhjXVGE1BxL4u3DXZt0VAw80mkQseXu9wtzNEdZqCxGlSzycyut4cLtXpWXZDN/zqDYczAPUAYeRi2XbxT06OHhczmtn7WPGp2O/HYrXzHrMjAho0tNdch/62hJycEYAMRN0iQSnB2Gs2Ja7h6WUmf6lw2P4qS8gOSKuBJ3Z5Q0glbS2m28oCDZjP6zBqCwYucMZfUqF/aKiVei0NQp1dvjKUBqMJVogesuOAvtDVo+wN3rp2pcTntEKJHqYbNL9fOzwErJM8r/ZUMGC0HkdyTcnPS7uGkRF5WlzFl1cVBNmHzburc+N4=" - allow_failures: - - python: "3.8" - os: linux - fast_finish: true - # exclude the duplicate default test stage - exclude: - - python: "3.6" notifications: webhooks: diff --git a/appveyor.yml b/appveyor.yml index 08e90819..1cd5a84f 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,8 +1,6 @@ version: 1.0.dev{build} environment: matrix: - - PYTHON: "C:\\Python36" - - PYTHON: "C:\\Python36-x64" - PYTHON: "C:\\Python37" - PYTHON: "C:\\Python37-x64" - PYTHON: "C:\\Python38" diff --git a/changes/97.breaking b/changes/97.breaking new file mode 100644 index 00000000..09dd8aaf --- /dev/null +++ b/changes/97.breaking @@ -0,0 +1 @@ +Drop support for Python 3.6 diff --git a/changes/97.feature b/changes/97.feature new file mode 100644 index 00000000..8f4dcac4 --- /dev/null +++ b/changes/97.feature @@ -0,0 +1 @@ +Support APIv5's new GraphQL schema and kernel/session naming changes diff --git a/setup.cfg b/setup.cfg index 83303562..25637b5f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,3 +13,7 @@ norecursedirs = venv virtualenv .git timeout = 5 markers = integration: Test cases that require real manager (and agents) to be running on http://localhost:8081. + +[mypy] +ignore_missing_imports = true +namespace_packages = true diff --git a/setup.py b/setup.py index 7f450b6a..4e38fa6a 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ 'tabulate~=0.8.6', 'tqdm~=4.42', 'yarl~=1.4.2', + 'typing-extensions~=3.7.4', ] build_requires = [ 'wheel>=0.34.2', @@ -82,7 +83,6 @@ def read_src_version(): 'Intended Audience :: Developers', 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Operating System :: POSIX', @@ -94,7 +94,7 @@ def read_src_version(): ], package_dir={'': 'src'}, packages=find_namespace_packages(where='src', include='ai.backend.*'), - python_requires='>=3.6', + python_requires='>=3.7', setup_requires=setup_requires, install_requires=install_requires, extras_require={ diff --git a/src/ai/backend/client/cli/admin/images.py b/src/ai/backend/client/cli/admin/images.py index e6650ba5..a2175935 100644 --- a/src/ai/backend/client/cli/admin/images.py +++ b/src/ai/backend/client/cli/admin/images.py @@ -1,18 +1,22 @@ +import json import sys + import click from tabulate import tabulate +from tqdm import tqdm from . import admin -from ...session import Session +from ...compat import asyncio_run +from ...session import Session, AsyncSession from ..pretty import print_done, print_warn, print_fail, print_error @admin.command() @click.option('--operation', is_flag=True, help='Get operational images only') -def images(operation): - ''' +def images(operation: bool) -> None: + """ Show the list of registered images in this cluster. - ''' + """ fields = [ ('Name', 'name'), ('Registry', 'registry'), @@ -32,7 +36,7 @@ def images(operation): print_warn('There are no registered images.') return print(tabulate((item.values() for item in items), - headers=(item[0] for item in fields), + headers=[item[0] for item in fields], floatfmt=',.0f')) @@ -40,25 +44,53 @@ def images(operation): @click.option('-r', '--registry', type=str, default=None, help='The name (usually hostname or "lablup") ' 'of the Docker registry configured.') -def rescan_images(registry): - '''Update the kernel image metadata from all configured docker registries.''' - with Session() as session: - try: - result = session.Image.rescan_images(registry) - except Exception as e: - print_error(e) - sys.exit(1) - if result['ok']: - print_done("Updated the image metadata from the configured registries.") - else: - print_fail(f"Rescanning has failed: {result['msg']}") +def rescan_images(registry: str) -> None: + """ + Update the kernel image metadata from all configured docker registries. + """ + + async def rescan_images_impl(registry: str) -> None: + async with AsyncSession() as session: + try: + result = await session.Image.rescan_images(registry) + except Exception as e: + print_error(e) + sys.exit(1) + if not result['ok']: + print_fail(f"Failed to begin registry scanning: {result['msg']}") + sys.exit(1) + print_done("Started updating the image metadata from the configured registries.") + task_id = result['task_id'] + bgtask = session.BackgroundTask(task_id) + try: + completion_msg_func = lambda: print_done("Finished registry scanning.") + with tqdm(unit='image') as pbar: + async with bgtask.listen_events() as response: + async for ev in response: + data = json.loads(ev.data) + if ev.event == 'task_updated': + pbar.total = data['total_progress'] + pbar.write(data['message']) + pbar.update(data['current_progress'] - pbar.n) + elif ev.event == 'task_failed': + error_msg = data['message'] + completion_msg_func = \ + lambda: print_fail(f"Error occurred: {error_msg}") + elif ev.event == 'task_cancelled': + completion_msg_func = \ + lambda: print_warn("Registry scanning has been " + "cancelled in the middle.") + finally: + completion_msg_func() + + asyncio_run(rescan_images_impl(registry)) @admin.command() @click.argument('alias', type=str) @click.argument('target', type=str) def alias_image(alias, target): - '''Add an image alias.''' + """Add an image alias.""" with Session() as session: try: result = session.Image.alias_image(alias, target) @@ -74,7 +106,7 @@ def alias_image(alias, target): @admin.command() @click.argument('alias', type=str) def dealias_image(alias): - '''Remove an image alias.''' + """Remove an image alias.""" with Session() as session: try: result = session.Image.dealias_image(alias) diff --git a/src/ai/backend/client/cli/app.py b/src/ai/backend/client/cli/app.py index 94f29ed1..2f7f86e1 100644 --- a/src/ai/backend/client/cli/app.py +++ b/src/ai/backend/client/cli/app.py @@ -128,7 +128,7 @@ class ProxyRunnerContext: protocol: str host: str port: int - args: Dict[str, str] + args: Dict[str, Union[None, str, List[str]]] envs: Dict[str, str] api_session: Optional[AsyncSession] local_server: Optional[asyncio.AbstractServer] @@ -150,7 +150,7 @@ def __init__(self, host: str, port: int, self.exit_code = 0 self.args, self.envs = {}, {} - if len(args) > 0: + if args is not None and len(args) > 0: for argline in args: tokens = [] for token in shlex.shlex(argline, @@ -168,7 +168,7 @@ def __init__(self, host: str, port: int, self.args[tokens[0]] = tokens[1] else: self.args[tokens[0]] = tokens[1:] - if len(envs) > 0: + if envs is not None and len(envs) > 0: for envline in envs: split = envline.strip().split('=', maxsplit=2) if len(split) == 2: @@ -178,6 +178,7 @@ def __init__(self, host: str, port: int, async def handle_connection(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + assert self.api_session is not None p = WSProxy(self.api_session, self.session_name, self.app_name, self.protocol, self.args, self.envs, @@ -232,6 +233,7 @@ async def __aexit__(self, *exc_info) -> None: print_info("Shutting down....") self.local_server.close() await self.local_server.wait_closed() + assert self.api_session is not None await self.api_session.__aexit__(*exc_info) assert self.api_session.closed if self.local_server is not None: diff --git a/src/ai/backend/client/cli/proxy.py b/src/ai/backend/client/cli/proxy.py index 3c49df71..74c1c3a2 100644 --- a/src/ai/backend/client/cli/proxy.py +++ b/src/ai/backend/client/cli/proxy.py @@ -1,6 +1,13 @@ +from __future__ import annotations + import asyncio import json import re +from typing import ( + Union, + Tuple, + AsyncIterator, +) import aiohttp from aiohttp import web @@ -19,6 +26,8 @@ class WebSocketProxy: 'upstream_buffer', 'upstream_buffer_task', ) + upstream_buffer: asyncio.Queue[Tuple[Union[str, bytes], aiohttp.WSMsgType]] + def __init__(self, up_conn: aiohttp.ClientWebSocketResponse, down_conn: web.WebSocketResponse): self.up_conn = up_conn @@ -182,18 +191,15 @@ async def websocket_handler(request): reason="Internal Server Error") -async def startup_proxy(app): +async def proxy_context(app: web.Application) -> AsyncIterator[None]: app['client_session'] = AsyncSession() - - -async def cleanup_proxy(app): - await app['client_session'].close() + async with app['client_session']: + yield def create_proxy_app(): app = web.Application() - app.on_startup.append(startup_proxy) - app.on_cleanup.append(cleanup_proxy) + app.cleanup_ctx.append(proxy_context) app.router.add_route("GET", r'/stream/{path:.*$}', websocket_handler) app.router.add_route("GET", r'/wsproxy/{path:.*$}', websocket_handler) diff --git a/src/ai/backend/client/cli/run.py b/src/ai/backend/client/cli/run.py index d128b742..7f0e35a7 100644 --- a/src/ai/backend/client/cli/run.py +++ b/src/ai/backend/client/cli/run.py @@ -22,7 +22,7 @@ from ..compat import asyncio_run, current_loop from ..exceptions import BackendError, BackendAPIError from ..session import Session, AsyncSession, is_legacy_server -from ..utils import undefined +from ..types import undefined from .pretty import ( print_info, print_wait, print_done, print_error, print_fail, print_warn, format_info, @@ -881,7 +881,7 @@ def start(image, name, owner, # base args help='Set the owner of the target session explicitly.') # job scheduling options @click.option('--type', 'type_', metavar='SESSTYPE', - type=click.Choice(['batch', 'interactive', undefined]), + type=click.Choice(['batch', 'interactive', undefined]), # type: ignore default=undefined, help='Either batch or interactive') @click.option('-i', '--image', default=undefined, @@ -1145,9 +1145,9 @@ def events(name, owner_access_key): async def _run_events(): async with AsyncSession() as session: compute_session = session.ComputeSession(name, owner_access_key) - async with compute_session.stream_events() as sse_response: - async for ev in sse_response.fetch_events(): - print(click.style(ev['event'], fg='cyan', bold=True), json.loads(ev['data'])) + async with compute_session.listen_events() as response: + async for ev in response: + print(click.style(ev.event, fg='cyan', bold=True), json.loads(ev.data)) try: asyncio_run(_run_events()) diff --git a/src/ai/backend/client/compat.py b/src/ai/backend/client/compat.py index 6cb3acf4..ed8c0631 100644 --- a/src/ai/backend/client/compat.py +++ b/src/ai/backend/client/compat.py @@ -1,9 +1,16 @@ -''' +""" A compatibility module for backported codes from Python 3.6+ standard library. -''' +""" import asyncio +__all__ = ( + 'current_loop', + 'all_tasks', + 'asyncio_run', + 'asyncio_run_forever', +) + if hasattr(asyncio, 'get_running_loop'): # Python 3.7+ current_loop = asyncio.get_running_loop @@ -60,11 +67,11 @@ def _asyncio_run(coro, *, debug=False): def asyncio_run_forever(server_context, *, debug=False): - ''' + """ A proposed-but-not-implemented asyncio.run_forever() API based on @vxgmichel's idea. See discussions on https://github.com/python/asyncio/pull/465 - ''' + """ loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.set_debug(debug) diff --git a/src/ai/backend/client/config.py b/src/ai/backend/client/config.py index be9a6608..b0000430 100644 --- a/src/ai/backend/client/config.py +++ b/src/ai/backend/client/config.py @@ -5,6 +5,8 @@ from typing import ( Any, Callable, Iterable, Union, List, Tuple, Sequence, + Mapping, + cast, ) import appdirs @@ -39,7 +41,7 @@ def parse_api_version(value: str) -> Tuple[int, str]: def get_env(key: str, default: Any = _undefined, *, clean: Callable[[str], Any] = lambda v: v): - ''' + """ Retrieves a configuration value from the environment variables. The given *key* is uppercased and prefixed by ``"BACKEND_"`` and then ``"SORNA_"`` if the former does not exist. @@ -52,7 +54,7 @@ def get_env(key: str, default: Any = _undefined, *, The default is returning the value as-is. :returns: The value processed by the *clean* function. - ''' + """ key = key.upper() v = os.environ.get('BACKEND_' + key) if v is None: @@ -73,7 +75,7 @@ def bool_env(v: str) -> bool: raise ValueError('Unrecognized value of boolean environment variable', v) -def _clean_urls(v: str) -> List[URL]: +def _clean_urls(v: Union[URL, str]) -> List[URL]: if isinstance(v, URL): return [v] if isinstance(v, str): @@ -95,7 +97,7 @@ def _clean_tokens(v): class APIConfig: - ''' + """ Represents a set of API client configurations. The access key and secret key are mandatory -- they must be set in either environment variables or as the explicit arguments. @@ -129,9 +131,9 @@ class APIConfig: access key) to be automatically mounted upon any :func:`Kernel.get_or_create() ` calls. - ''' + """ - DEFAULTS = { + DEFAULTS: Mapping[str, Any] = { 'endpoint': 'https://api.backend.ai', 'endpoint_type': 'api', 'version': f'v{API_VERSION[0]}.{API_VERSION[1]}', @@ -141,9 +143,12 @@ class APIConfig: 'connection_timeout': 10.0, 'read_timeout': None, } - ''' + """ The default values except the access and secret keys. - ''' + """ + + _group: str + _hash_type: str def __init__(self, *, endpoint: Union[URL, str] = None, @@ -159,17 +164,17 @@ def __init__(self, *, skip_sslcert_validation: bool = None, connection_timeout: float = None, read_timeout: float = None) -> None: - from . import get_user_agent # noqa; to avoid circular imports + from . import get_user_agent self._endpoints = ( _clean_urls(endpoint) if endpoint else get_env('ENDPOINT', self.DEFAULTS['endpoint'], clean=_clean_urls)) random.shuffle(self._endpoints) - self._endpoint_type = endpoint_type if endpoint_type \ + self._endpoint_type = endpoint_type if endpoint_type is not None \ else get_env('ENDPOINT_TYPE', self.DEFAULTS['endpoint_type']) - self._domain = domain if domain else get_env('DOMAIN', self.DEFAULTS['domain']) - self._group = group if group else get_env('GROUP', self.DEFAULTS['group']) - self._version = version if version else self.DEFAULTS['version'] - self._user_agent = user_agent if user_agent else get_user_agent() + self._domain = domain if domain is not None else get_env('DOMAIN', self.DEFAULTS['domain']) + self._group = group if group is not None else get_env('GROUP', self.DEFAULTS['group']) + self._version = version if version is not None else self.DEFAULTS['version'] + self._user_agent = user_agent if user_agent is not None else get_user_agent() if self._endpoint_type == 'api': self._access_key = access_key if access_key is not None \ else get_env('ACCESS_KEY', '') @@ -178,8 +183,8 @@ def __init__(self, *, else: self._access_key = 'dummy' self._secret_key = 'dummy' - self._hash_type = hash_type.lower() if hash_type else \ - self.DEFAULTS['hash_type'] + self._hash_type = hash_type.lower() if hash_type is not None else \ + cast(str, self.DEFAULTS['hash_type']) arg_vfolders = set(vfolder_mounts) if vfolder_mounts else set() env_vfolders = set(get_env('VFOLDER_MOUNTS', [], clean=_clean_tokens)) self._vfolder_mounts = [*(arg_vfolders | env_vfolders)] @@ -198,16 +203,16 @@ def is_anonymous(self) -> bool: @property def endpoint(self) -> URL: - ''' + """ The currently active endpoint URL. This may change if there are multiple configured endpoints and the current one is not accessible. - ''' + """ return self._endpoints[0] @property def endpoints(self) -> Sequence[URL]: - '''All configured endpoint URLs.''' + """All configured endpoint URLs.""" return self._endpoints def rotate_endpoints(self): @@ -217,74 +222,74 @@ def rotate_endpoints(self): @property def endpoint_type(self) -> str: - ''' + """ The configured endpoint type. - ''' + """ return self._endpoint_type @property def domain(self) -> str: - '''The configured domain.''' + """The configured domain.""" return self._domain @property def group(self) -> str: - '''The configured group.''' + """The configured group.""" return self._group @property def user_agent(self) -> str: - '''The configured user agent string.''' + """The configured user agent string.""" return self._user_agent @property def access_key(self) -> str: - '''The configured API access key.''' + """The configured API access key.""" return self._access_key @property def secret_key(self) -> str: - '''The configured API secret key.''' + """The configured API secret key.""" return self._secret_key @property def version(self) -> str: - '''The configured API protocol version.''' + """The configured API protocol version.""" return self._version @property def hash_type(self) -> str: - '''The configured hash algorithm for API authentication signatures.''' + """The configured hash algorithm for API authentication signatures.""" return self._hash_type @property - def vfolder_mounts(self) -> Tuple[str, ...]: - '''The configured auto-mounted vfolder list.''' + def vfolder_mounts(self) -> Sequence[str]: + """The configured auto-mounted vfolder list.""" return self._vfolder_mounts @property def skip_sslcert_validation(self) -> bool: - '''Whether to skip SSL certificate validation for the API gateway.''' + """Whether to skip SSL certificate validation for the API gateway.""" return self._skip_sslcert_validation @property def connection_timeout(self) -> float: - '''The maximum allowed duration for making TCP connections to the server.''' + """The maximum allowed duration for making TCP connections to the server.""" return self._connection_timeout @property def read_timeout(self) -> float: - '''The maximum allowed waiting time for the first byte of the response from the server.''' + """The maximum allowed waiting time for the first byte of the response from the server.""" return self._read_timeout def get_config(): - ''' + """ Returns the configuration for the current process. If there is no explicitly set :class:`APIConfig` instance, it will generate a new one from the current environment variables and defaults. - ''' + """ global _config if _config is None: _config = APIConfig() @@ -292,8 +297,8 @@ def get_config(): def set_config(conf: APIConfig): - ''' + """ Sets the configuration used throughout the current process. - ''' + """ global _config _config = conf diff --git a/src/ai/backend/client/etcd.py b/src/ai/backend/client/etcd.py deleted file mode 100644 index fb809bc4..00000000 --- a/src/ai/backend/client/etcd.py +++ /dev/null @@ -1,74 +0,0 @@ -from .base import api_function -from .request import Request - -__all__ = ( - 'EtcdConfig', -) - - -class EtcdConfig: - ''' - Provides a way to get or set ETCD configurations. - - .. note:: - - All methods in this function class require your API access key to - have the *superadmin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' - - @api_function - @classmethod - async def get(cls, key: str, prefix: bool = False) -> dict: - ''' - Get configuration from ETCD with given key. - - :param key: Name of the key to fetch. - :param prefix: get all keys prefixed with the give key. - ''' - rqst = Request(cls.session, 'POST', '/config/get') - rqst.set_json({ - 'key': key, - 'prefix': prefix, - }) - async with rqst.fetch() as resp: - data = await resp.json() - return data.get('result', None) - - @api_function - @classmethod - async def set(cls, key: str, value: str) -> dict: - ''' - Set configuration into ETCD with given key and value. - - :param key: Name of the key to set. - :param value: Value to set. - ''' - rqst = Request(cls.session, 'POST', '/config/set') - rqst.set_json({ - 'key': key, - 'value': value, - }) - async with rqst.fetch() as resp: - data = await resp.json() - return data - - @api_function - @classmethod - async def delete(cls, key: str, prefix: bool = False) -> dict: - ''' - Delete configuration from ETCD with given key. - - :param key: Name of the key to delete. - :param prefix: delete all keys prefixed with the give key. - ''' - rqst = Request(cls.session, 'POST', '/config/delete') - rqst.set_json({ - 'key': key, - 'prefix': prefix, - }) - async with rqst.fetch() as resp: - data = await resp.json() - return data diff --git a/src/ai/backend/client/func/admin.py b/src/ai/backend/client/func/admin.py index 47c92a64..57170ca1 100644 --- a/src/ai/backend/client/func/admin.py +++ b/src/ai/backend/client/func/admin.py @@ -1,15 +1,16 @@ from typing import Any, Mapping, Optional -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'Admin', ) -class Admin: - ''' +class Admin(BaseFunction): + """ Provides the function interface for making admin GrapQL queries. .. note:: @@ -17,17 +18,14 @@ class Admin: Depending on the privilege of your API access key, you may or may not have access to querying/mutating server-side resources of other users. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def query(cls, query: str, variables: Optional[Mapping[str, Any]] = None, ) -> Any: - ''' + """ Sends the GraphQL query and returns the response. :param query: The GraphQL query string. @@ -36,12 +34,12 @@ async def query(cls, query: str, in the query. :returns: The object parsed from the response JSON string. - ''' + """ gql_query = { 'query': query, 'variables': variables if variables else {}, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json(gql_query) async with rqst.fetch() as resp: return await resp.json() diff --git a/src/ai/backend/client/func/agent.py b/src/ai/backend/client/func/agent.py index b498339f..23d92fc4 100644 --- a/src/ai/backend/client/func/agent.py +++ b/src/ai/backend/client/func/agent.py @@ -1,8 +1,9 @@ import textwrap from typing import Iterable, Sequence -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'Agent', @@ -10,8 +11,8 @@ ) -class Agent: - ''' +class Agent(BaseFunction): + """ Provides a shortcut of :func:`Admin.query() ` that fetches various agent information. @@ -20,10 +21,7 @@ class Agent: All methods in this function class require your API access key to have the *admin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod @@ -32,7 +30,7 @@ async def list_with_limit(cls, offset, status: str = 'ALIVE', fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetches the list of agents with the given status with limit and offset for pagination. @@ -42,7 +40,7 @@ async def list_with_limit(cls, status (one of ``'ALIVE'``, ``'TERMINATED'``, ``'LOST'``, etc.) :param fields: Additional per-agent query fields to fetch. - ''' + """ if fields is None: fields = ( 'id', @@ -65,7 +63,7 @@ async def list_with_limit(cls, 'offset': offset, 'status': status, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -81,14 +79,14 @@ async def detail(cls, agent_id: str, fields: Iterable[str] = None) -> Sequence[d fields = ('id', 'status', 'addr', 'region', 'first_contact', 'cpu_cur_pct', 'mem_cur_bytes', 'available_slots', 'occupied_slots') - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($agent_id: String!) { agent(agent_id: $agent_id) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'agent_id': agent_id} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -98,8 +96,8 @@ async def detail(cls, agent_id: str, fields: Iterable[str] = None) -> Sequence[d return data['agent'] -class AgentWatcher: - ''' +class AgentWatcher(BaseFunction): + """ Provides a shortcut of :func:`Admin.query() ` that manipulate agent status. @@ -107,18 +105,15 @@ class AgentWatcher: All methods in this function class require you to have the *superadmin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def get_status(cls, agent_id: str) -> dict: - ''' + """ Get agent and watcher status. - ''' - rqst = Request(cls.session, 'GET', '/resource/watcher') + """ + rqst = Request(api_session.get(), 'GET', '/resource/watcher') rqst.set_json({'agent_id': agent_id}) async with rqst.fetch() as resp: data = await resp.json() @@ -130,10 +125,10 @@ async def get_status(cls, agent_id: str) -> dict: @api_function @classmethod async def agent_start(cls, agent_id: str) -> dict: - ''' + """ Start agent. - ''' - rqst = Request(cls.session, 'POST', '/resource/watcher/agent/start') + """ + rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/start') rqst.set_json({'agent_id': agent_id}) async with rqst.fetch() as resp: data = await resp.json() @@ -145,10 +140,10 @@ async def agent_start(cls, agent_id: str) -> dict: @api_function @classmethod async def agent_stop(cls, agent_id: str) -> dict: - ''' + """ Stop agent. - ''' - rqst = Request(cls.session, 'POST', '/resource/watcher/agent/stop') + """ + rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/stop') rqst.set_json({'agent_id': agent_id}) async with rqst.fetch() as resp: data = await resp.json() @@ -160,10 +155,10 @@ async def agent_stop(cls, agent_id: str) -> dict: @api_function @classmethod async def agent_restart(cls, agent_id: str) -> dict: - ''' + """ Restart agent. - ''' - rqst = Request(cls.session, 'POST', '/resource/watcher/agent/restart') + """ + rqst = Request(api_session.get(), 'POST', '/resource/watcher/agent/restart') rqst.set_json({'agent_id': agent_id}) async with rqst.fetch() as resp: data = await resp.json() diff --git a/src/ai/backend/client/func/auth.py b/src/ai/backend/client/func/auth.py index 97f331ad..e38601c0 100644 --- a/src/ai/backend/client/func/auth.py +++ b/src/ai/backend/client/func/auth.py @@ -1,26 +1,27 @@ -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'Auth', ) -class Auth: - ''' +class Auth(BaseFunction): + """ Provides the function interface for login session management and authorization. - ''' + """ @api_function @classmethod async def login(cls, user_id: str, password: str) -> dict: - ''' + """ Log-in into the endpoint with the given user ID and password. It creates a server-side web session and return a dictionary with ``"authenticated"`` boolean field and JSON-encoded raw cookie data. - ''' - rqst = Request(cls.session, 'POST', '/server/login') + """ + rqst = Request(api_session.get(), 'POST', '/server/login') rqst.set_json({ 'username': user_id, 'password': password, @@ -36,11 +37,11 @@ async def login(cls, user_id: str, password: str) -> dict: @api_function @classmethod async def logout(cls) -> None: - ''' + """ Log-out from the endpoint. It clears the server-side web session. - ''' - rqst = Request(cls.session, 'POST', '/server/logout') + """ + rqst = Request(api_session.get(), 'POST', '/server/logout') async with rqst.fetch() as resp: resp.raw_response.raise_for_status() @@ -50,7 +51,7 @@ async def update_password(cls, old_password: str, new_password: str, new_passwor """ Update user's password. This API works only for account owner. """ - rqst = Request(cls.session, 'POST', '/auth/update-password') + rqst = Request(api_session.get(), 'POST', '/auth/update-password') rqst.set_json({ 'old_password': old_password, 'new_password': new_password, diff --git a/src/ai/backend/client/func/base.py b/src/ai/backend/client/func/base.py index a6d3ac1c..f350e1d3 100644 --- a/src/ai/backend/client/func/base.py +++ b/src/ai/backend/client/func/base.py @@ -1,4 +1,5 @@ import functools +from ..session import api_session, AsyncSession __all__ = ( 'APIFunctionMeta', @@ -11,39 +12,38 @@ def _wrap_method(cls, orig_name, meth): @functools.wraps(meth) def _method(*args, **kwargs): - assert cls.session is not None, \ - 'You must use API wrapper functions via a Session object.' # We need to keep the original attributes so that they could be correctly # bound to the class/instance at runtime. func = getattr(cls, orig_name) coro = func(*args, **kwargs) - if hasattr(cls.session, 'worker_thread'): - return cls.session.worker_thread.execute(coro) - else: + _api_session = api_session.get() + if _api_session is None: + raise RuntimeError("API functions must be called inside the context of a valid API session") + if isinstance(_api_session, AsyncSession): return coro + else: + return _api_session.worker_thread.execute(coro) return _method def api_function(meth): - ''' + """ Mark the wrapped method as the API function method. - ''' + """ setattr(meth, '_backend_api', True) return meth class APIFunctionMeta(type): - ''' + """ Converts all methods marked with :func:`api_function` into session-aware methods that are either plain Python functions or coroutines. - ''' + """ _async = True def __init__(cls, name, bases, attrs, **kwargs): - assert 'session' in attrs, \ - 'An API function class must define the session attribute.' super().__init__(name, bases, attrs) for attr_name, attr_value in attrs.items(): if hasattr(attr_value, '_backend_api'): @@ -54,8 +54,4 @@ def __init__(cls, name, bases, attrs, **kwargs): class BaseFunction(metaclass=APIFunctionMeta): - ''' - The class used to build API functions proxies bound to specific session - instances. - ''' - session = None + pass diff --git a/src/ai/backend/client/func/bgtask.py b/src/ai/backend/client/func/bgtask.py new file mode 100644 index 00000000..49283070 --- /dev/null +++ b/src/ai/backend/client/func/bgtask.py @@ -0,0 +1,37 @@ +from typing import Union +from uuid import UUID + +from .base import BaseFunction +from ..request import ( + Request, + SSEContextManager, +) +from ..session import api_session + + +class BackgroundTask(BaseFunction): + """ + Provides server-sent events streaming functions. + """ + + task_id: UUID + + def __init__(self, task_id: Union[UUID, str]) -> None: + self.task_id = task_id if isinstance(task_id, UUID) else UUID(task_id) + + # only supported in AsyncAPISession + def listen_events(self) -> SSEContextManager: + """ + Opens an event stream of the background task updates. + + :returns: a context manager that returns an :class:`SSEResponse` object. + """ + params = { + 'task_id': str(self.task_id), + } + request = Request( + api_session.get(), + 'GET', '/events/background-task', + params=params, + ) + return request.connect_events() diff --git a/src/ai/backend/client/func/domain.py b/src/ai/backend/client/func/domain.py index 4defbaf9..8ba5a221 100644 --- a/src/ai/backend/client/func/domain.py +++ b/src/ai/backend/client/func/domain.py @@ -1,16 +1,17 @@ import textwrap from typing import Iterable, Sequence -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'Domain', ) -class Domain: - ''' +class Domain(BaseFunction): + """ Provides a shortcut of :func:`Admin.query() ` that fetches various domain information. @@ -19,30 +20,27 @@ class Domain: All methods in this function class require your API access key to have the *admin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetches the list of domains. :param fields: Additional per-domain query fields to fetch. - ''' + """ if fields is None: fields = ('name', 'description', 'is_active', 'created_at', 'total_resource_slots', 'allowed_vfolder_hosts', 'allowed_docker_registries', 'integration_id') - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query { domains {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, }) @@ -53,24 +51,24 @@ async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]: @api_function @classmethod async def detail(cls, name: str, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetch information of a domain with name. :param name: Name of the domain to fetch. :param fields: Additional per-domain query fields to fetch. - ''' + """ if fields is None: fields = ('name', 'description', 'is_active', 'created_at', 'total_resource_slots', 'allowed_vfolder_hosts', 'allowed_docker_registries', 'integration_id',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($name: String) { domain(name: $name) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'name': name} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -87,19 +85,19 @@ async def create(cls, name: str, description: str = '', is_active: bool = True, allowed_docker_registries: Iterable[str] = None, integration_id: str = None, fields: Iterable[str] = None) -> dict: - ''' + """ Creates a new domain with the given options. You need an admin privilege for this operation. - ''' + """ if fields is None: fields = ('name',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ mutation($name: String!, $input: DomainInput!) { create_domain(name: $name, props: $input) { ok msg domain {$fields} } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = { 'name': name, @@ -112,7 +110,7 @@ async def create(cls, name: str, description: str = '', is_active: bool = True, 'integration_id': integration_id, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -129,17 +127,17 @@ async def update(cls, name: str, new_name: str = None, description: str = None, allowed_docker_registries: Iterable[str] = None, integration_id: str = None, fields: Iterable[str] = None) -> dict: - ''' + """ Update existing domain. You need an admin privilege for this operation. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($name: String!, $input: ModifyDomainInput!) { modify_domain(name: $name, props: $input) { ok msg } } - ''') + """) variables = { 'name': name, 'input': { @@ -152,7 +150,7 @@ async def update(cls, name: str, new_name: str = None, description: str = None, 'integration_id': integration_id, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -164,18 +162,18 @@ async def update(cls, name: str, new_name: str = None, description: str = None, @api_function @classmethod async def delete(cls, name: str): - ''' + """ Deletes an existing domain. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($name: String!) { delete_domain(name: $name) { ok msg } } - ''') + """) variables = {'name': name} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, diff --git a/src/ai/backend/client/func/dotfile.py b/src/ai/backend/client/func/dotfile.py index 123448df..6552a1ba 100644 --- a/src/ai/backend/client/func/dotfile.py +++ b/src/ai/backend/client/func/dotfile.py @@ -1,7 +1,8 @@ from typing import List, Mapping -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( @@ -9,9 +10,8 @@ ) -class Dotfile: +class Dotfile(BaseFunction): - session = None @api_function @classmethod async def create(cls, @@ -20,7 +20,7 @@ async def create(cls, permission: str, owner_access_key: str = None, ) -> 'Dotfile': - rqst = Request(cls.session, + rqst = Request(api_session.get(), 'POST', '/user-config/dotfiles') body = { 'data': data, @@ -29,18 +29,16 @@ async def create(cls, } rqst.set_json(body) async with rqst.fetch() as resp: - if resp.status == 200: - await resp.json() - return cls(path, owner_access_key=owner_access_key) + await resp.json() + return cls(path, owner_access_key=owner_access_key) @api_function @classmethod async def list_dotfiles(cls) -> 'List[Mapping[str, str]]': - rqst = Request(cls.session, + rqst = Request(api_session.get(), 'GET', '/user-config/dotfiles') async with rqst.fetch() as resp: - if resp.status == 200: - return await resp.json() + return await resp.json() def __init__(self, path: str, owner_access_key: str = None): self.path = path @@ -51,12 +49,11 @@ async def get(self) -> str: params = {'path': self.path} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'GET', f'/user-config/dotfiles', params=params) async with rqst.fetch() as resp: - if resp.status == 200: - return await resp.json() + return await resp.json() @api_function async def update(self, data: str, permission: str): @@ -67,7 +64,7 @@ async def update(self, data: str, permission: str): } if self.owner_access_key: body['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'PATCH', f'/user-config/dotfiles') rqst.set_json(body) @@ -79,7 +76,7 @@ async def delete(self): params = {'path': self.path} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'DELETE', f'/user-config/dotfiles', params=params) diff --git a/src/ai/backend/client/func/etcd.py b/src/ai/backend/client/func/etcd.py index 3a90620a..1472f2c2 100644 --- a/src/ai/backend/client/func/etcd.py +++ b/src/ai/backend/client/func/etcd.py @@ -1,34 +1,32 @@ -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'EtcdConfig', ) -class EtcdConfig: - ''' +class EtcdConfig(BaseFunction): + """ Provides a way to get or set ETCD configurations. .. note:: All methods in this function class require your API access key to have the *superadmin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def get(cls, key: str, prefix: bool = False) -> dict: - ''' + """ Get configuration from ETCD with given key. :param key: Name of the key to fetch. :param prefix: get all keys prefixed with the give key. - ''' - rqst = Request(cls.session, 'POST', '/config/get') + """ + rqst = Request(api_session.get(), 'POST', '/config/get') rqst.set_json({ 'key': key, 'prefix': prefix, @@ -40,13 +38,13 @@ async def get(cls, key: str, prefix: bool = False) -> dict: @api_function @classmethod async def set(cls, key: str, value: str) -> dict: - ''' + """ Set configuration into ETCD with given key and value. :param key: Name of the key to set. :param value: Value to set. - ''' - rqst = Request(cls.session, 'POST', '/config/set') + """ + rqst = Request(api_session.get(), 'POST', '/config/set') rqst.set_json({ 'key': key, 'value': value, @@ -58,13 +56,13 @@ async def set(cls, key: str, value: str) -> dict: @api_function @classmethod async def delete(cls, key: str, prefix: bool = False) -> dict: - ''' + """ Delete configuration from ETCD with given key. :param key: Name of the key to delete. :param prefix: delete all keys prefixed with the give key. - ''' - rqst = Request(cls.session, 'POST', '/config/delete') + """ + rqst = Request(api_session.get(), 'POST', '/config/delete') rqst.set_json({ 'key': key, 'prefix': prefix, diff --git a/src/ai/backend/client/func/group.py b/src/ai/backend/client/func/group.py index 8b4b3a94..fa99c864 100644 --- a/src/ai/backend/client/func/group.py +++ b/src/ai/backend/client/func/group.py @@ -1,16 +1,17 @@ import textwrap from typing import Iterable, Sequence -from ai.backend.client.func.base import api_function -from ai.backend.client.request import Request +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session __all__ = ( 'Group', ) -class Group: - ''' +class Group(BaseFunction): + """ Provides a shortcut of :func:`Group.query() ` that fetches various group information. @@ -18,69 +19,66 @@ class Group: All methods in this function class require your API access key to have the *admin* privilege. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def list(cls, domain_name: str, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetches the list of groups. :param domain_name: Name of domain to list groups. :param fields: Additional per-group query fields to fetch. - ''' + """ if fields is None: fields = ('id', 'name', 'description', 'is_active', 'created_at', 'domain_name', 'total_resource_slots', 'allowed_vfolder_hosts', 'integration_id') - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($domain_name: String) { groups(domain_name: $domain_name) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'domain_name': domain_name} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['groups'] + return data['groups'] @api_function @classmethod async def detail(cls, gid: str, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetch information of a group with group ID. :param gid: ID of the group to fetch. :param fields: Additional per-group query fields to fetch. - ''' + """ if fields is None: fields = ('id', 'name', 'description', 'is_active', 'created_at', 'domain_name', 'total_resource_slots', 'allowed_vfolder_hosts', 'integration_id') - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($gid: String!) { group(id: $gid) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'gid': gid} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['group'] + return data['group'] @api_function @classmethod @@ -89,19 +87,19 @@ async def create(cls, domain_name: str, name: str, description: str = '', allowed_vfolder_hosts: Iterable[str] = None, integration_id: str = None, fields: Iterable[str] = None) -> dict: - ''' + """ Creates a new group with the given options. You need an admin privilege for this operation. - ''' + """ if fields is None: fields = ('id', 'domain_name', 'name',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ mutation($name: String!, $input: GroupInput!) { create_group(name: $name, props: $input) { ok msg group {$fields} } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = { 'name': name, @@ -114,14 +112,14 @@ async def create(cls, domain_name: str, name: str, description: str = '', 'integration_id': integration_id, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['create_group'] + return data['create_group'] @api_function @classmethod @@ -130,17 +128,17 @@ async def update(cls, gid: str, name: str = None, description: str = None, allowed_vfolder_hosts: Iterable[str] = None, integration_id: str = None, fields: Iterable[str] = None) -> dict: - ''' + """ Update existing group. You need an admin privilege for this operation. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($gid: String!, $input: ModifyGroupInput!) { modify_group(gid: $gid, props: $input) { ok msg } } - ''') + """) variables = { 'gid': gid, 'input': { @@ -152,53 +150,53 @@ async def update(cls, gid: str, name: str = None, description: str = None, 'integration_id': integration_id, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['modify_group'] + return data['modify_group'] @api_function @classmethod async def delete(cls, gid: str): - ''' + """ Deletes an existing group. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($gid: String!) { delete_group(gid: $gid) { ok msg } } - ''') + """) variables = {'gid': gid} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['delete_group'] + return data['delete_group'] @api_function @classmethod async def add_users(cls, gid: str, user_uuids: Iterable[str], fields: Iterable[str] = None) -> dict: - ''' + """ Add users to a group. You need an admin privilege for this operation. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($gid: String!, $input: ModifyGroupInput!) { modify_group(gid: $gid, props: $input) { ok msg } } - ''') + """) variables = { 'gid': gid, 'input': { @@ -206,30 +204,30 @@ async def add_users(cls, gid: str, user_uuids: Iterable[str], 'user_uuids': user_uuids, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['modify_group'] + return data['modify_group'] @api_function @classmethod async def remove_users(cls, gid: str, user_uuids: Iterable[str], fields: Iterable[str] = None) -> dict: - ''' + """ Remove users from a group. You need an admin privilege for this operation. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($gid: String!, $input: ModifyGroupInput!) { modify_group(gid: $gid, props: $input) { ok msg } } - ''') + """) variables = { 'gid': gid, 'input': { @@ -237,11 +235,11 @@ async def remove_users(cls, gid: str, user_uuids: Iterable[str], 'user_uuids': user_uuids, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['modify_group'] + return data['modify_group'] diff --git a/src/ai/backend/client/func/image.py b/src/ai/backend/client/func/image.py index 308875b5..26c745ee 100644 --- a/src/ai/backend/client/func/image.py +++ b/src/ai/backend/client/func/image.py @@ -1,31 +1,29 @@ from typing import Iterable, Sequence -from ai.backend.client.func.base import api_function -from ai.backend.client.request import Request +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session __all__ = ( 'Image', ) -class Image: - ''' +class Image(BaseFunction): + """ Provides a shortcut of :func:`Admin.query() ` that fetches the information about available images. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def list(cls, operation: bool = False, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetches the list of registered images in this cluster. - ''' + """ if fields is None: fields = ( @@ -42,34 +40,34 @@ async def list(cls, variables = { 'is_operation': operation, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['images'] + return data['images'] @api_function @classmethod async def rescan_images(cls, registry: str): q = 'mutation($registry: String) {' \ ' rescan_images(registry:$registry) {' \ - ' ok msg' \ + ' ok msg task_id' \ ' }' \ '}' variables = { 'registry': registry, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['rescan_images'] + return data['rescan_images'] @api_function @classmethod @@ -83,14 +81,14 @@ async def alias_image(cls, alias: str, target: str) -> dict: 'alias': alias, 'target': target, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['alias_image'] + return data['alias_image'] @api_function @classmethod @@ -103,26 +101,28 @@ async def dealias_image(cls, alias: str) -> dict: variables = { 'alias': alias, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, }) async with rqst.fetch() as resp: data = await resp.json() - return data['dealias_image'] + return data['dealias_image'] @api_function @classmethod async def get_image_import_form(cls) -> dict: - rqst = Request(cls.session, 'GET', '/image/import') + rqst = Request(api_session.get(), 'GET', '/image/import') async with rqst.fetch() as resp: - return await resp.json() + data = await resp.json() + return data @api_function @classmethod async def build(cls, **kwargs) -> dict: - rqst = Request(cls.session, 'POST', '/image/import') + rqst = Request(api_session.get(), 'POST', '/image/import') rqst.set_json(kwargs) async with rqst.fetch() as resp: - return await resp.json() + data = await resp.json() + return data diff --git a/src/ai/backend/client/func/keypair.py b/src/ai/backend/client/func/keypair.py index 716add15..cf7c0a89 100644 --- a/src/ai/backend/client/func/keypair.py +++ b/src/ai/backend/client/func/keypair.py @@ -1,20 +1,22 @@ -from typing import Iterable, Sequence, Union +from typing import ( + Any, Iterable, Union, + Sequence, + Dict, +) -from ai.backend.client.func.base import api_function -from ai.backend.client.request import Request +from .base import api_function, BaseFunction +from ..request import Request +from ..session import api_session __all__ = ( 'KeyPair', ) -class KeyPair: - ''' +class KeyPair(BaseFunction): + """ Provides interactions with keypairs. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ def __init__(self, access_key: str): self.access_key = access_key @@ -27,10 +29,10 @@ async def create(cls, user_id: Union[int, str], resource_policy: str = None, rate_limit: int = None, fields: Iterable[str] = None) -> dict: - ''' + """ Creates a new keypair with the given options. You need an admin privilege for this operation. - ''' + """ if fields is None: fields = ('access_key', 'secret_key') uid_type = 'Int!' if isinstance(user_id, int) else 'String!' @@ -49,7 +51,7 @@ async def create(cls, user_id: Union[int, str], 'rate_limit': rate_limit, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -83,7 +85,7 @@ async def update(cls, access_key: str, 'rate_limit': rate_limit, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -106,7 +108,7 @@ async def delete(cls, access_key: str): variables = { 'access_key': access_key, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -120,10 +122,10 @@ async def delete(cls, access_key: str): async def list(cls, user_id: Union[int, str] = None, is_active: bool = None, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Lists the keypairs. You need an admin privilege for this operation. - ''' + """ if fields is None: fields = ( 'access_key', 'secret_key', @@ -143,12 +145,12 @@ async def list(cls, user_id: Union[int, str] = None, ' }' \ '}' q = q.replace('$fields', ' '.join(fields)) - variables = { + variables: Dict[str, Any] = { 'is_active': is_active, } if user_id is not None: variables['email'] = user_id - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -159,13 +161,13 @@ async def list(cls, user_id: Union[int, str] = None, @api_function async def info(self, fields: Iterable[str] = None) -> dict: - ''' + """ Returns the keypair's information such as resource limits. :param fields: Additional per-agent query fields to fetch. .. versionadded:: 18.12 - ''' + """ if fields is None: fields = ( 'access_key', 'secret_key', @@ -177,7 +179,7 @@ async def info(self, fields: Iterable[str] = None) -> dict: ' }' \ '}' q = q.replace('$fields', ' '.join(fields)) - rqst = Request(self.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, }) @@ -188,10 +190,10 @@ async def info(self, fields: Iterable[str] = None) -> dict: @api_function @classmethod async def activate(cls, access_key: str) -> dict: - ''' + """ Activates this keypair. You need an admin privilege for this operation. - ''' + """ q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \ ' modify_keypair(access_key: $access_key, props: $input) {' \ ' ok msg' \ @@ -206,7 +208,7 @@ async def activate(cls, access_key: str) -> dict: 'rate_limit': None, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -218,12 +220,12 @@ async def activate(cls, access_key: str) -> dict: @api_function @classmethod async def deactivate(cls, access_key: str) -> dict: - ''' + """ Deactivates this keypair. Deactivated keypairs cannot make any API requests unless activated again by an administrator. You need an admin privilege for this operation. - ''' + """ q = 'mutation($access_key: String!, $input: ModifyKeyPairInput!) {' + \ ' modify_keypair(access_key: $access_key, props: $input) {' \ ' ok msg' \ @@ -238,7 +240,7 @@ async def deactivate(cls, access_key: str) -> dict: 'rate_limit': None, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, diff --git a/src/ai/backend/client/func/keypair_resource_policy.py b/src/ai/backend/client/func/keypair_resource_policy.py index c3725805..18cc2b2e 100644 --- a/src/ai/backend/client/func/keypair_resource_policy.py +++ b/src/ai/backend/client/func/keypair_resource_policy.py @@ -1,19 +1,18 @@ from typing import Iterable, Sequence -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'KeypairResourcePolicy' ) -class KeypairResourcePolicy: +class KeypairResourcePolicy(BaseFunction): """ Provides interactions with keypair resource policy. """ - session = None - """The client session instance that this function class is bound to.""" def __init__(self, access_key: str): self.access_key = access_key @@ -56,7 +55,7 @@ async def create(cls, name: str, 'allowed_vfolder_hosts': allowed_vfolder_hosts, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -99,7 +98,7 @@ async def update(cls, name: str, 'allowed_vfolder_hosts': allowed_vfolder_hosts, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -124,7 +123,7 @@ async def delete(cls, name: str) -> dict: variables = { 'name': name, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, @@ -153,7 +152,7 @@ async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]: ' }' \ '}' q = q.replace('$fields', ' '.join(fields)) - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, }) @@ -186,7 +185,7 @@ async def info(self, name: str, fields: Iterable[str] = None) -> dict: variables = { 'name': name, } - rqst = Request(self.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': q, 'variables': variables, diff --git a/src/ai/backend/client/func/manager.py b/src/ai/backend/client/func/manager.py index 6aad6d3d..0aab6167 100644 --- a/src/ai/backend/client/func/manager.py +++ b/src/ai/backend/client/func/manager.py @@ -1,24 +1,22 @@ -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session -class Manager: - ''' +class Manager(BaseFunction): + """ Provides controlling of the gateway/manager servers. .. versionadded:: 18.12 - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod async def status(cls): - ''' + """ Returns the current status of the configured API server. - ''' - rqst = Request(cls.session, 'GET', '/manager/status') + """ + rqst = Request(api_session.get(), 'GET', '/manager/status') rqst.set_json({ 'status': 'running', }) @@ -28,7 +26,7 @@ async def status(cls): @api_function @classmethod async def freeze(cls, force_kill: bool = False): - ''' + """ Freezes the configured API server. Any API clients will no longer be able to create new compute sessions nor create and modify vfolders/keypairs/etc. @@ -39,8 +37,8 @@ async def freeze(cls, force_kill: bool = False): compute sessions forcibly. If not set, clients who have running compute session are still able to interact with them though they cannot create new compute sessions. - ''' - rqst = Request(cls.session, 'PUT', '/manager/status') + """ + rqst = Request(api_session.get(), 'PUT', '/manager/status') rqst.set_json({ 'status': 'frozen', 'force_kill': force_kill, @@ -51,10 +49,10 @@ async def freeze(cls, force_kill: bool = False): @api_function @classmethod async def unfreeze(cls): - ''' + """ Unfreezes the configured API server so that it resumes to normal operation. - ''' - rqst = Request(cls.session, 'PUT', '/manager/status') + """ + rqst = Request(api_session.get(), 'PUT', '/manager/status') rqst.set_json({ 'status': 'running', }) diff --git a/src/ai/backend/client/func/resource.py b/src/ai/backend/client/func/resource.py index 63531bf0..3f0d5996 100644 --- a/src/ai/backend/client/func/resource.py +++ b/src/ai/backend/client/func/resource.py @@ -1,64 +1,60 @@ from typing import Sequence -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'Resource' ) -class Resource: +class Resource(BaseFunction): """ Provides interactions with resource. """ - session = None - """The client session instance that this function class is bound to.""" - - # def __init__(self, access_key: str): - # self.access_key = access_key @api_function @classmethod async def list(cls): - ''' + """ Lists all resource presets. - ''' - rqst = Request(cls.session, 'GET', '/resource/presets') + """ + rqst = Request(api_session.get(), 'GET', '/resource/presets') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def check_presets(cls): - ''' + """ Lists all resource presets in the current scaling group with additiona information. - ''' - rqst = Request(cls.session, 'POST', '/resource/check-presets') + """ + rqst = Request(api_session.get(), 'POST', '/resource/check-presets') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def get_docker_registries(cls): - ''' + """ Lists all registered docker registries. - ''' - rqst = Request(cls.session, 'GET', '/config/docker-registries') + """ + rqst = Request(api_session.get(), 'GET', '/config/docker-registries') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def usage_per_month(cls, month: str, group_ids: Sequence[str]): - ''' + """ Get usage statistics for groups specified by `group_ids` at specific `month`. :param month: The month you want to get the statistics (yyyymm). :param group_ids: Groups IDs to be included in the result. - ''' - rqst = Request(cls.session, 'GET', '/resource/usage/month') + """ + rqst = Request(api_session.get(), 'GET', '/resource/usage/month') rqst.set_json({ 'month': month, 'group_ids': group_ids, @@ -69,15 +65,15 @@ async def usage_per_month(cls, month: str, group_ids: Sequence[str]): @api_function @classmethod async def usage_per_period(cls, group_id: str, start_date: str, end_date: str): - ''' + """ Get usage statistics for a group specified by `group_id` for time betweeen `start_date` and `end_date`. :param start_date: start date in string format (yyyymmdd). :param end_date: end date in string format (yyyymmdd). :param group_id: Groups ID to list usage statistics. - ''' - rqst = Request(cls.session, 'GET', '/resource/usage/period') + """ + rqst = Request(api_session.get(), 'GET', '/resource/usage/period') rqst.set_json({ 'group_id': group_id, 'start_date': start_date, @@ -89,37 +85,37 @@ async def usage_per_period(cls, group_id: str, start_date: str, end_date: str): @api_function @classmethod async def get_resource_slots(cls): - ''' + """ Get supported resource slots of Backend.AI server. - ''' - rqst = Request(cls.session, 'GET', '/config/resource-slots') + """ + rqst = Request(api_session.get(), 'GET', '/config/resource-slots') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def get_vfolder_types(cls): - rqst = Request(cls.session, 'GET', '/config/vfolder-types') + rqst = Request(api_session.get(), 'GET', '/config/vfolder-types') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def recalculate_usage(cls): - rqst = Request(cls.session, 'POST', '/resource/recalculate-usage') + rqst = Request(api_session.get(), 'POST', '/resource/recalculate-usage') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def user_monthly_stats(cls): - rqst = Request(cls.session, 'GET', '/resource/stats/user/month') + rqst = Request(api_session.get(), 'GET', '/resource/stats/user/month') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def admin_monthly_stats(cls): - rqst = Request(cls.session, 'GET', '/resource/stats/admin/month') + rqst = Request(api_session.get(), 'GET', '/resource/stats/admin/month') async with rqst.fetch() as resp: return await resp.json() diff --git a/src/ai/backend/client/func/scaling_group.py b/src/ai/backend/client/func/scaling_group.py index d31150db..bed1fe9c 100644 --- a/src/ai/backend/client/func/scaling_group.py +++ b/src/ai/backend/client/func/scaling_group.py @@ -2,25 +2,23 @@ import textwrap from typing import Iterable, Mapping, Sequence -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'ScalingGroup', ) -class ScalingGroup: - ''' +class ScalingGroup(BaseFunction): + """ Provides getting scaling-group information required for the current user. The scaling-group is an opaque server-side configuration which splits the whole cluster into several partitions, so that server administrators can apply different auto-scaling policies and operation standards to each partition of agent sets. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ def __init__(self, name: str): self.name = name @@ -28,11 +26,11 @@ def __init__(self, name: str): @api_function @classmethod async def list_available(cls, group: str): - ''' + """ List available scaling groups for the current user, considering the user, the user's domain, and the designated user group. - ''' - rqst = Request(cls.session, 'GET', '/scaling-groups', + """ + rqst = Request(api_session.get(), 'GET', '/scaling-groups', params={'group': group}) async with rqst.fetch() as resp: return await resp.json() @@ -40,25 +38,25 @@ async def list_available(cls, group: str): @api_function @classmethod async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ List available scaling groups for the current user, considering the user, the user's domain, and the designated user group. - ''' + """ if fields is None: fields = ('name', 'description', 'is_active', 'created_at', 'driver', 'driver_opts', 'scheduler', 'scheduler_opts',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($is_active: Boolean) { scaling_groups(is_active: $is_active) { $fields } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'is_active': None} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables @@ -70,25 +68,25 @@ async def list(cls, fields: Iterable[str] = None) -> Sequence[dict]: @api_function @classmethod async def detail(cls, name: str, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetch information of a scaling group by name. :param name: Name of the scaling group. :param fields: Additional per-scaling-group query fields. - ''' + """ if fields is None: fields = ('name', 'description', 'is_active', 'created_at', 'driver', 'driver_opts', 'scheduler', 'scheduler_opts',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($name: String) { scaling_group(name: $name) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'name': name} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -103,18 +101,18 @@ async def create(cls, name: str, description: str = '', is_active: bool = True, driver: str = None, driver_opts: Mapping[str, str] = None, scheduler: str = None, scheduler_opts: Mapping[str, str] = None, fields: Iterable[str] = None) -> dict: - ''' + """ Creates a new scaling group with the given options. - ''' + """ if fields is None: fields = ('name',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ mutation($name: String!, $input: ScalingGroupInput!) { create_scaling_group(name: $name, props: $input) { ok msg scaling_group {$fields} } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = { 'name': name, @@ -127,7 +125,7 @@ async def create(cls, name: str, description: str = '', is_active: bool = True, 'scheduler_opts': json.dumps(scheduler_opts), }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -142,18 +140,18 @@ async def update(cls, name: str, description: str = '', is_active: bool = True, driver: str = None, driver_opts: Mapping[str, str] = None, scheduler: str = None, scheduler_opts: Mapping[str, str] = None, fields: Iterable[str] = None) -> dict: - ''' + """ Update existing scaling group. - ''' + """ if fields is None: fields = ('name',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ mutation($name: String!, $input: ModifyScalingGroupInput!) { modify_scaling_group(name: $name, props: $input) { ok msg } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = { 'name': name, @@ -166,7 +164,7 @@ async def update(cls, name: str, description: str = '', is_active: bool = True, 'scheduler_opts': json.dumps(scheduler_opts), }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -178,18 +176,18 @@ async def update(cls, name: str, description: str = '', is_active: bool = True, @api_function @classmethod async def delete(cls, name: str): - ''' + """ Deletes an existing scaling group. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($name: String!) { delete_scaling_group(name: $name) { ok msg } } - ''') + """) variables = {'name': name} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -201,22 +199,22 @@ async def delete(cls, name: str): @api_function @classmethod async def associate_domain(cls, scaling_group: str, domain: str): - ''' + """ Associate scaling_group with domain. :param scaling_group: The name of a scaling group. :param domain: The name of a domain. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($scaling_group: String!, $domain: String!) { associate_scaling_group_with_domain( scaling_group: $scaling_group, domain: $domain) { ok msg } } - ''') + """) variables = {'scaling_group': scaling_group, 'domain': domain} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -228,22 +226,22 @@ async def associate_domain(cls, scaling_group: str, domain: str): @api_function @classmethod async def dissociate_domain(cls, scaling_group: str, domain: str): - ''' + """ Dissociate scaling_group from domain. :param scaling_group: The name of a scaling group. :param domain: The name of a domain. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($scaling_group: String!, $domain: String!) { disassociate_scaling_group_with_domain( scaling_group: $scaling_group, domain: $domain) { ok msg } } - ''') + """) variables = {'scaling_group': scaling_group, 'domain': domain} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -255,20 +253,20 @@ async def dissociate_domain(cls, scaling_group: str, domain: str): @api_function @classmethod async def dissociate_all_domain(cls, domain: str): - ''' + """ Dissociate all scaling_groups from domain. :param domain: The name of a domain. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($domain: String!) { disassociate_all_scaling_groups_with_domain(domain: $domain) { ok msg } } - ''') + """) variables = {'domain': domain} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -280,22 +278,22 @@ async def dissociate_all_domain(cls, domain: str): @api_function @classmethod async def associate_group(cls, scaling_group: str, group_id: str): - ''' + """ Associate scaling_group with group. :param scaling_group: The name of a scaling group. :param group_id: The ID of a group. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($scaling_group: String!, $user_group: String!) { associate_scaling_group_with_user_group( scaling_group: $scaling_group, user_group: $user_group) { ok msg } } - ''') + """) variables = {'scaling_group': scaling_group, 'user_group': group_id} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -307,22 +305,22 @@ async def associate_group(cls, scaling_group: str, group_id: str): @api_function @classmethod async def dissociate_group(cls, scaling_group: str, group_id: str): - ''' + """ Dissociate scaling_group from group. :param scaling_group: The name of a scaling group. :param group_id: The ID of a group. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($scaling_group: String!, $user_group: String!) { disassociate_scaling_group_with_user_group( scaling_group: $scaling_group, user_group: $user_group) { ok msg } } - ''') + """) variables = {'scaling_group': scaling_group, 'user_group': group_id} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -334,20 +332,20 @@ async def dissociate_group(cls, scaling_group: str, group_id: str): @api_function @classmethod async def dissociate_all_group(cls, group_id: str): - ''' + """ Dissociate all scaling_groups from group. :param group_id: The ID of a group. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($group_id: String!) { disassociate_all_scaling_groups_with_group(user_group: $group_id) { ok msg } } - ''') + """) variables = {'group_id': group_id} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, diff --git a/src/ai/backend/client/func/session.py b/src/ai/backend/client/func/session.py index afe0e603..bcf2e505 100644 --- a/src/ai/backend/client/func/session.py +++ b/src/ai/backend/client/func/session.py @@ -1,13 +1,16 @@ +from __future__ import annotations + import json import os import secrets import tarfile import tempfile from typing import ( - Iterable, Union, - AsyncGenerator, - Mapping, + Any, Iterable, Optional, Union, + AsyncIterator, + Mapping, Dict, Sequence, List, + cast, ) from pathlib import Path @@ -15,16 +18,19 @@ from aiohttp import hdrs from tqdm import tqdm -from .base import api_function +from .base import api_function, BaseFunction from ..compat import current_loop from ..config import DEFAULT_CHUNK_SIZE from ..exceptions import BackendClientError from ..request import ( Request, AttachedFile, WebSocketResponse, - SSEResponse, + SSEContextManager, + WebSocketContextManager, ) -from ..utils import undefined, ProgressReportingReader +from ..session import api_session +from ..utils import ProgressReportingReader +from ..types import Undefined, undefined from ..versioning import get_naming __all__ = ( @@ -32,18 +38,18 @@ ) -def drop(d, dropval): - newd = {} +def drop(d: Mapping[str, Any], value_to_drop: Any) -> Mapping[str, Any]: + modified: Dict[str, Any] = {} for k, v in d.items(): if isinstance(v, Mapping) or isinstance(v, dict): - newd[k] = drop(v, dropval) - elif v != dropval: - newd[k] = v - return newd + modified[k] = drop(v, value_to_drop) + elif v != value_to_drop: + modified[k] = v + return modified -class ComputeSession: - ''' +class ComputeSession(BaseFunction): + """ Provides various interactions with compute sessions in Backend.AI. The term 'kernel' is now deprecated and we prefer 'compute sessions'. @@ -55,15 +61,20 @@ class ComputeSession: So it is the user's responsibility to distribute uploaded files to multiple containers using explicit copies or virtual folders which are commonly mounted to all containers belonging to the same compute session. - ''' + """ - session = None - '''The client session instance that this function class is bound to.''' + name: str + owner_access_key: Optional[str] + created: bool + status: str + service_ports: List[str] + domain: str + group: str @api_function @classmethod async def hello(cls) -> str: - rqst = Request(cls.session, 'GET', '/') + rqst = Request(api_session.get(), 'GET', '/') async with rqst.fetch() as resp: return await resp.json() @@ -72,9 +83,9 @@ async def hello(cls) -> str: async def get_task_logs( cls, task_id: str, *, chunk_size: int = 8192 - ) -> AsyncGenerator[bytes, None]: - prefix = get_naming(cls.session.api_version, 'path') - rqst = Request(cls.session, 'GET', f'/{prefix}/_/logs', params={ + ) -> AsyncIterator[bytes]: + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request(api_session.get(), 'GET', f'/{prefix}/_/logs', params={ 'taskId': task_id, }) async with rqst.fetch() as resp: @@ -86,27 +97,30 @@ async def get_task_logs( @api_function @classmethod - async def get_or_create(cls, image: str, *, - name: str = None, - type_: str = 'interactive', - enqueue_only: bool = False, - max_wait: int = 0, - no_reuse: bool = False, - mounts: Iterable[str] = None, - mount_map: Mapping[str, str] = None, - envs: Mapping[str, str] = None, - startup_command: str = None, - resources: Mapping[str, int] = None, - resource_opts: Mapping[str, int] = None, - cluster_size: int = 1, - domain_name: str = None, - group_name: str = None, - bootstrap_script: str = None, - tag: str = None, - scaling_group: str = None, - owner_access_key: str = None, - preopen_ports: List[int] = None) -> 'ComputeSession': - ''' + async def get_or_create( + cls, + image: str, *, + name: str = None, + type_: str = 'interactive', + enqueue_only: bool = False, + max_wait: int = 0, + no_reuse: bool = False, + mounts: List[str] = None, + mount_map: Mapping[str, str] = None, + envs: Mapping[str, str] = None, + startup_command: str = None, + resources: Mapping[str, int] = None, + resource_opts: Mapping[str, int] = None, + cluster_size: int = 1, + domain_name: str = None, + group_name: str = None, + bootstrap_script: str = None, + tag: str = None, + scaling_group: str = None, + owner_access_key: str = None, + preopen_ports: List[int] = None, + ) -> ComputeSession: + """ Get-or-creates a compute session. If *name* is ``None``, it creates a new compute session as long as the server has enough resources and your API key has remaining quota. @@ -159,8 +173,8 @@ async def get_or_create(cls, image: str, *, available to administrators) :returns: The :class:`ComputeSession` instance. - ''' - if name: + """ + if name is not None: assert 4 <= len(name) <= 64, \ 'Client session token should be 4 to 64 characters long.' else: @@ -175,16 +189,16 @@ async def get_or_create(cls, image: str, *, resource_opts = {} if domain_name is None: # Even if config.domain is None, it can be guessed in the manager by user information. - domain_name = cls.session.config.domain + domain_name = api_session.get().config.domain if group_name is None: - group_name = cls.session.config.group + group_name = api_session.get().config.group - mounts.extend(cls.session.config.vfolder_mounts) - prefix = get_naming(cls.session.api_version, 'path') - rqst = Request(cls.session, 'POST', f'/{prefix}') - params = { + mounts.extend(api_session.get().config.vfolder_mounts) + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request(api_session.get(), 'POST', f'/{prefix}') + params: Dict[str, Any] = { 'tag': tag, - get_naming(cls.session.api_version, 'name_arg'): name, + get_naming(api_session.get().api_version, 'name_arg'): name, 'config': { 'mounts': mounts, 'environ': envs, @@ -194,7 +208,7 @@ async def get_or_create(cls, image: str, *, 'scalingGroup': scaling_group, }, } - if cls.session.api_version >= (5, '20191215'): + if api_session.get().api_version >= (5, '20191215'): params['config'].update({ 'mount_map': mount_map, 'preopen_ports': preopen_ports, @@ -202,7 +216,7 @@ async def get_or_create(cls, image: str, *, params.update({ 'bootstrap_script': bootstrap_script, }) - if cls.session.api_version >= (4, '20190615'): + if api_session.get().api_version >= (4, '20190615'): params.update({ 'owner_access_key': owner_access_key, 'domain': domain_name, @@ -213,7 +227,7 @@ async def get_or_create(cls, image: str, *, 'reuseIfExists': not no_reuse, 'startupCommand': startup_command, }) - if cls.session.api_version > (4, '20181215'): + if api_session.get().api_version > (4, '20181215'): params['image'] = image else: params['lang'] = image @@ -230,27 +244,30 @@ async def get_or_create(cls, image: str, *, @api_function @classmethod - async def create_from_template(cls, template_id: str, *, - name: str = undefined, - type_: str = undefined, - enqueue_only: bool = undefined, - max_wait: int = undefined, - no_reuse: bool = undefined, - image: str = undefined, - mounts: Iterable[str] = undefined, - mount_map: Mapping[str, str] = undefined, - envs: Mapping[str, str] = undefined, - startup_command: str = undefined, - resources: Mapping[str, int] = undefined, - resource_opts: Mapping[str, int] = undefined, - cluster_size: int = undefined, - domain_name: str = undefined, - group_name: str = undefined, - bootstrap_script: str = undefined, - tag: str = undefined, - scaling_group: str = undefined, - owner_access_key: str = undefined) -> 'ComputeSession': - ''' + async def create_from_template( + cls, + template_id: str, *, + name: Union[str, Undefined] = undefined, + type_: Union[str, Undefined] = undefined, + enqueue_only: Union[bool, Undefined] = undefined, + max_wait: Union[int, Undefined] = undefined, + no_reuse: Union[bool, Undefined] = undefined, + image: Union[str, Undefined] = undefined, + mounts: Union[List[str], Undefined] = undefined, + mount_map: Union[Mapping[str, str], Undefined] = undefined, + envs: Union[Mapping[str, str], Undefined] = undefined, + startup_command: Union[str, Undefined] = undefined, + resources: Union[Mapping[str, int], Undefined] = undefined, + resource_opts: Union[Mapping[str, int], Undefined] = undefined, + cluster_size: Union[int, Undefined] = undefined, + domain_name: Union[str, Undefined] = undefined, + group_name: Union[str, Undefined] = undefined, + bootstrap_script: Union[str, Undefined] = undefined, + tag: Union[str, Undefined] = undefined, + scaling_group: Union[str, Undefined] = undefined, + owner_access_key: Union[str, Undefined] = undefined, + ) -> ComputeSession: + """ Get-or-creates a compute session from template. All other parameters provided will be overwritten to template, including vfolder mounts (not appended!). @@ -306,29 +323,32 @@ async def create_from_template(cls, template_id: str, *, available to administrators) :returns: The :class:`ComputeSession` instance. - ''' - if name: + """ + if name is not undefined: assert 4 <= len(name) <= 64, \ 'Client session token should be 4 to 64 characters long.' else: name = f'pysdk-{secrets.token_urlsafe(8)}' - if domain_name is None: + if domain_name is undefined: # Even if config.domain is None, it can be guessed in the manager by user information. - domain_name = cls.session.config.domain - if group_name is None: - group_name = cls.session.config.group - if cls.session.config.vfolder_mounts: - mounts.extend(cls.session.config.vfolder_mounts) - prefix = get_naming(cls.session.api_version, 'path') - rqst = Request(cls.session, 'POST', f'/{prefix}/_/create-from-template') + domain_name = api_session.get().config.domain + if group_name is undefined: + group_name = api_session.get().config.group + if mounts is undefined: + mounts = [] + if api_session.get().config.vfolder_mounts: + mounts.extend(api_session.get().config.vfolder_mounts) + prefix = get_naming(api_session.get().api_version, 'path') + rqst = Request(api_session.get(), 'POST', f'/{prefix}/_/create-from-template') + params: Dict[str, Any] params = { 'template_id': template_id, 'tag': tag, 'image': image, 'domain': domain_name, 'group': group_name, - get_naming(cls.session.api_version, 'name_arg'): name, + get_naming(api_session.get().api_version, 'name_arg'): name, 'bootstrap_script': bootstrap_script, 'enqueueOnly': enqueue_only, 'maxWaitSeconds': max_wait, @@ -346,11 +366,11 @@ async def create_from_template(cls, template_id: str, *, 'scalingGroup': scaling_group, }, } - params = drop(params, undefined) + params = cast(Dict[str, Any], drop(params, undefined)) rqst.set_json(params) async with rqst.fetch() as resp: data = await resp.json() - o = cls(name, owner_access_key) + o = cls(name, owner_access_key if owner_access_key is not undefined else None) o.created = data.get('created', True) # True is for legacy o.status = data.get('status', 'RUNNING') o.service_ports = data.get('servicePorts', []) @@ -364,19 +384,19 @@ def __init__(self, name: str, owner_access_key: str = None): @api_function async def destroy(self, *, forced: bool = False): - ''' + """ Destroys the compute session. Since the server literally kills the container(s), all ongoing executions are forcibly interrupted. - ''' + """ params = {} - if self.owner_access_key: + if self.owner_access_key is not None: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') if forced: params['forced'] = 'true' rqst = Request( - self.session, + api_session.get(), 'DELETE', f'/{prefix}/{self.name}', params=params, ) @@ -386,17 +406,17 @@ async def destroy(self, *, forced: bool = False): @api_function async def restart(self): - ''' + """ Restarts the compute session. The server force-destroys the current running container(s), but keeps their temporary scratch directories intact. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'PATCH', f'/{prefix}/{self.name}', params=params, ) @@ -405,17 +425,17 @@ async def restart(self): @api_function async def interrupt(self): - ''' + """ Tries to interrupt the current ongoing code execution. This may fail without any explicit errors depending on the code being executed. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}/interrupt', params=params, ) @@ -424,7 +444,7 @@ async def interrupt(self): @api_function async def complete(self, code: str, opts: dict = None) -> Iterable[str]: - ''' + """ Gets the auto-completion candidates from the given code string, as if a user has pressed the tab key just after the code in IDEs. @@ -437,14 +457,14 @@ async def complete(self, code: str, opts: dict = None) -> Iterable[str]: such as row, col, line and the remainder text. :returns: An ordered list of strings. - ''' + """ opts = {} if opts is None else opts params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}/complete', params=params, ) @@ -462,15 +482,15 @@ async def complete(self, code: str, opts: dict = None) -> Iterable[str]: @api_function async def get_info(self): - ''' + """ Retrieves a brief information about the compute session. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'GET', f'/{prefix}/{self.name}', params=params, ) @@ -479,15 +499,15 @@ async def get_info(self): @api_function async def get_logs(self): - ''' + """ Retrieves the console log of the compute session container. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'GET', f'/{prefix}/{self.name}/logs', params=params, ) @@ -499,7 +519,7 @@ async def execute(self, run_id: str = None, code: str = None, mode: str = 'query', opts: dict = None): - ''' + """ Executes a code snippet directly in the compute session or sends a set of build/clean/execute commands to the compute session. @@ -521,17 +541,17 @@ async def execute(self, run_id: str = None, for details. :returns: :ref:`An execution result object ` - ''' + """ opts = opts if opts is not None else {} params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') if mode in {'query', 'continue', 'input'}: assert code is not None, \ 'The code argument must be a valid string even when empty.' rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}', params=params, ) @@ -542,7 +562,7 @@ async def execute(self, run_id: str = None, }) elif mode == 'batch': rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}', params=params, ) @@ -559,7 +579,7 @@ async def execute(self, run_id: str = None, }) elif mode == 'complete': rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}', params=params, ) @@ -581,7 +601,7 @@ async def execute(self, run_id: str = None, async def upload(self, files: Sequence[Union[str, Path]], basedir: Union[str, Path] = None, show_progress: bool = False): - ''' + """ Uploads the given list of files to the compute session. You may refer them in the batch-mode execution or from the code executed in the server afterwards. @@ -598,11 +618,11 @@ async def upload(self, files: Sequence[Union[str, Path]], :param basedir: The directory prefix where the files reside. The default value is the current working directory. :param show_progress: Displays a progress bar during uploads. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') base_path = ( Path.cwd() if basedir is None else Path(basedir).resolve() @@ -610,7 +630,7 @@ async def upload(self, files: Sequence[Union[str, Path]], files = [Path(file).resolve() for file in files] total_size = 0 for file_path in files: - total_size += file_path.stat().st_size + total_size += Path(file_path).stat().st_size tqdm_obj = tqdm(desc='Uploading files', unit='bytes', unit_scale=True, total=total_size, @@ -620,7 +640,7 @@ async def upload(self, files: Sequence[Union[str, Path]], for file_path in files: try: attachments.append(AttachedFile( - str(file_path.relative_to(base_path)), + str(Path(file_path).relative_to(base_path)), ProgressReportingReader(str(file_path), tqdm_instance=tqdm_obj), 'application/octet-stream', @@ -631,7 +651,7 @@ async def upload(self, files: Sequence[Union[str, Path]], raise ValueError(msg) from None rqst = Request( - self.session, + api_session.get(), 'POST', f'/{prefix}/{self.name}/upload', params=params, ) @@ -643,7 +663,7 @@ async def upload(self, files: Sequence[Union[str, Path]], async def download(self, files: Sequence[Union[str, Path]], dest: Union[str, Path] = '.', show_progress: bool = False): - ''' + """ Downloads the given list of files from the compute session. :param files: The list of file paths in the compute session. @@ -651,13 +671,13 @@ async def download(self, files: Sequence[Union[str, Path]], ``/home/work`` in the compute session container. :param dest: The destination directory in the client-side. :param show_progress: Displays a progress bar during downloads. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'GET', f'/{prefix}/{self.name}/download', params=params, ) @@ -674,7 +694,7 @@ async def download(self, files: Sequence[Union[str, Path]], reader = aiohttp.MultipartReader.from_response(resp.raw_response) with tqdm_obj as pbar: while True: - part = await reader.next() + part = cast(aiohttp.BodyPartReader, await reader.next()) if part is None: break assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() == 'identity' @@ -698,18 +718,18 @@ async def download(self, files: Sequence[Union[str, Path]], @api_function async def list_files(self, path: Union[str, Path] = '.'): - ''' + """ Gets the list of files in the given path inside the compute session container. :param path: The directory path in the compute session. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') rqst = Request( - self.session, + api_session.get(), 'GET', f'/{prefix}/{self.name}/files', params=params, ) @@ -724,68 +744,70 @@ async def stream_app_info(self): params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') api_rqst = Request( - self.session, + api_session.get(), 'GET', f'/stream/{prefix}/{self.name}/apps', params=params, ) async with api_rqst.fetch() as resp: return await resp.json() - # only supported in AsyncKernel - def stream_events(self) -> SSEResponse: - ''' + # only supported in AsyncAPISession + def listen_events(self) -> SSEContextManager: + """ Opens the stream of the kernel lifecycle events. Only the master kernel of each session is monitored. :returns: a :class:`StreamEvents` object. - ''' + """ params = { - get_naming(self.session.api_version, 'event_name_arg'): self.name, + get_naming(api_session.get().api_version, 'event_name_arg'): self.name, } if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + path = get_naming(api_session.get().api_version, 'session_events') request = Request( - self.session, - 'GET', f'/stream/{prefix}/_/events', + api_session.get(), + 'GET', path, params=params, ) return request.connect_events() - # only supported in AsyncKernel - def stream_pty(self) -> 'StreamPty': - ''' + stream_events = listen_events # legacy alias + + # only supported in AsyncAPISession + def stream_pty(self) -> WebSocketContextManager: + """ Opens a pseudo-terminal of the kernel (if supported) streamed via websockets. :returns: a :class:`StreamPty` object. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') request = Request( - self.session, + api_session.get(), 'GET', f'/stream/{prefix}/{self.name}/pty', params=params, ) return request.connect_websocket(response_cls=StreamPty) - # only supported in AsyncKernel + # only supported in AsyncAPISession def stream_execute(self, code: str = '', *, mode: str = 'query', - opts: dict = None) -> WebSocketResponse: - ''' + opts: dict = None) -> WebSocketContextManager: + """ Executes a code snippet in the streaming mode. Since the returned websocket represents a run loop, there is no need to specify *run_id* explicitly. - ''' + """ params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - prefix = get_naming(self.session.api_version, 'path') + prefix = get_naming(api_session.get().api_version, 'path') opts = {} if opts is None else opts if mode == 'query': opts = {} @@ -800,7 +822,7 @@ def stream_execute(self, code: str = '', *, msg = 'Invalid stream-execution mode: {0}'.format(mode) raise BackendClientError(msg) request = Request( - self.session, + api_session.get(), 'GET', f'/stream/{prefix}/{self.name}/execute', params=params, ) @@ -816,10 +838,10 @@ async def send_code(ws): class StreamPty(WebSocketResponse): - ''' + """ A derivative class of :class:`~ai.backend.client.request.WebSocketResponse` which provides additional functions to control the terminal. - ''' + """ __slots__ = ('ws', ) diff --git a/src/ai/backend/client/func/session_template.py b/src/ai/backend/client/func/session_template.py index d754e941..c3c9ad1c 100644 --- a/src/ai/backend/client/func/session_template.py +++ b/src/ai/backend/client/func/session_template.py @@ -1,12 +1,16 @@ -from typing import List, Mapping +from typing import Any, List, Mapping -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session +__all__ = ( + 'SessionTemplate', +) -class SessionTemplate: - session = None +class SessionTemplate(BaseFunction): + @api_function @classmethod async def create(cls, @@ -15,13 +19,13 @@ async def create(cls, group_name: str = None, owner_access_key: str = None, ) -> 'SessionTemplate': - rqst = Request(cls.session, + rqst = Request(api_session.get(), 'POST', '/template/session') if domain_name is None: # Even if config.domain is None, it can be guessed in the manager by user information. - domain_name = cls.session.config.domain + domain_name = api_session.get().config.domain if group_name is None: - group_name = cls.session.config.group + group_name = api_session.get().config.group body = { 'payload': template, 'group_name': group_name, @@ -30,20 +34,17 @@ async def create(cls, } rqst.set_json(body) async with rqst.fetch() as resp: - if resp.status == 200: - response = await resp.json() - - return cls(response['id'], owner_access_key=owner_access_key) + response = await resp.json() + return cls(response['id'], owner_access_key=owner_access_key) @api_function @classmethod - async def list_templates(cls, list_all: bool = False) -> 'List[Mapping[str, str]]': - rqst = Request(cls.session, + async def list_templates(cls, list_all: bool = False) -> List[Mapping[str, str]]: + rqst = Request(api_session.get(), 'GET', '/template/session') rqst.set_json({'all': list_all}) async with rqst.fetch() as resp: - if resp.status == 200: - return await resp.json() + return await resp.json() def __init__(self, template_id: str, owner_access_key: str = None): self.template_id = template_id @@ -54,35 +55,33 @@ async def get(self, body_format: str = 'yaml') -> str: params = {'format': body_format} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'GET', f'/template/session/{self.template_id}', params=params) async with rqst.fetch() as resp: - if resp.status == 200: - return await resp.text() + data = await resp.text() + return data @api_function - async def put(self, template: str): + async def put(self, template: str) -> Any: body = { 'payload': template } if self.owner_access_key: body['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'PUT', f'/template/session/{self.template_id}') rqst.set_json(body) - async with rqst.fetch() as resp: return await resp.json() @api_function - async def delete(self): + async def delete(self) -> Any: params = {} if self.owner_access_key: params['owner_access_key'] = self.owner_access_key - rqst = Request(self.session, + rqst = Request(api_session.get(), 'DELETE', f'/template/session/{self.template_id}', params=params) - async with rqst.fetch() as resp: return await resp.json() diff --git a/src/ai/backend/client/func/system.py b/src/ai/backend/client/func/system.py index b88a4ba1..3a9ed104 100644 --- a/src/ai/backend/client/func/system.py +++ b/src/ai/backend/client/func/system.py @@ -1,29 +1,30 @@ from typing import Mapping -from .base import api_function +from .base import api_function, BaseFunction from ..request import Request +from ..session import api_session __all__ = ( 'System', ) -class System: - ''' +class System(BaseFunction): + """ Provides the function interface for the API endpoint's system information. - ''' + """ @api_function @classmethod async def get_versions(cls) -> Mapping[str, str]: - rqst = Request(cls.session, 'GET', '/') + rqst = Request(api_session.get(), 'GET', '/') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def get_manager_version(cls) -> str: - rqst = Request(cls.session, 'GET', '/') + rqst = Request(api_session.get(), 'GET', '/') async with rqst.fetch() as resp: ret = await resp.json() return ret['manager'] @@ -31,7 +32,7 @@ async def get_manager_version(cls) -> str: @api_function @classmethod async def get_api_version(cls) -> str: - rqst = Request(cls.session, 'GET', '/') + rqst = Request(api_session.get(), 'GET', '/') async with rqst.fetch() as resp: ret = await resp.json() return ret['version'] diff --git a/src/ai/backend/client/func/user.py b/src/ai/backend/client/func/user.py index 7b0be7b2..bd4668a1 100644 --- a/src/ai/backend/client/func/user.py +++ b/src/ai/backend/client/func/user.py @@ -1,22 +1,20 @@ import textwrap from typing import Iterable, Sequence -from .base import api_function -from ..request import Request +from .base import api_function, BaseFunction from ..auth import AuthToken, AuthTokenTypes +from ..request import Request +from ..session import api_session __all__ = ( 'User', ) -class User: - ''' +class User(BaseFunction): + """ Provides interactions with users. - ''' - - session = None - '''The client session instance that this function class is bound to.''' + """ @api_function @classmethod @@ -30,10 +28,10 @@ async def authorize(cls, username: str, password: str, *, Its functionality will be expanded in the future to support multiple types of authentication methods. """ - rqst = Request(cls.session, 'POST', '/auth/authorize') + rqst = Request(api_session.get(), 'POST', '/auth/authorize') rqst.set_json({ 'type': token_type.value, - 'domain': cls.session.config.domain, + 'domain': api_session.get().config.domain, 'username': username, 'password': password, }) @@ -47,23 +45,23 @@ async def authorize(cls, username: str, password: str, *, @api_function @classmethod async def list(cls, is_active: bool = None, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetches the list of users. Domain admins can only get domain users. :param is_active: Fetches active or inactive users only if not None. :param fields: Additional per-user query fields to fetch. - ''' + """ if fields is None: fields = ('uuid', 'username', 'email', 'need_password_change', 'is_active', 'created_at', 'domain_name', 'role') - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($is_active: Boolean) { users(is_active: $is_active) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'is_active': is_active} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -75,31 +73,31 @@ async def list(cls, is_active: bool = None, fields: Iterable[str] = None) -> Seq @api_function @classmethod async def detail(cls, email: str = None, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetch information of a user. If email is not specified, requester's information will be returned. :param email: Email of the user to fetch. :param fields: Additional per-user query fields to fetch. - ''' + """ if fields is None: fields = ('uuid', 'username', 'email', 'need_password_change', 'is_active', 'created_at', 'domain_name', 'role') if email is None: - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query { user {$fields} } - ''') + """) else: - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($email: String) { user(email: $email) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'email': email} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') if email is None: rqst.set_json({ 'query': query, @@ -116,31 +114,31 @@ async def detail(cls, email: str = None, fields: Iterable[str] = None) -> Sequen @api_function @classmethod async def detail_by_uuid(cls, user_uuid: str = None, fields: Iterable[str] = None) -> Sequence[dict]: - ''' + """ Fetch information of a user by user's uuid. If user_uuid is not specified, requester's information will be returned. :param user_uuid: UUID of the user to fetch. :param fields: Additional per-user query fields to fetch. - ''' + """ if fields is None: fields = ('uuid', 'username', 'email', 'need_password_change', 'is_active', 'created_at', 'domain_name', 'role') if user_uuid is None: - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query { user {$fields} } - ''') + """) else: - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ query($user_id: String) { user_from_uuid(user_id: $user_id) {$fields} } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = {'user_id': user_uuid} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') if user_uuid is None: rqst.set_json({ 'query': query, @@ -163,19 +161,19 @@ async def create(cls, domain_name: str, email: str, password: str, description: str = '', group_ids: Iterable[str] = None, fields: Iterable[str] = None) -> dict: - ''' + """ Creates a new user with the given options. You need an admin privilege for this operation. - ''' + """ if fields is None: fields = ('domain_name', 'email', 'username',) - query = textwrap.dedent('''\ + query = textwrap.dedent("""\ mutation($email: String!, $input: UserInput!) { create_user(email: $email, props: $input) { ok msg user {$fields} } } - ''') + """) query = query.replace('$fields', ' '.join(fields)) variables = { 'email': email, @@ -191,7 +189,7 @@ async def create(cls, domain_name: str, email: str, password: str, 'group_ids': group_ids, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -207,17 +205,17 @@ async def update(cls, email: str, password: str = None, username: str = None, is_active: bool = None, need_password_change: bool = None, description: str = None, group_ids: Iterable[str] = None, fields: Iterable[str] = None) -> dict: - ''' + """ Update existing user. You need an admin privilege for this operation. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($email: String!, $input: ModifyUserInput!) { modify_user(email: $email, props: $input) { ok msg } } - ''') + """) variables = { 'email': email, 'input': { @@ -232,7 +230,7 @@ async def update(cls, email: str, password: str = None, username: str = None, 'group_ids': group_ids, }, } - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, @@ -244,18 +242,18 @@ async def update(cls, email: str, password: str = None, username: str = None, @api_function @classmethod async def delete(cls, email: str): - ''' + """ Deletes an existing user. - ''' - query = textwrap.dedent('''\ + """ + query = textwrap.dedent("""\ mutation($email: String!) { delete_user(email: $email) { ok msg } } - ''') + """) variables = {'email': email} - rqst = Request(cls.session, 'POST', '/admin/graphql') + rqst = Request(api_session.get(), 'POST', '/admin/graphql') rqst.set_json({ 'query': query, 'variables': variables, diff --git a/src/ai/backend/client/func/vfolder.py b/src/ai/backend/client/func/vfolder.py index 6ff38d42..050b812d 100644 --- a/src/ai/backend/client/func/vfolder.py +++ b/src/ai/backend/client/func/vfolder.py @@ -1,15 +1,20 @@ from pathlib import Path -from typing import Sequence, Union +from typing import ( + Union, + Sequence, List, + cast, +) import aiohttp from aiohttp import hdrs from tqdm import tqdm -from .base import api_function +from .base import api_function, BaseFunction from ..compat import current_loop from ..config import DEFAULT_CHUNK_SIZE from ..exceptions import BackendAPIError from ..request import Request, AttachedFile +from ..session import api_session from ..utils import ProgressReportingReader __all__ = ( @@ -17,18 +22,21 @@ ) -class VFolder: - - session = None - '''The client session instance that this function class is bound to.''' +class VFolder(BaseFunction): def __init__(self, name: str): self.name = name @api_function @classmethod - async def create(cls, name: str, host: str = None, unmanaged_path: str = None, group: str = None): - rqst = Request(cls.session, 'POST', '/folders') + async def create( + cls, + name: str, + host: str = None, + unmanaged_path: str = None, + group: str = None, + ): + rqst = Request(api_session.get(), 'POST', '/folders') rqst.set_json({ 'name': name, 'host': host, @@ -41,7 +49,7 @@ async def create(cls, name: str, host: str = None, unmanaged_path: str = None, g @api_function @classmethod async def delete_by_id(cls, oid): - rqst = Request(cls.session, 'DELETE', '/folders') + rqst = Request(api_session.get(), 'DELETE', '/folders') rqst.set_json({'id': oid}) async with rqst.fetch(): return {} @@ -49,7 +57,7 @@ async def delete_by_id(cls, oid): @api_function @classmethod async def list(cls, list_all=False): - rqst = Request(cls.session, 'GET', '/folders') + rqst = Request(api_session.get(), 'GET', '/folders') rqst.set_json({'all': list_all}) async with rqst.fetch() as resp: return await resp.json() @@ -57,39 +65,39 @@ async def list(cls, list_all=False): @api_function @classmethod async def list_hosts(cls): - rqst = Request(cls.session, 'GET', '/folders/_/hosts') + rqst = Request(api_session.get(), 'GET', '/folders/_/hosts') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def list_all_hosts(cls): - rqst = Request(cls.session, 'GET', '/folders/_/all_hosts') + rqst = Request(api_session.get(), 'GET', '/folders/_/all_hosts') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def list_allowed_types(cls): - rqst = Request(cls.session, 'GET', '/folders/_/allowed_types') + rqst = Request(api_session.get(), 'GET', '/folders/_/allowed_types') async with rqst.fetch() as resp: return await resp.json() @api_function async def info(self): - rqst = Request(self.session, 'GET', '/folders/{0}'.format(self.name)) + rqst = Request(api_session.get(), 'GET', '/folders/{0}'.format(self.name)) async with rqst.fetch() as resp: return await resp.json() @api_function async def delete(self): - rqst = Request(self.session, 'DELETE', '/folders/{0}'.format(self.name)) + rqst = Request(api_session.get(), 'DELETE', '/folders/{0}'.format(self.name)) async with rqst.fetch(): return {} @api_function async def rename(self, new_name): - rqst = Request(self.session, 'POST', '/folders/{0}/rename'.format(self.name)) + rqst = Request(api_session.get(), 'POST', '/folders/{0}/rename'.format(self.name)) rqst.set_json({ 'new_name': new_name, }) @@ -106,7 +114,7 @@ async def upload(self, files: Sequence[Union[str, Path]], files = [Path(file).resolve() for file in files] total_size = 0 for file_path in files: - total_size += file_path.stat().st_size + total_size += Path(file_path).stat().st_size tqdm_obj = tqdm(desc='Uploading files', unit='bytes', unit_scale=True, total=total_size, @@ -116,7 +124,7 @@ async def upload(self, files: Sequence[Union[str, Path]], for file_path in files: try: attachments.append(AttachedFile( - str(file_path.relative_to(base_path)), + str(Path(file_path).relative_to(base_path)), ProgressReportingReader(str(file_path), tqdm_instance=tqdm_obj), 'application/octet-stream', @@ -126,7 +134,7 @@ async def upload(self, files: Sequence[Union[str, Path]], .format(file_path, base_path) raise ValueError(msg) from None - rqst = Request(self.session, + rqst = Request(api_session.get(), 'POST', '/folders/{}/upload'.format(self.name)) rqst.attach_files(attachments) async with rqst.fetch() as resp: @@ -134,7 +142,7 @@ async def upload(self, files: Sequence[Union[str, Path]], @api_function async def mkdir(self, path: Union[str, Path]): - rqst = Request(self.session, 'POST', + rqst = Request(api_session.get(), 'POST', '/folders/{}/mkdir'.format(self.name)) rqst.set_json({ 'path': path, @@ -144,7 +152,7 @@ async def mkdir(self, path: Union[str, Path]): @api_function async def request_download(self, filename: Union[str, Path]): - rqst = Request(self.session, 'POST', + rqst = Request(api_session.get(), 'POST', '/folders/{}/request_download'.format(self.name)) rqst.set_json({ 'file': filename @@ -156,7 +164,7 @@ async def request_download(self, filename: Union[str, Path]): async def delete_files(self, files: Sequence[Union[str, Path]], recursive: bool = False): - rqst = Request(self.session, 'DELETE', + rqst = Request(api_session.get(), 'DELETE', '/folders/{}/delete_files'.format(self.name)) rqst.set_json({ 'files': files, @@ -169,12 +177,12 @@ async def delete_files(self, async def download(self, files: Sequence[Union[str, Path]], show_progress: bool = False): - rqst = Request(self.session, 'GET', + rqst = Request(api_session.get(), 'GET', '/folders/{}/download'.format(self.name)) rqst.set_json({ 'files': files, }) - file_names = [] + file_names: List[str] = [] async with rqst.fetch() as resp: if resp.status // 100 != 2: raise BackendAPIError(resp.status, resp.reason, @@ -189,7 +197,7 @@ async def download(self, files: Sequence[Union[str, Path]], loop = current_loop() acc_bytes = 0 while True: - part = await reader.next() + part = cast(aiohttp.BodyPartReader, await reader.next()) if part is None: break assert part.headers.get(hdrs.CONTENT_ENCODING, 'identity').lower() in ( @@ -211,7 +219,7 @@ async def download(self, files: Sequence[Union[str, Path]], @api_function async def list_files(self, path: Union[str, Path] = '.'): - rqst = Request(self.session, 'GET', '/folders/{}/files'.format(self.name)) + rqst = Request(api_session.get(), 'GET', '/folders/{}/files'.format(self.name)) rqst.set_json({ 'path': path, }) @@ -220,7 +228,7 @@ async def list_files(self, path: Union[str, Path] = '.'): @api_function async def invite(self, perm: str, emails: Sequence[str]): - rqst = Request(self.session, 'POST', '/folders/{}/invite'.format(self.name)) + rqst = Request(api_session.get(), 'POST', '/folders/{}/invite'.format(self.name)) rqst.set_json({ 'perm': perm, 'user_ids': emails, }) @@ -230,14 +238,14 @@ async def invite(self, perm: str, emails: Sequence[str]): @api_function @classmethod async def invitations(cls): - rqst = Request(cls.session, 'GET', '/folders/invitations/list') + rqst = Request(api_session.get(), 'GET', '/folders/invitations/list') async with rqst.fetch() as resp: return await resp.json() @api_function @classmethod async def accept_invitation(cls, inv_id: str): - rqst = Request(cls.session, 'POST', '/folders/invitations/accept') + rqst = Request(api_session.get(), 'POST', '/folders/invitations/accept') rqst.set_json({'inv_id': inv_id}) async with rqst.fetch() as resp: return await resp.json() @@ -245,7 +253,7 @@ async def accept_invitation(cls, inv_id: str): @api_function @classmethod async def delete_invitation(cls, inv_id: str): - rqst = Request(cls.session, 'DELETE', '/folders/invitations/delete') + rqst = Request(api_session.get(), 'DELETE', '/folders/invitations/delete') rqst.set_json({'inv_id': inv_id}) async with rqst.fetch() as resp: return await resp.json() @@ -253,7 +261,7 @@ async def delete_invitation(cls, inv_id: str): @api_function @classmethod async def get_fstab_contents(cls, agent_id=None): - rqst = Request(cls.session, 'GET', '/folders/_/fstab') + rqst = Request(api_session.get(), 'GET', '/folders/_/fstab') rqst.set_json({ 'agent_id': agent_id, }) @@ -263,7 +271,7 @@ async def get_fstab_contents(cls, agent_id=None): @api_function @classmethod async def list_mounts(cls): - rqst = Request(cls.session, 'GET', '/folders/_/mounts') + rqst = Request(api_session.get(), 'GET', '/folders/_/mounts') async with rqst.fetch() as resp: return await resp.json() @@ -271,7 +279,7 @@ async def list_mounts(cls): @classmethod async def mount_host(cls, name: str, fs_location: str, options=None, edit_fstab: bool = False): - rqst = Request(cls.session, 'POST', '/folders/_/mounts') + rqst = Request(api_session.get(), 'POST', '/folders/_/mounts') rqst.set_json({ 'name': name, 'fs_location': fs_location, @@ -284,7 +292,7 @@ async def mount_host(cls, name: str, fs_location: str, options=None, @api_function @classmethod async def umount_host(cls, name: str, edit_fstab: bool = False): - rqst = Request(cls.session, 'DELETE', '/folders/_/mounts') + rqst = Request(api_session.get(), 'DELETE', '/folders/_/mounts') rqst.set_json({ 'name': name, 'edit_fstab': edit_fstab, diff --git a/src/ai/backend/client/py.typed b/src/ai/backend/client/py.typed new file mode 100644 index 00000000..48cdce85 --- /dev/null +++ b/src/ai/backend/client/py.typed @@ -0,0 +1 @@ +placeholder diff --git a/src/ai/backend/client/request.py b/src/ai/backend/client/request.py index 3b7e55f2..6a0454d3 100644 --- a/src/ai/backend/client/request.py +++ b/src/ai/backend/client/request.py @@ -1,32 +1,46 @@ +from __future__ import annotations + +import asyncio from collections import OrderedDict, namedtuple from datetime import datetime from decimal import Decimal import functools import io import logging +import json as modjson from pathlib import Path import sys -from typing import Any, Callable, Mapping, Sequence, Union +from typing import ( + Any, Callable, Optional, Union, + Awaitable, AsyncIterator, Type, TypeVar, + Mapping, Sequence, List, + cast, +) import aiohttp from aiohttp.client import _RequestContextManager, _WSRequestContextManager import aiohttp.web import appdirs +import attr from dateutil.tz import tzutc from multidict import CIMultiDict -import json as modjson +from yarl import URL from .auth import generate_signature from .exceptions import BackendClientError, BackendAPIError -from .session import BaseSession, Session as SyncSession, AsyncSession +from .session import BaseSession, Session as SyncSession, AsyncSession, api_session log = logging.getLogger('ai.backend.client.request') __all__ = [ 'Request', + 'BaseResponse', 'Response', 'WebSocketResponse', 'SSEResponse', + 'FetchContextManager', + 'WebSocketContextManager', + 'SSEContextManager', 'AttachedFile', ] @@ -37,13 +51,13 @@ io.IOBase, None, ] -''' +""" The type alias for the set of allowed types for request content. -''' +""" AttachedFile = namedtuple('AttachedFile', 'filename stream content_type') -''' +""" A struct that represents an attached file to the API request. :param str filename: The name of file to store. It may include paths @@ -54,12 +68,19 @@ :param str content_type: The content type for the stream. For arbitrary binary data, use "application/octet-stream". -''' +""" + + +_T = TypeVar('_T') + + +async def _coro_return(val: _T) -> _T: + return val class ExtendedJSONEncoder(modjson.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, Path): return str(obj) if isinstance(obj, Decimal): @@ -68,9 +89,9 @@ def default(self, obj): class Request: - ''' + """ The API request object. - ''' + """ __slots__ = ( 'config', 'session', 'method', 'path', @@ -79,19 +100,27 @@ class Request: 'reporthook', ) + _content: RequestContent + _attached_files: Optional[Sequence[AttachedFile]] + + date: Optional[datetime] + _allowed_methods = frozenset([ 'GET', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE', 'OPTIONS']) - def __init__(self, session: BaseSession, - method: str = 'GET', - path: str = None, - content: RequestContent = None, *, - content_type: str = None, - params: Mapping[str, str] = None, - reporthook: Callable = None) -> None: - ''' + def __init__( + self, + session: BaseSession, + method: str = 'GET', + path: str = None, + content: RequestContent = None, *, + content_type: str = None, + params: Mapping[str, str] = None, + reporthook: Callable = None, + ) -> None: + """ Initialize an API request. :param BaseSession session: The session where this request is executed on. @@ -104,11 +133,11 @@ def __init__(self, session: BaseSession, :param str content_type: Explicitly set the content type. See also :func:`Request.set_content`. - ''' - self.session = session - self.config = session.config + """ + self.session = api_session.get() + self.config = self.session.config self.method = method - if path.startswith('/'): + if path is not None and path.startswith('/'): path = path[1:] self.path = path self.params = params @@ -124,18 +153,22 @@ def __init__(self, session: BaseSession, @property def content(self) -> RequestContent: - ''' + """ Retrieves the content in the original form. Private codes should NOT use this as it incurs duplicate encoding/decoding. - ''' + """ return self._content - def set_content(self, value: RequestContent, *, - content_type: str = None): - ''' + def set_content( + self, + value: RequestContent, + *, + content_type: str = None, + ) -> None: + """ Sets the content of the request. - ''' + """ assert self._attached_files is None, \ 'cannot set content because you already attached files.' guessed_content_type = 'application/octet-stream' @@ -151,28 +184,34 @@ def set_content(self, value: RequestContent, *, self.content_type = (content_type if content_type is not None else guessed_content_type) - def set_json(self, value: object): - ''' + def set_json(self, value: Any) -> None: + """ A shortcut for set_content() with JSON objects. - ''' + """ self.set_content(modjson.dumps(value, cls=ExtendedJSONEncoder), content_type='application/json') - def attach_files(self, files: Sequence[AttachedFile]): - ''' + def attach_files(self, files: Sequence[AttachedFile]) -> None: + """ Attach a list of files represented as AttachedFile. - ''' + """ assert not self._content, 'content must be empty to attach files.' self.content_type = 'multipart/form-data' self._attached_files = files - def _sign(self, rel_url, access_key=None, secret_key=None, hash_type=None): - ''' + def _sign( + self, + rel_url: URL, + access_key: str = None, + secret_key: str = None, + hash_type: str = None, + ) -> None: + """ Calculates the signature of the given request and adds the Authorization HTTP header. It should be called at the very end of request preparation and before sending the request to the server. - ''' + """ if access_key is None: access_key = self.config.access_key if secret_key is None: @@ -188,8 +227,8 @@ def _sign(self, rel_url, access_key=None, secret_key=None, hash_type=None): elif self.config.endpoint_type == 'session': local_state_path = Path(appdirs.user_state_dir('backend.ai', 'Lablup')) try: - self.session.aiohttp_session.cookie_jar.load( - local_state_path / 'cookie.dat') + cookie_jar = cast(aiohttp.CookieJar, self.session.aiohttp_session.cookie_jar) + cookie_jar.load(local_state_path / 'cookie.dat') except (IOError, PermissionError): pass else: @@ -211,9 +250,9 @@ def _pack_content(self): else: return self._content - def _build_url(self): + def _build_url(self) -> URL: base_url = self.config.endpoint.path.rstrip('/') - query_path = self.path.lstrip('/') if len(self.path) > 0 else '' + query_path = self.path.lstrip('/') if self.path is not None and len(self.path) > 0 else '' if self.config.endpoint_type == 'session': if not query_path.startswith('server'): query_path = 'func/{0}'.format(query_path) @@ -225,8 +264,8 @@ def _build_url(self): # TODO: attach rate-limit information - def fetch(self, **kwargs) -> 'FetchContextManager': - ''' + def fetch(self, **kwargs) -> FetchContextManager: + """ Sends the request to the server and reads the response. You may use this method either with plain synchronous Session or @@ -251,10 +290,11 @@ def fetch(self, **kwargs) -> 'FetchContextManager': rqst = Request(sess, 'GET', ...) async with rqst.fetch() as resp: print(await resp.text()) - ''' + """ assert self.method in self._allowed_methods, \ 'Disallowed HTTP method: {}'.format(self.method) self.date = datetime.now(tzutc()) + assert self.date is not None self.headers['Date'] = self.date.isoformat() if self.content_type is not None and 'Content-Type' not in self.headers: self.headers['Content-Type'] = self.content_type @@ -278,19 +318,20 @@ def _rqst_ctx_builder(): return FetchContextManager(self.session, _rqst_ctx_builder, **kwargs) - def connect_websocket(self, **kwargs) -> 'WebSocketContextManager': - ''' + def connect_websocket(self, **kwargs) -> WebSocketContextManager: + """ Creates a WebSocket connection. .. warning:: This method only works with :class:`~ai.backend.client.session.AsyncSession`. - ''' + """ assert isinstance(self.session, AsyncSession), \ 'Cannot use websockets with sessions in the synchronous mode' assert self.method == 'GET', 'Invalid websocket method' self.date = datetime.now(tzutc()) + assert self.date is not None self.headers['Date'] = self.date.isoformat() # websocket is always a "binary" stream. self.content_type = 'application/octet-stream' @@ -306,19 +347,20 @@ def _ws_ctx_builder(): return WebSocketContextManager(self.session, _ws_ctx_builder, **kwargs) - def connect_events(self, **kwargs) -> 'SSEContextManager': - ''' + def connect_events(self, **kwargs) -> SSEContextManager: + """ Creates a Server-Sent Events connection. .. warning:: This method only works with :class:`~ai.backend.client.session.AsyncSession`. - ''' + """ assert isinstance(self.session, AsyncSession), \ 'Cannot use event streams with sessions in the synchronous mode' assert self.method == 'GET', 'Invalid event stream method' self.date = datetime.now(tzutc()) + assert self.date is not None self.headers['Date'] = self.date.isoformat() self.content_type = 'application/octet-stream' @@ -340,8 +382,58 @@ def _rqst_ctx_builder(): return SSEContextManager(self.session, _rqst_ctx_builder, **kwargs) -class Response: - ''' +class AsyncResponseMixin: + + _session: BaseSession + _raw_response: aiohttp.ClientResponse + + async def text(self) -> str: + return await self._raw_response.text() + + async def json(self, *, loads=modjson.loads) -> Any: + loads = functools.partial(loads, object_pairs_hook=OrderedDict) + return await self._raw_response.json(loads=loads) + + async def read(self, n: int = -1) -> bytes: + return await self._raw_response.content.read(n) + + async def readall(self) -> bytes: + return await self._raw_response.content.read(-1) + + +class SyncResponseMixin: + + _session: BaseSession + _raw_response: aiohttp.ClientResponse + + def text(self) -> str: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.text() + ) + + def json(self, *, loads=modjson.loads) -> Any: + loads = functools.partial(loads, object_pairs_hook=OrderedDict) + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.json(loads=loads) + ) + + def read(self, n: int = -1) -> bytes: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.content.read(n) + ) + + def readall(self) -> bytes: + sync_session = cast(SyncSession, self._session) + return sync_session.worker_thread.execute( + self._raw_response.content.read(-1) + ) + + +class BaseResponse: + """ Represents the Backend.AI API response. Also serves as a high-level wrapper of :class:`aiohttp.ClientResponse`. @@ -349,15 +441,24 @@ class Response: :func:`text`, :func:`json` methods return the resolved content directly with plain synchronous Session while they return the coroutines with AsyncSession. - ''' + """ __slots__ = ( '_session', '_raw_response', '_async_mode', ) - def __init__(self, session: BaseSession, - underlying_response: aiohttp.ClientResponse, *, - async_mode: bool = False): + _session: BaseSession + _raw_response: aiohttp.ClientResponse + _async_mode: bool + + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + *, + async_mode: bool = False, + **kwargs, + ) -> None: self._session = session self._raw_response = underlying_response self._async_mode = async_mode @@ -372,7 +473,9 @@ def status(self) -> int: @property def reason(self) -> str: - return self._raw_response.reason + if self._raw_response.reason is not None: + return self._raw_response.reason + return '' @property def headers(self) -> Mapping[str, str]: @@ -387,46 +490,24 @@ def content_type(self) -> str: return self._raw_response.content_type @property - def content_length(self) -> int: + def content_length(self) -> Optional[int]: return self._raw_response.content_length @property def content(self) -> aiohttp.StreamReader: return self._raw_response.content - def text(self) -> str: - if self._async_mode: - return self._raw_response.text() - else: - return self._session.worker_thread.execute(self._raw_response.text()) - - def json(self, *, loads=modjson.loads) -> Any: - loads = functools.partial(loads, object_pairs_hook=OrderedDict) - if self._async_mode: - return self._raw_response.json(loads=loads) - else: - return self._session.worker_thread.execute( - self._raw_response.json(loads=loads)) - - def read(self, n=-1) -> bytes: - return self._session.worker_thread.execute(self.aread(n)) - async def aread(self, n=-1) -> bytes: - return await self._raw_response.content.read(n) - - def readall(self) -> bytes: - return self._session.worker_thread.execute(self._areadall()) - - async def areadall(self) -> bytes: - return await self._raw_response.content.read(-1) +class Response(AsyncResponseMixin, BaseResponse): + pass class FetchContextManager: - ''' + """ The context manager returned by :func:`Request.fetch`. - It provides both synchronouse and asynchronous contex manager interfaces. - ''' + It provides both synchronous and asynchronous context manager interfaces. + """ __slots__ = ( 'session', 'rqst_ctx_builder', 'response_cls', @@ -435,34 +516,36 @@ class FetchContextManager: '_rqst_ctx', ) - def __init__(self, session: BaseSession, - rqst_ctx_builder: Callable[[], _RequestContextManager], *, - response_cls: Response = Response, - check_status: bool = True): + _rqst_ctx: Optional[_RequestContextManager] + + def __init__( + self, + session: BaseSession, + rqst_ctx_builder: Callable[[], _RequestContextManager], + *, + response_cls: Type[Response] = Response, + check_status: bool = True, + ) -> None: self.session = session self.rqst_ctx_builder = rqst_ctx_builder - self.response_cls = response_cls self.check_status = check_status - self._async_mode = True + self.response_cls = response_cls + self._async_mode = isinstance(session, AsyncSession) self._rqst_ctx = None - def __enter__(self): - assert isinstance(self.session, SyncSession) - self._async_mode = False - return self.session.worker_thread.execute(self.__aenter__()) - - async def __aenter__(self): + async def __aenter__(self) -> Response: max_retries = len(self.session.config.endpoints) retry_count = 0 while True: try: retry_count += 1 self._rqst_ctx = self.rqst_ctx_builder() + assert self._rqst_ctx is not None raw_resp = await self._rqst_ctx.__aenter__() if self.check_status and raw_resp.status // 100 != 2: msg = await raw_resp.text() await raw_resp.__aexit__(None, None, None) - raise BackendAPIError(raw_resp.status, raw_resp.reason, msg) + raise BackendAPIError(raw_resp.status, raw_resp.reason or '', msg) return self.response_cls(self.session, raw_resp, async_mode=self._async_mode) except aiohttp.ClientConnectionError as e: @@ -480,62 +563,71 @@ async def __aenter__(self): await raw_resp.__aexit__(*sys.exc_info()) raise BackendClientError(msg) from e - def __exit__(self, *args): - return self.session.worker_thread.execute(self.__aexit__(*args)) - - async def __aexit__(self, *args): - ret = await self._rqst_ctx.__aexit__(*args) + async def __aexit__(self, *exc_info) -> Optional[bool]: + assert self._rqst_ctx is not None + ret = await self._rqst_ctx.__aexit__(*exc_info) self._rqst_ctx = None return ret -class WebSocketResponse: - ''' +class WebSocketResponse(BaseResponse): + """ A high-level wrapper of :class:`aiohttp.ClientWebSocketResponse`. - ''' + """ - __slots__ = ('_session', '_raw_ws', ) + __slots__ = ('_raw_ws', ) - def __init__(self, session: BaseSession, - underlying_ws: aiohttp.ClientWebSocketResponse): - self._session = session - self._raw_ws = underlying_ws + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + **kwargs, + ) -> None: + # Unfortunately, aiohttp.ClientWebSocketResponse is not a subclass of aiohttp.ClientResponse. + # Since we block methods that require ClientResponse-specific methods, we just force-typecast. + super().__init__(session, underlying_response, **kwargs) + self._raw_ws = cast(aiohttp.ClientWebSocketResponse, underlying_response) @property - def session(self) -> BaseSession: - return self._session + def content_type(self) -> str: + raise AttributeError("WebSocketResponse does not have an explicit content type.") + + @property + def content_length(self) -> Optional[int]: + raise AttributeError("WebSocketResponse does not have a fixed content length.") @property - def raw_weboscket(self) -> aiohttp.ClientWebSocketResponse: + def content(self) -> aiohttp.StreamReader: + raise AttributeError("WebSocketResponse does not support reading the content.") + + @property + def raw_websocket(self) -> aiohttp.ClientWebSocketResponse: return self._raw_ws @property - def closed(self): + def closed(self) -> bool: return self._raw_ws.closed - async def close(self): + async def close(self) -> None: await self._raw_ws.close() - def __aiter__(self): + def __aiter__(self) -> AsyncIterator[aiohttp.WSMessage]: return self._raw_ws.__aiter__() - async def __anext__(self): - return await self._raw_ws.__anext__() - - def exception(self): + def exception(self) -> Optional[BaseException]: return self._raw_ws.exception() - async def send_str(self, raw_str: str): + async def send_str(self, raw_str: str) -> None: if self._raw_ws.closed: raise aiohttp.ServerDisconnectedError('server disconnected') await self._raw_ws.send_str(raw_str) - async def send_json(self, obj: Any): + async def send_json(self, obj: Any) -> None: if self._raw_ws.closed: raise aiohttp.ServerDisconnectedError('server disconnected') await self._raw_ws.send_json(obj) - async def send_bytes(self, data: bytes): + async def send_bytes(self, data: bytes) -> None: if self._raw_ws.closed: raise aiohttp.ServerDisconnectedError('server disconnected') await self._raw_ws.send_bytes(data) @@ -557,9 +649,9 @@ async def receive_bytes(self) -> bytes: class WebSocketContextManager: - ''' + """ The context manager returned by :func:`Request.connect_websocket`. - ''' + """ __slots__ = ( 'session', 'ws_ctx_builder', 'response_cls', @@ -567,23 +659,30 @@ class WebSocketContextManager: '_ws_ctx', ) - def __init__(self, session: BaseSession, - ws_ctx_builder: Callable[[], _WSRequestContextManager], *, - on_enter: Callable = None, - response_cls: WebSocketResponse = WebSocketResponse): + _ws_ctx: Optional[_WSRequestContextManager] + + def __init__( + self, + session: BaseSession, + ws_ctx_builder: Callable[[], _WSRequestContextManager], + *, + on_enter: Callable = None, + response_cls: Type[WebSocketResponse] = WebSocketResponse, + ) -> None: self.session = session self.ws_ctx_builder = ws_ctx_builder self.response_cls = response_cls self.on_enter = on_enter self._ws_ctx = None - async def __aenter__(self): + async def __aenter__(self) -> WebSocketResponse: max_retries = len(self.session.config.endpoints) retry_count = 0 while True: try: retry_count += 1 self._ws_ctx = self.ws_ctx_builder() + assert self._ws_ctx is not None raw_ws = await self._ws_ctx.__aenter__() except aiohttp.ClientConnectionError as e: if retry_count == max_retries: @@ -601,68 +700,105 @@ async def __aenter__(self): else: break - wrapped_ws = self.response_cls(self.session, raw_ws) + wrapped_ws = self.response_cls(self.session, cast(aiohttp.ClientResponse, raw_ws)) if self.on_enter is not None: await self.on_enter(wrapped_ws) return wrapped_ws - async def __aexit__(self, *args): + async def __aexit__(self, *args) -> Optional[bool]: + assert self._ws_ctx is not None ret = await self._ws_ctx.__aexit__(*args) self._ws_ctx = None return ret -class SSEResponse(Response): +@attr.s(auto_attribs=True, slots=True, frozen=True) +class SSEMessage: + event: str + data: str + id: Optional[str] = None + retry: Optional[int] = None + + +class SSEResponse(BaseResponse): __slots__ = ( - '_session', '_raw_response', '_async_mode', - '_auto_reconnect', + '_auto_reconnect', '_retry', '_connector', ) - def __init__(self, session: BaseSession, - underlying_response: aiohttp.ClientResponse): - super().__init__(session, underlying_response, async_mode=True) - - async def fetch_events(self): - msg_lines = [] + def __init__( + self, + session: BaseSession, + underlying_response: aiohttp.ClientResponse, + *, + connector: Callable[[], Awaitable[aiohttp.ClientResponse]], + auto_reconnect: bool = True, + default_retry: int = 5, + **kwargs, + ) -> None: + super().__init__(session, underlying_response, async_mode=True, **kwargs) + self._auto_reconnect = auto_reconnect + self._retry = default_retry + self._connector = connector + + async def fetch_events(self) -> AsyncIterator[SSEMessage]: + msg_lines: List[str] = [] + server_closed = False while True: - line = await self._raw_response.content.readline() - if not line: + received_line = await self._raw_response.content.readline() + if not received_line: # connection closed - break - line = line.strip(b'\r\n') - if line.startswith(b':'): + if self._auto_reconnect and not server_closed: + await asyncio.sleep(self._retry) + self._raw_response = await self._connector() + continue + else: + break + received_line = received_line.strip(b'\r\n') + if received_line.startswith(b':'): # comment continue - if not line: + if not received_line: # message boundary if len(msg_lines) == 0: continue - evdata = { - 'event': 'message', - 'data': '', - } + event_type = 'message' + event_id = None + event_retry = None data_lines = [] try: - for line in msg_lines: - hdr, text = line.split(':', maxsplit=1) + for stored_line in msg_lines: + hdr, text = stored_line.split(':', maxsplit=1) text = text.lstrip(' ') if hdr == 'data': data_lines.append(text) elif hdr == 'event': - evdata['event'] = text + event_type = text elif hdr == 'id': - evdata['id'] = text + event_id = text elif hdr == 'retry': - evdata['retry'] = int(text) + event_retry = int(text) except (IndexError, ValueError): log.exception('SSEResponse: parsing-error') continue - evdata['data'] = '\n'.join(data_lines) + event_data = '\n'.join(data_lines) msg_lines.clear() - yield evdata + if event_retry is not None: + self._retry = event_retry + yield SSEMessage( + event=event_type, + data=event_data, + id=event_id, + retry=event_retry, + ) + if event_type == 'server_close': + server_closed = True + break else: - msg_lines.append(line.decode('utf-8')) + msg_lines.append(received_line.decode('utf-8')) + + def __aiter__(self) -> AsyncIterator[SSEMessage]: + return self.fetch_events() class SSEContextManager: @@ -672,26 +808,39 @@ class SSEContextManager: '_rqst_ctx', ) - def __init__(self, session: BaseSession, - rqst_ctx_builder: Callable[[], _RequestContextManager], *, - response_cls: SSEResponse = SSEResponse): + _rqst_ctx: Optional[_RequestContextManager] + + def __init__( + self, + session: BaseSession, + rqst_ctx_builder: Callable[[], _RequestContextManager], + *, + response_cls: Type[SSEResponse] = SSEResponse, + ) -> None: self.session = session self.rqst_ctx_builder = rqst_ctx_builder self.response_cls = response_cls self._rqst_ctx = None - async def __aenter__(self): + async def reconnect(self) -> aiohttp.ClientResponse: + if self._rqst_ctx is not None: + await self._rqst_ctx.__aexit__(None, None, None) + self._rqst_ctx = self.rqst_ctx_builder() + assert self._rqst_ctx is not None + raw_resp = await self._rqst_ctx.__aenter__() + if raw_resp.status // 100 != 2: + msg = await raw_resp.text() + raise BackendAPIError(raw_resp.status, raw_resp.reason or '', msg) + return raw_resp + + async def __aenter__(self) -> SSEResponse: max_retries = len(self.session.config.endpoints) retry_count = 0 while True: try: retry_count += 1 - self._rqst_ctx = self.rqst_ctx_builder() - raw_resp = await self._rqst_ctx.__aenter__() - if raw_resp.status // 100 != 2: - msg = await raw_resp.text() - raise BackendAPIError(raw_resp.status, raw_resp.reason, msg) - return self.response_cls(self.session, raw_resp) + raw_resp = await self.reconnect() + return self.response_cls(self.session, raw_resp, connector=self.reconnect) except aiohttp.ClientConnectionError as e: if retry_count == max_retries: msg = 'Request to the API endpoint has failed.\n' \ @@ -706,7 +855,8 @@ async def __aenter__(self): '\u279c {!r}'.format(e) raise BackendClientError(msg) from e - async def __aexit__(self, *args): + async def __aexit__(self, *args) -> Optional[bool]: + assert self._rqst_ctx is not None ret = await self._rqst_ctx.__aexit__(*args) self._rqst_ctx = None return ret diff --git a/src/ai/backend/client/session.py b/src/ai/backend/client/session.py index 1df15dad..e8db2a62 100644 --- a/src/ai/backend/client/session.py +++ b/src/ai/backend/client/session.py @@ -1,7 +1,17 @@ +from __future__ import annotations + import abc import asyncio +from contextvars import Context, ContextVar, copy_context import threading -from typing import Tuple +from typing import ( + Any, + Awaitable, + Coroutine, + Tuple, + Union, +) +from typing_extensions import Literal # for Python 3.7 import queue import warnings @@ -10,16 +20,21 @@ from .config import APIConfig, get_config, parse_api_version from .exceptions import APIVersionWarning +from .types import Sentinel, sentinel __all__ = ( 'BaseSession', 'Session', 'AsyncSession', + 'api_session', ) -def is_legacy_server(): +api_session: ContextVar[BaseSession] = ContextVar('api_session') + + +def is_legacy_server() -> bool: """ Determine execution mode. @@ -63,41 +78,44 @@ async def _negotiate_api_version( return client_version -async def _close_aiohttp_session(session: aiohttp.ClientSession): +async def _close_aiohttp_session(session: aiohttp.ClientSession) -> None: # This is a hacky workaround for premature closing of SSL transports # on Windows Proactor event loops. # Thanks to Vadim Markovtsev's comment on the aiohttp issue #1925. # (https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-592596034) transports = 0 all_is_lost = asyncio.Event() - if len(session.connector._conns) == 0: + if session.connector is None: all_is_lost.set() - for conn in session.connector._conns.values(): - for handler, _ in conn: - proto = getattr(handler.transport, "_ssl_protocol", None) - if proto is None: - continue - transports += 1 - orig_lost = proto.connection_lost - orig_eof_received = proto.eof_received - - def connection_lost(exc): - orig_lost(exc) - nonlocal transports - transports -= 1 - if transports == 0: - all_is_lost.set() - - def eof_received(): - try: - orig_eof_received() - except AttributeError: - # It may happen that eof_received() is called after - # _app_protocol and _transport are set to None. - pass - - proto.connection_lost = connection_lost - proto.eof_received = eof_received + else: + if len(session.connector._conns) == 0: + all_is_lost.set() + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + def connection_lost(exc): + orig_lost(exc) + nonlocal transports + transports -= 1 + if transports == 0: + all_is_lost.set() + + def eof_received(): + try: + orig_eof_received() + except AttributeError: + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + proto.connection_lost = connection_lost + proto.eof_received = eof_received await session.close() if transports > 0: await all_is_lost.wait() @@ -105,23 +123,25 @@ def eof_received(): class _SyncWorkerThread(threading.Thread): - sentinel = object() + work_queue: queue.Queue[Union[Tuple[Coroutine, Context], Sentinel]] + done_queue: queue.Queue[Any] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.work_queue = queue.Queue() self.done_queue = queue.Queue() - def run(self): + def run(self) -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: while True: - coro = self.work_queue.get() - if coro is self.sentinel: + item = self.work_queue.get() + if item is sentinel: break + coro, ctx = item try: - result = loop.run_until_complete(coro) + result = ctx.run(loop.run_until_complete, coro) except Exception as e: self.done_queue.put_nowait(e) else: @@ -132,13 +152,17 @@ def run(self): finally: loop.stop() - def execute(self, coro): - self.work_queue.put(coro) - result = self.done_queue.get() - self.done_queue.task_done() - if isinstance(result, Exception): - raise result - return result + def execute(self, coro: Coroutine) -> Any: + ctx = copy_context() # preserve context for the worker thread + try: + self.work_queue.put((coro, ctx)) + result = self.done_queue.get() + self.done_queue.task_done() + if isinstance(result, Exception): + raise result + return result + finally: + del ctx class BaseSession(metaclass=abc.ABCMeta): @@ -147,26 +171,75 @@ class BaseSession(metaclass=abc.ABCMeta): """ __slots__ = ( - '_config', '_closed', 'aiohttp_session', - 'api_version', + '_config', '_closed', '_context_token', + 'aiohttp_session', 'api_version', 'System', 'Manager', 'Admin', 'Agent', 'AgentWatcher', 'ScalingGroup', 'Image', 'ComputeSession', 'SessionTemplate', 'Domain', 'Group', 'Auth', 'User', 'KeyPair', + 'BackgroundTask', 'EtcdConfig', 'Resource', 'KeypairResourcePolicy', - 'VFolder', 'Dotfile' + 'VFolder', 'Dotfile', ) aiohttp_session: aiohttp.ClientSession api_version: Tuple[int, str] - def __init__(self, *, config: APIConfig = None): + def __init__(self, *, config: APIConfig = None) -> None: self._closed = False self._config = config if config else get_config() + from .func.system import System + from .func.admin import Admin + from .func.agent import Agent, AgentWatcher + from .func.auth import Auth + from .func.bgtask import BackgroundTask + from .func.domain import Domain + from .func.etcd import EtcdConfig + from .func.group import Group + from .func.image import Image + from .func.session import ComputeSession + from .func.keypair import KeyPair + from .func.manager import Manager + from .func.resource import Resource + from .func.keypair_resource_policy import KeypairResourcePolicy + from .func.scaling_group import ScalingGroup + from .func.session_template import SessionTemplate + from .func.user import User + from .func.vfolder import VFolder + from .func.dotfile import Dotfile + + self.System = System + self.Admin = Admin + self.Agent = Agent + self.AgentWatcher = AgentWatcher + self.Auth = Auth + self.BackgroundTask = BackgroundTask + self.EtcdConfig = EtcdConfig + self.Domain = Domain + self.Group = Group + self.Image = Image + self.ComputeSession = ComputeSession + self.KeyPair = KeyPair + self.Manager = Manager + self.Resource = Resource + self.KeypairResourcePolicy = KeypairResourcePolicy + self.User = User + self.ScalingGroup = ScalingGroup + self.SessionTemplate = SessionTemplate + self.VFolder = VFolder + self.Dotfile = Dotfile + + @abc.abstractmethod + def open(self) -> Union[None, Awaitable[None]]: + """ + Initializes the session and perform version negotiation. + """ + raise NotImplementedError + @abc.abstractmethod - def close(self): + def close(self) -> Union[None, Awaitable[None]]: """ Terminates the session and releases underlying resources. """ @@ -180,22 +253,33 @@ def closed(self) -> bool: return self._closed @property - def config(self): + def config(self) -> APIConfig: """ The configuration used by this session object. """ return self._config + def __enter__(self) -> BaseSession: + raise NotImplementedError + + def __exit__(self, *exc_info) -> Literal[False]: + return False + + async def __aenter__(self) -> BaseSession: + raise NotImplementedError + + async def __aexit__(self, *exc_info) -> Literal[False]: + return False + class Session(BaseSession): """ - An API client session that makes API requests synchronously. - You may call (almost) all function proxy methods like a plain Python function. - It provides a context manager interface to ensure closing of the session - upon errors and scope exits. + A context manager for API client sessions that makes API requests synchronously. + You may call simple request-response APIs like a plain Python function, + but cannot use streaming APIs based on WebSocket and Server-Sent Events. """ - __slots__ = BaseSession.__slots__ + ( + __slots__ = ( '_worker_thread', ) @@ -213,440 +297,81 @@ async def _create_aiohttp_session() -> aiohttp.ClientSession: self.aiohttp_session = self.worker_thread.execute(_create_aiohttp_session()) - from .func.base import BaseFunction - from .func.system import System - from .func.admin import Admin - from .func.agent import Agent, AgentWatcher - from .func.auth import Auth - from .func.etcd import EtcdConfig - from .func.domain import Domain - from .func.group import Group - from .func.image import Image - from .func.session import ComputeSession - from .func.keypair import KeyPair - from .func.manager import Manager - from .func.resource import Resource - from .func.keypair_resource_policy import KeypairResourcePolicy - from .func.scaling_group import ScalingGroup - from .func.session_template import SessionTemplate - from .func.user import User - from .func.vfolder import VFolder - from .func.dotfile import Dotfile + def open(self) -> None: + self._context_token = api_session.set(self) + self.api_version = self.worker_thread.execute( + _negotiate_api_version(self.aiohttp_session, self.config)) - self.System = type('System', (BaseFunction, ), { - **System.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.system.System` function proxy - bound to this session. - ''' - - self.Admin = type('Admin', (BaseFunction, ), { - **Admin.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.admin.Admin` function proxy - bound to this session. - ''' - - self.Agent = type('Agent', (BaseFunction, ), { - **Agent.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.Agent` function proxy - bound to this session. - ''' - - self.AgentWatcher = type('AgentWatcher', (BaseFunction, ), { - **AgentWatcher.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.AgentWatcher` function proxy - bound to this session. - ''' - - self.Auth = type('Auth', (BaseFunction, ), { - **Auth.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.Auth` function proxy - bound to this session. - ''' - - self.EtcdConfig = type('EtcdConfig', (BaseFunction, ), { - **EtcdConfig.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.EtcdConfig` function proxy - bound to this session. - ''' - - self.Domain = type('Domain', (BaseFunction, ), { - **Domain.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.Domain` function proxy - bound to this session. - ''' - - self.Group = type('Group', (BaseFunction, ), { - **Group.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.Group` function proxy - bound to this session. - ''' - - self.Image = type('Image', (BaseFunction, ), { - **Image.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.image.Image` function proxy - bound to this session. - ''' - - self.ComputeSession = type('ComputeSession', (BaseFunction, ), { - **ComputeSession.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.kernel.ComputeSession` function proxy - bound to this session. - ''' - - self.KeyPair = type('KeyPair', (BaseFunction, ), { - **KeyPair.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.keypair.KeyPair` function proxy - bound to this session. - ''' - - self.Manager = type('Manager', (BaseFunction, ), { - **Manager.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.manager.Manager` function proxy - bound to this session. - ''' - - self.Resource = type('Resource', (BaseFunction, ), { - **Resource.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.resource.Resource` function proxy - bound to this session. - ''' - - self.KeypairResourcePolicy = type('KeypairResourcePolicy', (BaseFunction, ), { - **KeypairResourcePolicy.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.keypair_resource_policy.KeypairResourcePolicy` function proxy - bound to this session. - ''' - - self.User = type('User', (BaseFunction, ), { - **User.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.user.User` function proxy - bound to this session. - ''' - - self.ScalingGroup = type('ScalingGroup', (BaseFunction, ), { - **ScalingGroup.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.scaling_group.ScalingGroup` function proxy - bound to this session. - ''' - - self.SessionTemplate = type('SessionTemplate', (BaseFunction, ), { - **SessionTemplate.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.session_template.SessionTemplate` function proxy - bound to this session. - ''' - - self.VFolder = type('VFolder', (BaseFunction, ), { - **VFolder.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.vfolder.VFolder` function proxy - bound to this session. - ''' - - self.Dotfile = type('Dotfile', (BaseFunction, ), { - **Dotfile.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.dotfile.Dotfile` function proxy - bound to this session. - ''' - - def close(self): - ''' + def close(self) -> None: + """ Terminates the session. It schedules the ``close()`` coroutine of the underlying aiohttp session and then enqueues a sentinel object to indicate termination. Then it waits until the worker thread to self-terminate by joining. - ''' + """ if self._closed: return self._closed = True - self._worker_thread.work_queue.put(_close_aiohttp_session(self.aiohttp_session)) - self._worker_thread.work_queue.put(self.worker_thread.sentinel) + self._worker_thread.execute(_close_aiohttp_session(self.aiohttp_session)) + self._worker_thread.work_queue.put(sentinel) self._worker_thread.join() + api_session.reset(self._context_token) @property def worker_thread(self): - ''' + """ The thread that internally executes the asynchronous implementations of the given API functions. - ''' + """ return self._worker_thread - def __enter__(self): + def __enter__(self) -> Session: assert not self.closed, 'Cannot reuse closed session' - self.api_version = self.worker_thread.execute( - _negotiate_api_version(self.aiohttp_session, self.config)) + self.open() return self - def __exit__(self, exc_type, exc_obj, exc_tb): + def __exit__(self, *exc_info) -> Literal[False]: self.close() - return False + return False # raise up the inner exception class AsyncSession(BaseSession): - ''' - An API client session that makes API requests asynchronously using coroutines. - You may call all function proxy methods like a coroutine. - It provides an async context manager interface to ensure closing of the session - upon errors and scope exits. - ''' - - __slots__ = BaseSession.__slots__ + () + """ + A context manager for API client sessions that makes API requests asynchronously. + You may call all APIs as coroutines. + WebSocket-based APIs and SSE-based APIs returns special response types. + """ def __init__(self, *, config: APIConfig = None): super().__init__(config=config) - ssl = None if self._config.skip_sslcert_validation: ssl = False connector = aiohttp.TCPConnector(ssl=ssl) self.aiohttp_session = aiohttp.ClientSession(connector=connector) - from .func.base import BaseFunction - from .func.system import System - from .func.admin import Admin - from .func.agent import Agent, AgentWatcher - from .func.auth import Auth - from .func.etcd import EtcdConfig - from .func.group import Group - from .func.image import Image - from .func.session import ComputeSession - from .func.keypair import KeyPair - from .func.manager import Manager - from .func.resource import Resource - from .func.keypair_resource_policy import KeypairResourcePolicy - from .func.scaling_group import ScalingGroup - from .func.session_template import SessionTemplate - from .func.user import User - from .func.vfolder import VFolder - from .func.dotfile import Dotfile + async def _aopen(self) -> None: + self._context_token = api_session.set(self) + self.api_version = await _negotiate_api_version(self.aiohttp_session, self.config) + + def open(self) -> Awaitable[None]: + return self._aopen() - self.System = type('System', (BaseFunction, ), { - **System.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.system.System` function proxy - bound to this session. - ''' - - self.Admin = type('Admin', (BaseFunction, ), { - **Admin.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.admin.Admin` function proxy - bound to this session. - ''' - - self.Agent = type('Agent', (BaseFunction, ), { - **Agent.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.Agent` function proxy - bound to this session. - ''' - - self.AgentWatcher = type('AgentWatcher', (BaseFunction, ), { - **AgentWatcher.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.AgentWatcher` function proxy - bound to this session. - ''' - - self.Auth = type('Auth', (BaseFunction, ), { - **Auth.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.Auth` function proxy - bound to this session. - ''' - - self.EtcdConfig = type('EtcdConfig', (BaseFunction, ), { - **EtcdConfig.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.EtcdConfig` function proxy - bound to this session. - ''' - - self.Group = type('Group', (BaseFunction, ), { - **Group.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.agent.Group` function proxy - bound to this session. - ''' - - self.Image = type('Image', (BaseFunction, ), { - **Image.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.image.Image` function proxy - bound to this session. - ''' - - self.ComputeSession = type('ComputeSession', (BaseFunction, ), { - **ComputeSession.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.kernel.ComputeSession` function proxy - bound to this session. - ''' - - self.KeyPair = type('KeyPair', (BaseFunction, ), { - **KeyPair.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.keypair.KeyPair` function proxy - bound to this session. - ''' - - self.Manager = type('Manager', (BaseFunction, ), { - **Manager.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.manager.Manager` function proxy - bound to this session. - ''' - - self.Resource = type('Resource', (BaseFunction, ), { - **Resource.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.resource.Resource` function proxy - bound to this session. - ''' - - self.KeypairResourcePolicy = type('KeypairResourcePolicy', (BaseFunction, ), { - **KeypairResourcePolicy.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.keypair_resource_policy.KeypairResourcePolicy` function proxy - bound to this session. - ''' - - self.User = type('User', (BaseFunction, ), { - **User.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.user.User` function proxy - bound to this session. - ''' - - self.ScalingGroup = type('ScalingGroup', (BaseFunction, ), { - **ScalingGroup.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.scaling_group.ScalingGroup` function proxy - bound to this session. - ''' - - self.SessionTemplate = type('SessionTemplate', (BaseFunction, ), { - **SessionTemplate.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.session_template.SessionTemplate` function proxy - bound to this session. - ''' - - self.VFolder = type('VFolder', (BaseFunction, ), { - **VFolder.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.vfolder.VFolder` function proxy - bound to this session. - ''' - self.Dotfile = type('Dotfile', (BaseFunction, ), { - **Dotfile.__dict__, - 'session': self, - }) - ''' - The :class:`~ai.backend.client.dotfile.Dotfile` function proxy - bound to this session. - ''' - - async def close(self): + async def _aclose(self) -> None: if self._closed: return self._closed = True await _close_aiohttp_session(self.aiohttp_session) + api_session.reset(self._context_token) - async def __aenter__(self): + def close(self) -> Awaitable[None]: + return self._aclose() + + async def __aenter__(self) -> AsyncSession: assert not self.closed, 'Cannot reuse closed session' - self.api_version = await _negotiate_api_version(self.aiohttp_session, self.config) + await self.open() return self - async def __aexit__(self, exc_type, exc_obj, exc_tb): + async def __aexit__(self, *exc_info) -> Literal[False]: await self.close() - return False + return False # raise up the inner exception diff --git a/src/ai/backend/client/test_utils.py b/src/ai/backend/client/test_utils.py index df24aa79..e739e0e6 100644 --- a/src/ai/backend/client/test_utils.py +++ b/src/ai/backend/client/test_utils.py @@ -1,6 +1,6 @@ -''' +""" A support module to async mocks in Python versiosn prior to 3.8. -''' +""" from unittest import mock try: @@ -8,11 +8,11 @@ # Python 3.8 also adds magic-mocking async iterators and async context managers. from unittest.mock import AsyncMock except ImportError: - from asynctest import CoroutineMock as AsyncMock # noqa + from asynctest import CoroutineMock as AsyncMock # type: ignore class AsyncContextMock(mock.Mock): - ''' + """ Provides a mock that can be used: async with mock(): @@ -34,7 +34,7 @@ class AsyncContextMock(mock.Mock): # resp.status is 200 result = await resp.json() # result is {'hello': 'world'} - ''' + """ async def __aenter__(self): return self @@ -44,12 +44,12 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class AsyncContextMagicMock(mock.MagicMock): - ''' + """ Provides a magic mock that can be used: async with mock(): ... - ''' + """ async def __aenter__(self): return self @@ -59,7 +59,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class AsyncContextCoroutineMock(AsyncMock): - ''' + """ Provides a mock that can be used: async with (await mock(...)): @@ -81,7 +81,7 @@ class AsyncContextCoroutineMock(AsyncMock): # resp.status is 200 result = await resp.json() # result is {'hello': 'world'} - ''' + """ async def __aenter__(self): return self diff --git a/src/ai/backend/client/types.py b/src/ai/backend/client/types.py new file mode 100644 index 00000000..85149b15 --- /dev/null +++ b/src/ai/backend/client/types.py @@ -0,0 +1,19 @@ +import enum + + +class Sentinel(enum.Enum): + """ + A special type to represent a special value to indicate closing/shutdown of queues. + """ + token = 0 + + +class Undefined(enum.Enum): + """ + A special type to represent an undefined value. + """ + token = 0 + + +sentinel = Sentinel.token +undefined = Undefined.token diff --git a/src/ai/backend/client/utils.py b/src/ai/backend/client/utils.py index 2749c0ec..4d4abe89 100644 --- a/src/ai/backend/client/utils.py +++ b/src/ai/backend/client/utils.py @@ -4,25 +4,6 @@ from tqdm import tqdm -class Singleton(type): - _instances = {} - - def __call__(cls, *args, **kwargs): - if cls not in cls._instances: - cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) - return cls._instances[cls] - - -class Undefined(metaclass=Singleton): - ''' - A special type to represent an undefined value. - ''' - pass - - -undefined = Undefined() - - class ProgressReportingReader(io.BufferedReader): def __init__(self, file_path, *, tqdm_instance=None): diff --git a/src/ai/backend/client/versioning.py b/src/ai/backend/client/versioning.py index f25c8448..1c1f7867 100644 --- a/src/ai/backend/client/versioning.py +++ b/src/ai/backend/client/versioning.py @@ -6,6 +6,7 @@ naming_profile = { 'path': ('kernel', 'session'), + 'session_events': ('/stream/kernel/_/events', '/events/session'), 'name_arg': ('clientSessionToken', 'name'), 'event_name_arg': ('sessionId', 'name'), 'name_gql_field': ('sess_id', 'session_name'), diff --git a/tests/cli/test_cli_proxy.py b/tests/cli/test_cli_proxy.py index 37e6d2a8..c51983fb 100644 --- a/tests/cli/test_cli_proxy.py +++ b/tests/cli/test_cli_proxy.py @@ -7,7 +7,8 @@ @pytest.fixture -def api_app(event_loop): +async def api_app_fixture(unused_tcp_port_factory): + api_port = unused_tcp_port_factory() app = web.Application() recv_queue = [] @@ -34,55 +35,46 @@ async def echo_web(request): app.router.add_route('GET', r'/stream/echo', echo_ws) app.router.add_route('POST', r'/echo', echo_web) runner = web.AppRunner(app) - - async def start(port): - await runner.setup() - site = web.TCPSite(runner, '127.0.0.1', port) - await site.start() - return app, recv_queue - - async def shutdown(): - await runner.cleanup() - + await runner.setup() + site = web.TCPSite(runner, '127.0.0.1', api_port) + await site.start() try: - yield start + yield app, recv_queue, api_port finally: - event_loop.run_until_complete(shutdown()) + await runner.cleanup() @pytest.fixture -def proxy_app(event_loop): +async def proxy_app_fixture(unused_tcp_port_factory): app = create_proxy_app() runner = web.AppRunner(app) - - async def start(port): - await runner.setup() - site = web.TCPSite(runner, '127.0.0.1', port) - await site.start() - return app - - async def shutdown(): - await runner.cleanup() - + proxy_port = unused_tcp_port_factory() + await runner.setup() + site = web.TCPSite(runner, '127.0.0.1', proxy_port) + await site.start() try: - yield start + yield app, proxy_port finally: - event_loop.run_until_complete(shutdown()) + await runner.cleanup() +@pytest.mark.xfail( + reason="pytest-dev/pytest-asyncio#153 should be resolved to make this test working" +) @pytest.mark.asyncio -async def test_proxy_web(monkeypatch, example_keypair, api_app, proxy_app, - unused_tcp_port_factory): - api_port = unused_tcp_port_factory() +async def test_proxy_web( + monkeypatch, example_keypair, + api_app_fixture, + proxy_app_fixture, +): + api_app, recv_queue, api_port = api_app_fixture api_url = 'http://127.0.0.1:{}'.format(api_port) monkeypatch.setenv('BACKEND_ACCESS_KEY', example_keypair[0]) monkeypatch.setenv('BACKEND_SECRET_KEY', example_keypair[1]) monkeypatch.setenv('BACKEND_ENDPOINT', api_url) monkeypatch.setattr(config, '_config', config.APIConfig()) - api_app, recv_queue = await api_app(api_port) + proxy_app, proxy_port = proxy_app_fixture proxy_client = aiohttp.ClientSession() - proxy_port = unused_tcp_port_factory() - proxy_app = await proxy_app(proxy_port) proxy_url = 'http://127.0.0.1:{}'.format(proxy_port) data = {"test": 1234} async with proxy_client.request('POST', proxy_url + '/echo', @@ -93,9 +85,15 @@ async def test_proxy_web(monkeypatch, example_keypair, api_app, proxy_app, assert ret['test'] == 1234 +@pytest.mark.xfail( + reason="pytest-dev/pytest-asyncio#153 should be resolved to make this test working" +) @pytest.mark.asyncio -async def test_proxy_web_502(monkeypatch, example_keypair, proxy_app, - unused_tcp_port_factory): +async def test_proxy_web_502( + monkeypatch, example_keypair, + proxy_app_fixture, + unused_tcp_port_factory, +): api_port = unused_tcp_port_factory() api_url = 'http://127.0.0.1:{}'.format(api_port) monkeypatch.setenv('BACKEND_ACCESS_KEY', example_keypair[0]) @@ -104,8 +102,7 @@ async def test_proxy_web_502(monkeypatch, example_keypair, proxy_app, monkeypatch.setattr(config, '_config', config.APIConfig()) # Skip creation of api_app; let the proxy use a non-existent server. proxy_client = aiohttp.ClientSession() - proxy_port = unused_tcp_port_factory() - proxy_app = await proxy_app(proxy_port) + proxy_app, proxy_port = proxy_app_fixture proxy_url = 'http://127.0.0.1:{}'.format(proxy_port) data = {"test": 1234} async with proxy_client.request('POST', proxy_url + '/echo', @@ -114,19 +111,23 @@ async def test_proxy_web_502(monkeypatch, example_keypair, proxy_app, assert resp.reason == 'Bad Gateway' +@pytest.mark.xfail( + reason="pytest-dev/pytest-asyncio#153 should be resolved to make this test working" +) @pytest.mark.asyncio -async def test_proxy_websocket(monkeypatch, example_keypair, api_app, proxy_app, - unused_tcp_port_factory): - api_port = unused_tcp_port_factory() +async def test_proxy_websocket( + monkeypatch, example_keypair, + api_app_fixture, + proxy_app_fixture, +): + api_app, recv_queue, api_port = api_app_fixture api_url = 'http://127.0.0.1:{}'.format(api_port) monkeypatch.setenv('BACKEND_ACCESS_KEY', example_keypair[0]) monkeypatch.setenv('BACKEND_SECRET_KEY', example_keypair[1]) monkeypatch.setenv('BACKEND_ENDPOINT', api_url) monkeypatch.setattr(config, '_config', config.APIConfig()) - api_app, recv_queue = await api_app(api_port) proxy_client = aiohttp.ClientSession() - proxy_port = unused_tcp_port_factory() - proxy_app = await proxy_app(proxy_port) + proxy_app, proxy_port = proxy_app_fixture proxy_url = 'http://127.0.0.1:{}'.format(proxy_port) ws = await proxy_client.ws_connect(proxy_url + '/stream/echo') await ws.send_str('test') diff --git a/tests/test_kernel.py b/tests/test_kernel.py index a4880bc8..6b7b6efe 100644 --- a/tests/test_kernel.py +++ b/tests/test_kernel.py @@ -4,7 +4,7 @@ import pytest from ai.backend.client.config import APIConfig -from ai.backend.client.session import Session +from ai.backend.client.session import api_session, Session from ai.backend.client.versioning import get_naming from ai.backend.client.test_utils import AsyncContextMock, AsyncMock @@ -43,12 +43,13 @@ def test_create_with_config(mocker, api_version): else: assert prefix == 'session' assert session.config is myconfig - cs = session.ComputeSession.get_or_create('python') + session.ComputeSession.get_or_create('python') mock_req.assert_called_once_with(session, 'POST', f'/{prefix}') - assert str(cs.session.config.endpoint) == 'https://localhost:9999' - assert cs.session.config.user_agent == 'BAIClientTest' - assert cs.session.config.access_key == '1234' - assert cs.session.config.secret_key == 'asdf' + current_api_session = api_session.get() + assert str(current_api_session.config.endpoint) == 'https://localhost:9999' + assert current_api_session.config.user_agent == 'BAIClientTest' + assert current_api_session.config.access_key == '1234' + assert current_api_session.config.secret_key == 'asdf' def test_create_kernel_url(mocker): diff --git a/tests/test_request.py b/tests/test_request.py index 7d003bda..2c0f8a83 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -114,16 +114,18 @@ def test_build_correct_url(mock_request_params): assert str(rqst._build_url()) == canonical_url -def test_fetch_invalid_method(mock_request_params): +@pytest.mark.asyncio +async def test_fetch_invalid_method(mock_request_params): mock_request_params['method'] = 'STRANGE' rqst = Request(**mock_request_params) with pytest.raises(AssertionError): - with rqst.fetch(): + async with rqst.fetch(): pass -def test_fetch(dummy_endpoint): +@pytest.mark.asyncio +async def test_fetch(dummy_endpoint): with aioresponses() as m, Session() as session: body = b'hello world' m.post( @@ -132,11 +134,11 @@ def test_fetch(dummy_endpoint): 'Content-Length': str(len(body))}, ) rqst = Request(session, 'POST', 'function') - with rqst.fetch() as resp: + async with rqst.fetch() as resp: assert isinstance(resp, Response) assert resp.status == 200 assert resp.content_type == 'text/plain' - assert resp.text() == body.decode() + assert await resp.text() == body.decode() assert resp.content_length == len(body) with aioresponses() as m, Session() as session: @@ -147,16 +149,17 @@ def test_fetch(dummy_endpoint): 'Content-Length': str(len(body))}, ) rqst = Request(session, 'POST', 'function') - with rqst.fetch() as resp: + async with rqst.fetch() as resp: assert isinstance(resp, Response) assert resp.status == 200 assert resp.content_type == 'application/json' - assert resp.text() == body.decode() - assert resp.json() == {'a': 1234, 'b': None} + assert await resp.text() == body.decode() + assert await resp.json() == {'a': 1234, 'b': None} assert resp.content_length == len(body) -def test_streaming_fetch(dummy_endpoint): +@pytest.mark.asyncio +async def test_streaming_fetch(dummy_endpoint): # Read content by chunks. with aioresponses() as m, Session() as session: body = b'hello world' @@ -166,17 +169,18 @@ def test_streaming_fetch(dummy_endpoint): 'Content-Length': str(len(body))}, ) rqst = Request(session, 'POST', 'function') - with rqst.fetch() as resp: + async with rqst.fetch() as resp: assert resp.status == 200 assert resp.content_type == 'text/plain' - assert resp.read(3) == b'hel' - assert resp.read(2) == b'lo' - resp.read() + assert await resp.read(3) == b'hel' + assert await resp.read(2) == b'lo' + await resp.read() with pytest.raises(AssertionError): - assert resp.text() + assert await resp.text() -def test_invalid_requests(dummy_endpoint): +@pytest.mark.asyncio +async def test_invalid_requests(dummy_endpoint): with aioresponses() as m, Session() as session: body = json.dumps({ 'type': 'https://api.backend.ai/probs/kernel-not-found', @@ -189,7 +193,7 @@ def test_invalid_requests(dummy_endpoint): ) rqst = Request(session, 'POST', '/') with pytest.raises(BackendAPIError) as e: - with rqst.fetch(): + async with rqst.fetch(): pass assert e.status == 404 assert e.data['type'] == \ @@ -244,21 +248,6 @@ async def test_fetch_timeout_async(dummy_endpoint): pass -def test_response_sync(defconfig, dummy_endpoint): - body = b'{"test": 1234}' - with aioresponses() as m: - m.post( - dummy_endpoint + 'function', status=200, body=body, - headers={'Content-Type': 'application/json', - 'Content-Length': str(len(body))}, - ) - with Session(config=defconfig) as session: - rqst = Request(session, 'POST', '/function') - with rqst.fetch() as resp: - assert resp.text() == '{"test": 1234}' - assert resp.json() == {'test': 1234} - - @pytest.mark.asyncio async def test_response_async(defconfig, dummy_endpoint): body = b'{"test": 5678}'