@@ -29,6 +29,30 @@ def strictly_positive(value: int):
2929 raise ValueError (f"Value must be strictly positive, got { value } " )
3030
3131
32+ def dtype_validation (value : "ForwardDtype" ):
33+ if not isinstance (value , str ):
34+ raise ValueError (f"Value must be string, got { value } " )
35+
36+ if isinstance (value , str ) and value not in ["float32" , "bfloat16" , "float16" ]:
37+ raise ValueError (f"Value must be one of `[float32, bfloat16, float16]` but got { value } " )
38+
39+
40+ @strict
41+ @dataclass
42+ class ConfigForwardRef :
43+ """Test forward reference handling.
44+
45+ In practice, forward reference types are not validated so a custom validator is highly recommended.
46+ """
47+
48+ forward_ref_validated : "ForwardDtype" = validated_field (validator = dtype_validation )
49+ forward_ref : "ForwardDtype" = "float32" # type is not validated by default
50+
51+
52+ class ForwardDtype (str ):
53+ """Dummy class to simulate a forward reference (e.g. `torch.dtype`)."""
54+
55+
3256@strict
3357@dataclass
3458class Config :
@@ -62,6 +86,26 @@ def test_default_values():
6286 assert config .hidden_size == 1024
6387
6488
89+ def test_forward_ref_validation_is_skipped ():
90+ config = ConfigForwardRef (forward_ref = "float32" , forward_ref_validated = "float32" )
91+ assert config .forward_ref == "float32"
92+ assert config .forward_ref_validated == "float32"
93+
94+ # The `forward_ref_validated` has proper validation added in field-metadata and will be validated
95+ with pytest .raises (StrictDataclassFieldValidationError ):
96+ ConfigForwardRef (forward_ref_validated = "float64" )
97+
98+ with pytest .raises (StrictDataclassFieldValidationError ):
99+ ConfigForwardRef (forward_ref_validated = - 1 )
100+
101+ with pytest .raises (StrictDataclassFieldValidationError ):
102+ ConfigForwardRef (forward_ref_validated = "not_dtype" )
103+
104+ # The `forward_ref` type is not validated => user can input anything
105+ ConfigForwardRef (forward_ref = - 1 , forward_ref_validated = "float32" )
106+ ConfigForwardRef (forward_ref = ["float32" ], forward_ref_validated = "float32" )
107+
108+
65109def test_invalid_type_initialization ():
66110 with pytest .raises (StrictDataclassFieldValidationError ):
67111 Config (model_type = {"type" : "bert" }, vocab_size = 30000 , hidden_size = 768 )
0 commit comments