Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/huggingface_hub/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
Callable,
Dict,
ForwardRef,
List,
Literal,
Optional,
Expand Down Expand Up @@ -325,6 +326,8 @@ def type_validator(name: str, value: Any, expected_type: Any) -> None:
validator(name, value, args)
elif isinstance(expected_type, type): # simple types
_validate_simple_type(name, value, expected_type)
elif isinstance(expected_type, ForwardRef) or isinstance(expected_type, str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The clause above, isinstance(expected_type, type), is True when expected_type is str. This means or isinstance(expected_type, str) can be removed, correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not really. The above clause is triggered if exp_type = <string> and the one I added is for cases such as exp_type = "torch.tensor"

In other words, the type is not resolved and remains in string version

Copy link
Contributor

@gante gante Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, derp, of course -- expected_type is a string instance here, not the type string 🤦

return
else:
raise TypeError(f"Unsupported type for field '{name}': {expected_type}")

Expand Down
62 changes: 60 additions & 2 deletions tests/test_utils_strict_dataclass.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of modifying existing tests I would prefer that you keep everything as it is + add a new class ConfigWithForwardRef specifically for this PR. Also you are testing on dtype: Union[ForwardRef("torch.dtype"), str] = "float32" which is not really conclusive since even without the forward ref, "float32" is already a valid string value (so str allows it).

So can you create a config like this:

@strict
@dataclass
class ConfigWithForwardRef:
    explicitForwardRef: ForwardRef("torch.dtype")
    implicitForwardRef: "torch.dtype"

and test both cases?

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
from dataclasses import asdict, astuple, dataclass, is_dataclass
from typing import Any, Dict, List, Literal, Optional, Set, Tuple, Union, get_type_hints
from typing import Any, Dict, ForwardRef, List, Literal, Optional, Set, Tuple, Union, get_type_hints

import jedi
import pytest
Expand Down Expand Up @@ -29,6 +29,24 @@ def strictly_positive(value: int):
raise ValueError(f"Value must be strictly positive, got {value}")


def dtype_validation(value):
"Torch isn't installed in runners so we will check only for `string`"
if not isinstance(value, str):
raise ValueError(f"Value must be string ot `torch.dtype` but got {value}")

if isinstance(value, str) and value not in ["float32", "bfloat16", "float16"]:
raise ValueError(f"Value must be one of `[float32, bfloat16, float16] but got {value}")


@strict
@dataclass
class ConfigForwardRef:
model_type: str
dtype: Union[ForwardRef("torch.dtype"), str] = validated_field(validator=[dtype_validation])
hidden_size: int = validated_field(validator=[positive_int, multiple_of_64])
vocab_size: int = strictly_positive(default=16)


@strict
@dataclass
class Config:
Expand Down Expand Up @@ -62,6 +80,42 @@ def test_default_values():
assert config.hidden_size == 1024


def test_forward_ref_validation():
config = ConfigForwardRef(model_type="bert", vocab_size=30000, hidden_size=768, dtype="float32")
assert config.model_type == "bert"
assert config.vocab_size == 30000
assert config.hidden_size == 768
assert config.dtype == "float32"

# All field are checked against type hints
with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type={"type": "bert"}, vocab_size=30000, hidden_size=768, dtype="float32")

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size="30000", hidden_size=768, dtype="float32")

# The `dtype` field can be of any value and will be skipped due to `ForwardRef`
# `ForwardRef` validation has to be added by the end-user in the field-metadata
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype="float32")
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype="float16")
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype="bfloat16")

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype="bfloat64")

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype=0)

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype=10.0)

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype=["float32"])

with pytest.raises(StrictDataclassFieldValidationError):
ConfigForwardRef(model_type="bert", vocab_size=100, hidden_size=768, dtype={"text_config": "float32"})


def test_invalid_type_initialization():
with pytest.raises(StrictDataclassFieldValidationError):
Config(model_type={"type": "bert"}, vocab_size=30000, hidden_size=768)
Expand Down Expand Up @@ -306,7 +360,11 @@ def test_is_recognized_as_dataclass():
def test_behave_as_a_dataclass():
# Check that dataclasses.asdict works
config = Config(model_type="bert", hidden_size=768)
assert asdict(config) == {"model_type": "bert", "hidden_size": 768, "vocab_size": 16}
assert asdict(config) == {
"model_type": "bert",
"hidden_size": 768,
"vocab_size": 16,
}

# Check that dataclasses.astuple works
assert astuple(config) == ("bert", 768, 16)
Expand Down
Loading