1515from ax .storage .sqa_store .sqa_classes import (
1616 ONLY_ONE_FIELDS ,
1717 ONLY_ONE_METRIC_FIELDS ,
18+ SQAArm ,
1819 SQAMetric ,
1920 SQAParameter ,
2021 SQAParameterConstraint ,
@@ -51,7 +52,6 @@ def wrapper(fn: Callable) -> Callable:
5152 return wrapper
5253
5354
54- # pyre-fixme[3]: Return annotation cannot be `Any`.
5555def consistency_exactly_one (instance : SQABase , exactly_one_fields : list [str ]) -> Any :
5656 """Ensure that exactly one of `exactly_one_fields` has a value set."""
5757 values = [getattr (instance , field ) is not None for field in exactly_one_fields ]
@@ -62,29 +62,6 @@ def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) ->
6262 )
6363
6464
65- @listens_for_multiple (
66- targets = GR_LARGE_MODEL_ATTRS ,
67- identifier = "set" ,
68- # `retval=True` instruct the operation ('set' on attributes in `targets`) to use
69- # the return value of decorated function to set the attribute.
70- retval = True ,
71- # `propagate=True` ensures that targets with subclasses of SQA classes used by
72- # default Ax OSS encoder inherit the event listeners.
73- propagate = True ,
74- )
75- def do_not_set_existing_value_to_null (
76- instance : SQABase , new_value : T , old_value : T , initiator_event : event .Events
77- ) -> T :
78- no_value = [None , NO_VALUE ]
79- if new_value in no_value and old_value not in no_value :
80- logger .debug (
81- f"New value for attribute is `None` or has no value, but old value "
82- f"was set, so keeping the old value ({ old_value } )."
83- )
84- return old_value
85- return new_value
86-
87-
8865@event .listens_for (
8966 SQAParameter ,
9067 "before_insert" ,
@@ -116,3 +93,26 @@ def validate_metric(mapper: Mapper, connection: Connection, target: SQABase) ->
11693@event .listens_for (SQARunner , "before_update" )
11794def validate_runner (mapper : Mapper , connection : Connection , target : SQABase ) -> None :
11895 consistency_exactly_one (target , ["experiment_id" , "trial_id" ])
96+
97+
98+ @listens_for_multiple (
99+ targets = GR_LARGE_MODEL_ATTRS ,
100+ identifier = "set" ,
101+ # `retval=True` instruct the operation ('set' on attributes in `targets`) to use
102+ # the return value of decorated function to set the attribute.
103+ retval = True ,
104+ # `propagate=True` ensures that targets with subclasses of SQA classes used by
105+ # default Ax OSS encoder inherit the event listeners.
106+ propagate = True ,
107+ )
108+ def do_not_set_existing_value_to_null (
109+ instance : SQABase , new_value : T , old_value : T , initiator_event : event .Events
110+ ) -> T :
111+ no_value = [None , NO_VALUE ]
112+ if new_value in no_value and old_value not in no_value :
113+ logger .debug (
114+ f"New value for attribute is `None` or has no value, but old value "
115+ f"was set, so keeping the old value ({ old_value } )."
116+ )
117+ return old_value
118+ return new_value
0 commit comments