Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/entitysdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@

from entitysdk.client import Client
from entitysdk.common import ProjectContext
from entitysdk.store import LocalAssetStore

__all__ = ["Client", "ProjectContext"]
__all__ = ["Client", "ProjectContext", "LocalAssetStore"]
62 changes: 15 additions & 47 deletions src/entitysdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from entitysdk.models.entity import Entity
from entitysdk.result import IteratorResult
from entitysdk.schemas.asset import DownloadedAssetFile
from entitysdk.store import LocalAssetStore
from entitysdk.token_manager import TokenFromValue, TokenManager
from entitysdk.types import (
ID,
Expand All @@ -33,8 +34,6 @@
)
from entitysdk.util import (
build_api_url,
create_intermediate_directories,
validate_filename_extension_consistency,
)
from entitysdk.utils.asset import filter_assets

Expand All @@ -53,6 +52,7 @@ def __init__(
http_client: httpx.Client | None = None,
token_manager: TokenManager | Token,
environment: DeploymentEnvironment | str | None = None,
local_store: LocalAssetStore | None = None,
) -> None:
"""Initialize client.

Expand All @@ -62,6 +62,7 @@ def __init__(
http_client: Optional HTTP client to use.
token_manager: Token manager or token to be used for authentication.
environment: Deployment environent.
local_store: LocalAssetStore object for using a local store.
"""
try:
environment = DeploymentEnvironment(environment) if environment else None
Expand All @@ -80,6 +81,7 @@ def __init__(
self._token_manager = (
TokenFromValue(token_manager) if isinstance(token_manager, Token) else token_manager
)
self._local_store = local_store

@staticmethod
def _handle_api_url(api_url: str | None, environment: DeploymentEnvironment | None) -> str:
Expand Down Expand Up @@ -525,22 +527,17 @@ def download_content(
Returns:
Asset content in bytes.
"""
url = (
route.get_assets_endpoint(
api_url=self.api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_id,
)
+ "/download"
)
context = self._optional_user_context(override_context=project_context)
return core.download_asset_content(
url=url,
project_context=context,
api_url=self.api_url,
asset_id=asset_id,
entity_id=entity_id,
entity_type=entity_type,
asset_path=asset_path,
project_context=context,
http_client=self._http_client,
token=self._token_manager.get_token(),
local_store=self._local_store,
)

def download_file(
Expand All @@ -567,46 +564,17 @@ def download_file(
Output file path.
"""
context = self._optional_user_context(override_context=project_context)

asset_endpoint = route.get_assets_endpoint(
return core.download_asset_file(
api_url=self.api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_id if isinstance(asset_id, ID) else asset_id.id,
)

if isinstance(asset_id, ID):
asset = core.get_entity(
asset_endpoint,
entity_type=Asset,
project_context=context,
http_client=self._http_client,
token=self._token_manager.get_token(),
)
else:
asset = asset_id

path: Path = Path(output_path)
if asset.is_directory:
if not asset_path:
raise EntitySDKError("Directory from directories require an `asset_path`")
else:
if asset_path:
raise EntitySDKError("Cannot pass `asset_path` to non-directories")

path = (
path / asset.path
if path.is_dir()
else validate_filename_extension_consistency(path, Path(asset.path).suffix)
)
create_intermediate_directories(path)
return core.download_asset_file(
url=f"{asset_endpoint}/download",
entity_type=entity_type,
asset_or_id=asset_id,
project_context=context,
asset_path=asset_path,
output_path=path,
output_path=Path(output_path),
http_client=self._http_client,
token=self._token_manager.get_token(),
local_store=self._local_store,
)

@staticmethod
Expand Down
106 changes: 97 additions & 9 deletions src/entitysdk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,14 @@
from entitysdk.models.entity import Entity
from entitysdk.result import IteratorResult
from entitysdk.route import get_assets_endpoint, get_entity_derivations_endpoint
from entitysdk.store import LocalAssetStore
from entitysdk.types import ID, AssetLabel, DerivationType
from entitysdk.util import make_db_api_request, stream_paginated_request
from entitysdk.util import (
create_intermediate_directories,
make_db_api_request,
stream_paginated_request,
validate_filename_extension_consistency,
)

L = logging.getLogger(__name__)

Expand Down Expand Up @@ -354,60 +360,142 @@ def list_directory(


def download_asset_file(
url: str,
*,
api_url: str,
entity_id: ID,
entity_type: type[Identifiable],
asset_or_id: ID | Asset,
output_path: Path,
asset_path: os.PathLike | None = None,
project_context: ProjectContext | None = None,
token: str,
http_client: httpx.Client | None = None,
local_store: LocalAssetStore | None = None,
) -> Path:
"""Download asset file to a file path.

Args:
url: URL of the asset.
api_url: The API URL to entitycore service.
entity_id: Resource id
entity_type: Resource type
asset_or_id: Asset id or asset instance
output_path: Path to save the file to.
asset_path: for asset directories, the path within the directory to the file
project_context: Project context.
token: Authorization access token.
http_client: HTTP client.
local_store: LocalAssetStore for using a local store.

Returns:
Output file path.
"""
asset_endpoint = get_assets_endpoint(
api_url=api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_or_id if isinstance(asset_or_id, ID) else asset_or_id.id,
)

if isinstance(asset_or_id, ID):
asset = get_entity(
asset_endpoint,
entity_type=Asset,
project_context=project_context,
http_client=http_client,
token=token,
)
else:
asset = asset_or_id

target_path: Path = Path(output_path)
source_path: Path = Path(asset.full_path)
if asset.is_directory:
if not asset_path:
raise EntitySDKError("Directory from directories require an `asset_path`")
source_path /= asset_path
else:
if asset_path:
raise EntitySDKError("Cannot pass `asset_path` to non-directories")

target_path = (
target_path / asset.path
if target_path.is_dir()
else validate_filename_extension_consistency(target_path, Path(asset.path).suffix)
)

create_intermediate_directories(target_path)

if local_store and local_store.path_exists(source_path):
local_store.link_path(source_path, target_path)
return target_path

bytes_content = download_asset_content(
url=url,
api_url=api_url,
asset_id=asset.id,
entity_id=entity_id,
entity_type=entity_type,
asset_path=asset_path,
project_context=project_context,
token=token,
http_client=http_client,
)
output_path.write_bytes(bytes_content)
return output_path
target_path.write_bytes(bytes_content)
return target_path


def download_asset_content(
url: str,
*,
api_url: str,
entity_id: ID,
entity_type: type[Identifiable],
asset_id: ID,
asset_path: os.PathLike | None = None,
project_context: ProjectContext | None = None,
token: str,
http_client: httpx.Client | None = None,
local_store: LocalAssetStore | None = None,
) -> bytes:
"""Download asset content.

Args:
url: URL of the asset.
api_url: The API URL to entitycore service.
entity_id: Resource id
entity_type: Resource type
asset_id: Asset id
asset_path: for asset directories, the path within the directory to the file
project_context: Project context.
token: Authorization access token.
http_client: HTTP client.
local_store: LocalAssetStore for using a local store.

Returns:
Asset content in bytes.
"""
asset_endpoint = get_assets_endpoint(
api_url=api_url,
entity_type=entity_type,
entity_id=entity_id,
asset_id=asset_id,
)

if local_store:
asset = get_entity(
asset_endpoint,
entity_type=Asset,
project_context=project_context,
http_client=http_client,
token=token,
)
if local_store.path_exists(asset.full_path):
path = asset.full_path
if asset.is_directory:
path = f"{path}/{asset_path}"
return local_store.read_bytes(path)

download_endpoint = f"{asset_endpoint}/download"

response = make_db_api_request(
url=url,
url=download_endpoint,
method="GET",
parameters={"asset_path": str(asset_path)} if asset_path else {},
project_context=project_context,
Expand Down
35 changes: 35 additions & 0 deletions src/entitysdk/store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Local asset store module."""

from dataclasses import dataclass
from pathlib import Path

from entitysdk.exception import EntitySDKError


@dataclass
class LocalAssetStore:
"""Class for locally stored asset data."""

prefix: Path

def __post_init__(self):
"""Post init."""
if not Path(self.prefix).exists():
raise EntitySDKError(f"Mount prefix path '{self.prefix}' does not exist")

def _local_path(self, path: str | Path) -> Path:
"""Return path from within the store."""
return Path(self.prefix, path)

def path_exists(self, path: str | Path) -> bool:
"""Return True if path exists in the store."""
return self._local_path(path).exists()

def link_path(self, source: str | Path, target: str | Path) -> Path:
"""Create a soft link from source to target."""
Path(target).symlink_to(self._local_path(source))
return Path(target)

def read_bytes(self, path: str | Path) -> bytes:
"""Read file from local store."""
return self._local_path(path).read_bytes()
2 changes: 1 addition & 1 deletion tests/unit/downloaders/test_memodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_download_memodel(
method="GET",
url=f"{api_url}/cell-morphology/{morph_id}/assets/{morph_asset_id}",
match_headers=request_headers,
json=_mock_morph_asset_response(morph_id) | {"path": "foo.asc"},
json=_mock_morph_asset_response(morph_asset_id) | {"path": "foo.asc"},
)
httpx_mock.add_response(
method="GET",
Expand Down
Loading