Skip to content

Commit ae08790

Browse files
Use git xet transfer to check if xet is enabled (#3381)
* use git xet transfer to check if xet is enabled * nit * fix syntax * fix * Apply suggestions from code review Co-authored-by: Lucain <[email protected]> * fix docstring * add transfers to docstring --------- Co-authored-by: Lucain <[email protected]>
1 parent 8e8a425 commit ae08790

File tree

4 files changed

+244
-159
lines changed

4 files changed

+244
-159
lines changed

src/huggingface_hub/_commit_api.py

Lines changed: 125 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
validate_hf_hub_args,
3434
)
3535
from .utils import tqdm as hf_tqdm
36+
from .utils._runtime import is_xet_available
3637

3738

3839
if TYPE_CHECKING:
@@ -353,7 +354,7 @@ def _warn_on_overwriting_operations(operations: List[CommitOperation]) -> None:
353354

354355

355356
@validate_hf_hub_args
356-
def _upload_lfs_files(
357+
def _upload_files(
357358
*,
358359
additions: List[CommitOperationAdd],
359360
repo_type: str,
@@ -362,6 +363,86 @@ def _upload_lfs_files(
362363
endpoint: Optional[str] = None,
363364
num_threads: int = 5,
364365
revision: Optional[str] = None,
366+
create_pr: Optional[bool] = None,
367+
):
368+
"""
369+
Negotiates per-file transfer (LFS vs Xet) and uploads in batches.
370+
"""
371+
xet_additions: List[CommitOperationAdd] = []
372+
lfs_actions: List[Dict] = []
373+
lfs_oid2addop: Dict[str, CommitOperationAdd] = {}
374+
375+
for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES):
376+
chunk_list = [op for op in chunk]
377+
378+
transfers: List[str] = ["basic", "multipart"]
379+
has_buffered_io_data = any(isinstance(op.path_or_fileobj, io.BufferedIOBase) for op in chunk_list)
380+
if is_xet_available():
381+
if not has_buffered_io_data:
382+
transfers.append("xet")
383+
else:
384+
logger.warning(
385+
"Uploading files as a binary IO buffer is not supported by Xet Storage. "
386+
"Falling back to HTTP upload."
387+
)
388+
389+
actions_chunk, errors_chunk, chosen_transfer = post_lfs_batch_info(
390+
upload_infos=[op.upload_info for op in chunk_list],
391+
repo_id=repo_id,
392+
repo_type=repo_type,
393+
revision=revision,
394+
endpoint=endpoint,
395+
headers=headers,
396+
token=None, # already passed in 'headers'
397+
transfers=transfers,
398+
)
399+
if errors_chunk:
400+
message = "\n".join(
401+
[
402+
f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
403+
for err in errors_chunk
404+
]
405+
)
406+
raise ValueError(f"LFS batch API returned errors:\n{message}")
407+
408+
# If server returns a transfer we didn't offer (e.g "xet" while uploading from BytesIO),
409+
# fall back to LFS for this chunk.
410+
if chosen_transfer == "xet" and ("xet" in transfers):
411+
xet_additions.extend(chunk_list)
412+
else:
413+
lfs_actions.extend(actions_chunk)
414+
for op in chunk_list:
415+
lfs_oid2addop[op.upload_info.sha256.hex()] = op
416+
417+
if len(lfs_actions) > 0:
418+
_upload_lfs_files(
419+
actions=lfs_actions,
420+
oid2addop=lfs_oid2addop,
421+
headers=headers,
422+
endpoint=endpoint,
423+
num_threads=num_threads,
424+
)
425+
426+
if len(xet_additions) > 0:
427+
_upload_xet_files(
428+
additions=xet_additions,
429+
repo_type=repo_type,
430+
repo_id=repo_id,
431+
headers=headers,
432+
endpoint=endpoint,
433+
revision=revision,
434+
create_pr=create_pr,
435+
)
436+
437+
438+
@validate_hf_hub_args
439+
def _upload_lfs_files(
440+
*,
441+
actions: List[Dict],
442+
oid2addop: Dict[str, CommitOperationAdd],
443+
headers: Dict[str, str],
444+
endpoint: Optional[str] = None,
445+
num_threads: int = 5,
365446
):
366447
"""
367448
Uploads the content of `additions` to the Hub using the large file storage protocol.
@@ -370,9 +451,21 @@ def _upload_lfs_files(
370451
- LFS Batch API: https://github.com/git-lfs/git-lfs/blob/main/docs/api/batch.md
371452
372453
Args:
373-
additions (`List` of `CommitOperationAdd`):
374-
The files to be uploaded
375-
repo_type (`str`):
454+
actions (`List[Dict]`):
455+
LFS batch actions returned by the server.
456+
oid2addop (`Dict[str, CommitOperationAdd]`):
457+
A dictionary mapping the OID of the file to the corresponding `CommitOperationAdd` object.
458+
headers (`Dict[str, str]`):
459+
Headers to use for the request, including authorization headers and user agent.
460+
endpoint (`str`, *optional*):
461+
The endpoint to use for the request. Defaults to `constants.ENDPOINT`.
462+
num_threads (`int`, *optional*):
463+
The number of concurrent threads to use when uploading. Defaults to 5.
464+
465+
Raises:
466+
[`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError)
467+
If an upload failed for any reason
468+
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
376469
Type of the repo to upload to: `"model"`, `"dataset"` or `"space"`.
377470
repo_id (`str`):
378471
A namespace (user or an organization) and a repo name separated
@@ -392,50 +485,17 @@ def _upload_lfs_files(
392485
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError)
393486
If the LFS batch endpoint returned an HTTP error.
394487
"""
395-
# Step 1: retrieve upload instructions from the LFS batch endpoint.
396-
# Upload instructions are retrieved by chunk of 256 files to avoid reaching
397-
# the payload limit.
398-
batch_actions: List[Dict] = []
399-
for chunk in chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES):
400-
batch_actions_chunk, batch_errors_chunk = post_lfs_batch_info(
401-
upload_infos=[op.upload_info for op in chunk],
402-
repo_id=repo_id,
403-
repo_type=repo_type,
404-
revision=revision,
405-
endpoint=endpoint,
406-
headers=headers,
407-
token=None, # already passed in 'headers'
408-
)
409-
410-
# If at least 1 error, we do not retrieve information for other chunks
411-
if batch_errors_chunk:
412-
message = "\n".join(
413-
[
414-
f"Encountered error for file with OID {err.get('oid')}: `{err.get('error', {}).get('message')}"
415-
for err in batch_errors_chunk
416-
]
417-
)
418-
raise ValueError(f"LFS batch endpoint returned errors:\n{message}")
419-
420-
batch_actions += batch_actions_chunk
421-
oid2addop = {add_op.upload_info.sha256.hex(): add_op for add_op in additions}
422-
423-
# Step 2: ignore files that have already been uploaded
488+
# Filter out files already present upstream
424489
filtered_actions = []
425-
for action in batch_actions:
490+
for action in actions:
426491
if action.get("actions") is None:
427492
logger.debug(
428-
f"Content of file {oid2addop[action['oid']].path_in_repo} is already"
429-
" present upstream - skipping upload."
493+
f"Content of file {oid2addop[action['oid']].path_in_repo} is already present upstream - skipping upload."
430494
)
431495
else:
432496
filtered_actions.append(action)
433497

434-
if len(filtered_actions) == 0:
435-
logger.debug("No LFS files to upload.")
436-
return
437-
438-
# Step 3: upload files concurrently according to these instructions
498+
# Upload according to server-provided actions
439499
def _wrapped_lfs_upload(batch_action) -> None:
440500
try:
441501
operation = oid2addop[batch_action["oid"]]
@@ -576,30 +636,30 @@ def token_refresher() -> Tuple[str, int]:
576636
progress, progress_callback = None, None
577637

578638
try:
579-
for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)):
580-
_chunk = [op for op in chunk]
581-
582-
bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)]
583-
paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))]
584-
585-
if len(paths_ops) > 0:
586-
upload_files(
587-
[str(op.path_or_fileobj) for op in paths_ops],
588-
xet_endpoint,
589-
access_token_info,
590-
token_refresher,
591-
progress_callback,
592-
repo_type,
593-
)
594-
if len(bytes_ops) > 0:
595-
upload_bytes(
596-
[op.path_or_fileobj for op in bytes_ops],
597-
xet_endpoint,
598-
access_token_info,
599-
token_refresher,
600-
progress_callback,
601-
repo_type,
602-
)
639+
all_bytes_ops = [op for op in additions if isinstance(op.path_or_fileobj, bytes)]
640+
all_paths_ops = [op for op in additions if isinstance(op.path_or_fileobj, (str, Path))]
641+
642+
if len(all_paths_ops) > 0:
643+
all_paths = [str(op.path_or_fileobj) for op in all_paths_ops]
644+
upload_files(
645+
all_paths,
646+
xet_endpoint,
647+
access_token_info,
648+
token_refresher,
649+
progress_callback,
650+
repo_type,
651+
)
652+
653+
if len(all_bytes_ops) > 0:
654+
all_bytes = [op.path_or_fileobj for op in all_bytes_ops]
655+
upload_bytes(
656+
all_bytes,
657+
xet_endpoint,
658+
access_token_info,
659+
token_refresher,
660+
progress_callback,
661+
repo_type,
662+
)
603663

604664
finally:
605665
if progress is not None:

src/huggingface_hub/hf_api.py

Lines changed: 8 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
import inspect
18-
import io
1918
import json
2019
import re
2120
import struct
@@ -46,7 +45,7 @@
4645
Union,
4746
overload,
4847
)
49-
from urllib.parse import quote, unquote
48+
from urllib.parse import quote
5049

5150
import requests
5251
from requests.exceptions import HTTPError
@@ -62,8 +61,7 @@
6261
_fetch_files_to_copy,
6362
_fetch_upload_modes,
6463
_prepare_commit_payload,
65-
_upload_lfs_files,
66-
_upload_xet_files,
64+
_upload_files,
6765
_warn_on_overwriting_operations,
6866
)
6967
from ._inference_endpoints import InferenceEndpoint, InferenceEndpointType
@@ -132,13 +130,8 @@
132130
validate_hf_hub_args,
133131
)
134132
from .utils import tqdm as hf_tqdm
135-
from .utils._auth import (
136-
_get_token_from_environment,
137-
_get_token_from_file,
138-
_get_token_from_google_colab,
139-
)
133+
from .utils._auth import _get_token_from_environment, _get_token_from_file, _get_token_from_google_colab
140134
from .utils._deprecation import _deprecate_arguments, _deprecate_method
141-
from .utils._runtime import is_xet_available
142135
from .utils._typing import CallableT
143136
from .utils.endpoint_helpers import _is_emission_within_threshold
144137

@@ -4502,6 +4495,10 @@ def preupload_lfs_files(
45024495
f"Skipped upload for {len(new_lfs_additions) - len(new_lfs_additions_to_upload)} LFS file(s) "
45034496
"(ignored by gitignore file)."
45044497
)
4498+
# If no LFS files remain to upload, keep previous behavior and log explicitly
4499+
if len(new_lfs_additions_to_upload) == 0:
4500+
logger.debug("No LFS files to upload.")
4501+
return
45054502
# Prepare upload parameters
45064503
upload_kwargs = {
45074504
"additions": new_lfs_additions_to_upload,
@@ -4514,32 +4511,7 @@ def preupload_lfs_files(
45144511
# PR (i.e. `revision`).
45154512
"revision": revision if not create_pr else None,
45164513
}
4517-
# Upload files using Xet protocol if all of the following are true:
4518-
# - xet is enabled for the repo,
4519-
# - the files are provided as str or paths objects,
4520-
# - the library is installed.
4521-
# Otherwise, default back to LFS.
4522-
xet_enabled = self.repo_info(
4523-
repo_id=repo_id,
4524-
repo_type=repo_type,
4525-
revision=unquote(revision) if revision is not None else revision,
4526-
expand="xetEnabled",
4527-
token=token,
4528-
).xet_enabled
4529-
has_buffered_io_data = any(
4530-
isinstance(addition.path_or_fileobj, io.BufferedIOBase) for addition in new_lfs_additions_to_upload
4531-
)
4532-
if xet_enabled and not has_buffered_io_data and is_xet_available():
4533-
logger.debug("Uploading files using Xet Storage..")
4534-
_upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type]
4535-
else:
4536-
if xet_enabled and is_xet_available():
4537-
if has_buffered_io_data:
4538-
logger.warning(
4539-
"Uploading files as a binary IO buffer is not supported by Xet Storage. "
4540-
"Falling back to HTTP upload."
4541-
)
4542-
_upload_lfs_files(**upload_kwargs, num_threads=num_threads) # type: ignore [arg-type]
4514+
_upload_files(**upload_kwargs, num_threads=num_threads, create_pr=create_pr) # type: ignore [arg-type]
45434515
for addition in new_lfs_additions_to_upload:
45444516
addition._is_uploaded = True
45454517
if free_memory:

src/huggingface_hub/lfs.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,8 @@ def post_lfs_batch_info(
108108
revision: Optional[str] = None,
109109
endpoint: Optional[str] = None,
110110
headers: Optional[Dict[str, str]] = None,
111-
) -> Tuple[List[dict], List[dict]]:
111+
transfers: Optional[List[str]] = None,
112+
) -> Tuple[List[dict], List[dict], Optional[str]]:
112113
"""
113114
Requests the LFS batch endpoint to retrieve upload instructions
114115
@@ -127,11 +128,14 @@ def post_lfs_batch_info(
127128
The git revision to upload to.
128129
headers (`dict`, *optional*):
129130
Additional headers to include in the request
131+
transfers (`list`, *optional*):
132+
List of transfer methods to use. Defaults to ["basic", "multipart"].
130133
131134
Returns:
132-
`LfsBatchInfo`: 2-tuple:
135+
`LfsBatchInfo`: 3-tuple:
133136
- First element is the list of upload instructions from the server
134-
- Second element is an list of errors, if any
137+
- Second element is a list of errors, if any
138+
- Third element is the chosen transfer adapter if provided by the server (e.g. "basic", "multipart", "xet")
135139
136140
Raises:
137141
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
@@ -146,7 +150,7 @@ def post_lfs_batch_info(
146150
batch_url = f"{endpoint}/{url_prefix}{repo_id}.git/info/lfs/objects/batch"
147151
payload: Dict = {
148152
"operation": "upload",
149-
"transfers": ["basic", "multipart"],
153+
"transfers": transfers if transfers is not None else ["basic", "multipart"],
150154
"objects": [
151155
{
152156
"oid": upload.sha256.hex(),
@@ -172,9 +176,13 @@ def post_lfs_batch_info(
172176
if not isinstance(objects, list):
173177
raise ValueError("Malformed response from server")
174178

179+
chosen_transfer = batch_info.get("transfer")
180+
chosen_transfer = chosen_transfer if isinstance(chosen_transfer, str) else None
181+
175182
return (
176183
[_validate_batch_actions(obj) for obj in objects if "error" not in obj],
177184
[_validate_batch_error(obj) for obj in objects if "error" in obj],
185+
chosen_transfer,
178186
)
179187

180188

0 commit comments

Comments
 (0)