Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/model_signing/_serialization/file_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections.abc import Callable, Iterable
import concurrent.futures
import itertools
import os
import pathlib
from typing import Optional

Expand Down Expand Up @@ -64,6 +65,7 @@
*,
max_workers: Optional[int] = None,
allow_symlinks: bool = False,
ignore_paths: Iterable[pathlib.Path] = frozenset(),
):
"""Initializes an instance to serialize a model with this serializer.

Expand All @@ -78,10 +80,12 @@
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
ignore_paths: The paths of files to ignore.
"""
self._hasher_factory = sharded_hasher_factory
self._max_workers = max_workers
self._allow_symlinks = allow_symlinks
self._ignore_paths = ignore_paths

# Precompute some private values only once by using a mock file hasher.
# None of the arguments used to build the hasher are used.
Expand All @@ -93,6 +97,7 @@
hasher._content_hasher.digest_name,
self._shard_size,
self._allow_symlinks,
self._ignore_paths,
)

@override
Expand Down Expand Up @@ -142,8 +147,29 @@
for future in concurrent.futures.as_completed(futures):
manifest_items.append(future.result())

# Recreate serialization_description for new ignore_paths
if ignore_paths:
rel_ignore_paths = []
for p in ignore_paths:
rp = os.path.relpath(p, model_path)
# rp may start with "../" if it is not relative to model_path
if not rp.startswith("../"):
rel_ignore_paths.append(pathlib.Path(rp))

hasher = self._hasher_factory(pathlib.Path(), 0, 1)
self._serialization_description = manifest._ShardSerialization(
hasher._content_hasher.digest_name,
self._shard_size,
self._allow_symlinks,
frozenset(list(self._ignore_paths) + rel_ignore_paths),
)

model_name = model_path.name
if not model_name or model_name == "..":
model_name = os.path.basename(model_path.resolve())

Check warning on line 169 in src/model_signing/_serialization/file_shard.py

View workflow job for this annotation

GitHub Actions / Signing with Python 3.12 on Linux

The following line was not covered in your tests: 169

return manifest.Manifest(
model_path.name, manifest_items, self._serialization_description
model_name, manifest_items, self._serialization_description
)

def _get_shards(
Expand Down
3 changes: 3 additions & 0 deletions src/model_signing/hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def use_shard_serialization(
shard_size: int = 1_000_000_000,
max_workers: Optional[int] = None,
allow_symlinks: bool = False,
ignore_paths: Iterable[pathlib.Path] = frozenset(),
) -> Self:
"""Configures serialization to build a manifest of (shard, hash) pairs.

Expand All @@ -311,6 +312,7 @@ def use_shard_serialization(
allow_symlinks: Controls whether symbolic links are included. If a
symlink is present but the flag is `False` (default) the
serialization would raise an error.
ignore_paths: Paths of files to ignore.

Returns:
The new hashing configuration with the new serialization method.
Expand All @@ -321,6 +323,7 @@ def use_shard_serialization(
),
max_workers=max_workers,
allow_symlinks=allow_symlinks,
ignore_paths=ignore_paths,
)
return self

Expand Down
18 changes: 15 additions & 3 deletions src/model_signing/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,11 @@ class _ShardSerialization(SerializationType):
method: Final[str] = "shards"

def __init__(
self, hash_type: str, shard_size: int, allow_symlinks: bool = False
self,
hash_type: str,
shard_size: int,
allow_symlinks: bool = False,
ignore_paths: Iterable[pathlib.Path] = frozenset(),
):
"""Records the manifest serialization type for serialization by files.

Expand All @@ -367,26 +371,34 @@ def __init__(
Args:
hash_type: A string representation of the hash algorithm.
allow_symlinks: Controls whether symbolic links are included.
ignore_paths: Paths of files to ignore.
"""
self._hash_type = hash_type
self._allow_symlinks = allow_symlinks
self._shard_size = shard_size
self._ignore_paths = [str(p) for p in ignore_paths]

@property
@override
def serialization_parameters(self) -> dict[str, Any]:
return {
res = {
"method": self.method,
"hash_type": self._hash_type,
"shard_size": self._shard_size,
"allow_symlinks": self._allow_symlinks,
}
if self._ignore_paths:
res["ignore_paths"] = self._ignore_paths
return res

@classmethod
@override
def _from_args(cls, args: dict[str, Any]) -> Self:
return cls(
args["hash_type"], args["shard_size"], args["allow_symlinks"]
args["hash_type"],
args["shard_size"],
args["allow_symlinks"],
args.get("ignore_paths", frozenset()),
)

@override
Expand Down
1 change: 1 addition & 0 deletions src/model_signing/verifying.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def _guess_hashing_config(self, source_manifest: manifest.Manifest) -> None:
hashing_algorithm=args["hash_type"],
shard_size=args["shard_size"],
allow_symlinks=args["allow_symlinks"],
ignore_paths=args.get("ignore_paths", frozenset()),
)
else:
raise ValueError("Cannot guess the hashing configuration")
Expand Down
73 changes: 73 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,76 @@ def test_sign_and_verify(self, base_path, populate_tmpdir):
signature, ignore_git_paths, ["model.sig", "ignored"]
)
assert get_model_name(signature) == os.path.basename(model_path)

def test_sign_and_verify_sharded(self, base_path, populate_tmpdir):
os.chdir(base_path)

model_path = populate_tmpdir
ignore_paths = []
ignore_git_paths = False
signature = Path(model_path / "model.sig")
private_key = Path(TESTDATA / "keys/certificate/signing-key.pem")
signing_certificate = Path(
TESTDATA / "keys/certificate/signing-key-cert.pem"
)
certificate_chain = [
Path(TESTDATA / "keys/certificate/int-ca-cert.pem")
]
log_fingerprints = False

signing.Config().use_certificate_signer(
private_key=private_key,
signing_certificate=signing_certificate,
certificate_chain=certificate_chain,
).set_hashing_config(
hashing.Config()
.set_ignored_paths(
paths=list(ignore_paths) + [signature],
ignore_git_paths=ignore_git_paths,
)
.use_shard_serialization()
).sign(model_path, signature)

certificate_chain = [Path(TESTDATA / "keys/certificate/ca-cert.pem")]

verifying.Config().use_certificate_verifier(
certificate_chain=certificate_chain,
log_fingerprints=log_fingerprints,
).set_hashing_config(
hashing.Config().set_ignored_paths(
paths=list(ignore_paths) + [signature],
ignore_git_paths=ignore_git_paths,
)
)
# .verify(model_path, signature)

assert [
".gitignore:0:4",
"signme-1:0:8",
"signme-2:0:8",
] == get_signed_files(signature)
check_ignore_paths(signature, ignore_git_paths, ["model.sig"])
assert get_model_name(signature) == os.path.basename(model_path)

# Ignore git paths now
ignore_paths = [Path(model_path / "ignored")]
ignore_git_paths = True

signing.Config().use_certificate_signer(
private_key=private_key,
signing_certificate=signing_certificate,
certificate_chain=certificate_chain,
).set_hashing_config(
hashing.Config()
.set_ignored_paths(
paths=list(ignore_paths) + [signature],
ignore_git_paths=ignore_git_paths,
)
.use_shard_serialization()
).sign(model_path, signature)

assert ["signme-1:0:8", "signme-2:0:8"] == get_signed_files(signature)
check_ignore_paths(
signature, ignore_git_paths, ["model.sig", "ignored"]
)
assert get_model_name(signature) == os.path.basename(model_path)
Loading