Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/huggingface_hub/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
Dict,
ForwardRef,
List,
Literal,
Optional,
Expand Down Expand Up @@ -325,6 +326,8 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
validator(name, value, args)
elif isinstance(expected_type, type): # simple types
_validate_simple_type(name, value, expected_type)
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clause above, isinstance(expected_type, type), is True when expected_type is str. This means or isinstance(expected_type, str) can be removed, correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. The above clause is triggered if exp_type = <string> and the one I added is for cases such as exp_type = "torch.tensor"

In other words, the type is not resolved and remains in string version

Copy link
Contributor

@gante gante Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, derp, of course -- expected_type is a string instance here, not the type string 🤦

return
else:
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")

Expand Down
44 changes: 44 additions & 0 deletions tests/test_utils_strict_dataclass.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of modifying existing tests I would prefer that you keep everything as it is + add a new class ConfigWithForwardRef specifically for this PR. Also you are testing on dtype: Union[ForwardRef("torch.dtype"), str] = "float32" which is not really conclusive since even without the forward ref, "float32" is already a valid string value (so str allows it).

So can you create a config like this:

@strict
@dataclass
class ConfigWithForwardRef:
    explicitForwardRef: ForwardRef("torch.dtype")
    implicitForwardRef: "torch.dtype"

and test both cases?

Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ def strictly_positive(value: int):
raise ValueError(f"Value must be strictly positive, got {value}")


def dtype_validation(value: "ForwardDtype"):
if not isinstance(value, str):
raise ValueError(f"Value must be string, got {value}")

if isinstance(value, str) and value not in ["float32", "bfloat16", "float16"]:
raise ValueError(f"Value must be one of `[float32, bfloat16, float16]` but got {value}")


@strict
@dataclass
class ConfigForwardRef:
"""Test forward reference handling.

In practice, forward reference types are not validated so a custom validator is highly recommended.
"""

forward_ref_validated: "ForwardDtype" = validated_field(validator=dtype_validation)
forward_ref: "ForwardDtype" = "float32" # type is not validated by default


class ForwardDtype(str):
"""Dummy class to simulate a forward reference (e.g. `torch.dtype`)."""


@strict
@dataclass
class Config:
Expand Down Expand Up @@ -62,6 +86,26 @@ def test_default_values():
assert config.hidden_size == 1024


def test_forward_ref_validation_is_skipped():
config = ConfigForwardRef(forward_ref="float32", forward_ref_validated="float32")
assert config.forward_ref == "float32"
assert config.forward_ref_validated == "float32"

# The `forward_ref_validated` has proper validation added in field-metadata and will be validated
with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(forward_ref_validated="float64")

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(forward_ref_validated=-1)

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(forward_ref_validated="not_dtype")

# The `forward_ref` type is not validated => user can input anything
ConfigForwardRef(forward_ref=-1, forward_ref_validated="float32")
ConfigForwardRef(forward_ref=["float32"], forward_ref_validated="float32")


def test_invalid_type_initialization():
with pytest.raises(StrictDataclassFieldValidationError):
Config(model_type={"type": "bert"}, vocab_size=30000, hidden_size=768)
Expand Down
Loading