Skip to content

Commit d19aa5a

Browse files
Lena Kashtelyanfacebook-github-bot
authored andcommitted
Allow SQAArm.generator_run_id to be nullable (#4363)
Summary: **Goal of this stack: attach all arms on the experiment, to the experiment object itself, as the main "ground truth" for them.** Differential Revision: D83488268
1 parent 5416f64 commit d19aa5a

File tree

2 files changed

+26
-26
lines changed

2 files changed

+26
-26
lines changed

ax/storage/sqa_store/sqa_classes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,8 @@ class SQAMetric(Base):
164164
class SQAArm(Base):
165165
__tablename__: str = "arm_v2"
166166

167-
generator_run_id: Column[int] = Column(
168-
Integer, ForeignKey("generator_run_v2.id"), nullable=False
167+
generator_run_id: Column[int | None] = Column(
168+
Integer, ForeignKey("generator_run_v2.id")
169169
)
170170
id: Column[int] = Column(Integer, primary_key=True)
171171
name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))

ax/storage/sqa_store/validation.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from ax.storage.sqa_store.sqa_classes import (
1616
ONLY_ONE_FIELDS,
1717
ONLY_ONE_METRIC_FIELDS,
18+
SQAArm,
1819
SQAMetric,
1920
SQAParameter,
2021
SQAParameterConstraint,
@@ -51,7 +52,6 @@ def wrapper(fn: Callable) -> Callable:
5152
return wrapper
5253

5354

54-
# pyre-fixme[3]: Return annotation cannot be `Any`.
5555
def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any:
5656
"""Ensure that exactly one of `exactly_one_fields` has a value set."""
5757
values = [getattr(instance, field) is not None for field in exactly_one_fields]
@@ -62,29 +62,6 @@ def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) ->
6262
)
6363

6464

65-
@listens_for_multiple(
66-
targets=GR_LARGE_MODEL_ATTRS,
67-
identifier="set",
68-
# `retval=True` instruct the operation ('set' on attributes in `targets`) to use
69-
# the return value of decorated function to set the attribute.
70-
retval=True,
71-
# `propagate=True` ensures that targets with subclasses of SQA classes used by
72-
# default Ax OSS encoder inherit the event listeners.
73-
propagate=True,
74-
)
75-
def do_not_set_existing_value_to_null(
76-
instance: SQABase, new_value: T, old_value: T, initiator_event: event.Events
77-
) -> T:
78-
no_value = [None, NO_VALUE]
79-
if new_value in no_value and old_value not in no_value:
80-
logger.debug(
81-
f"New value for attribute is `None` or has no value, but old value "
82-
f"was set, so keeping the old value ({old_value})."
83-
)
84-
return old_value
85-
return new_value
86-
87-
8865
@event.listens_for(
8966
SQAParameter,
9067
"before_insert",
@@ -116,3 +93,26 @@ def validate_metric(mapper: Mapper, connection: Connection, target: SQABase) ->
11693
@event.listens_for(SQARunner, "before_update")
11794
def validate_runner(mapper: Mapper, connection: Connection, target: SQABase) -> None:
11895
consistency_exactly_one(target, ["experiment_id", "trial_id"])
96+
97+
98+
@listens_for_multiple(
99+
targets=GR_LARGE_MODEL_ATTRS,
100+
identifier="set",
101+
# `retval=True` instruct the operation ('set' on attributes in `targets`) to use
102+
# the return value of decorated function to set the attribute.
103+
retval=True,
104+
# `propagate=True` ensures that targets with subclasses of SQA classes used by
105+
# default Ax OSS encoder inherit the event listeners.
106+
propagate=True,
107+
)
108+
def do_not_set_existing_value_to_null(
109+
instance: SQABase, new_value: T, old_value: T, initiator_event: event.Events
110+
) -> T:
111+
no_value = [None, NO_VALUE]
112+
if new_value in no_value and old_value not in no_value:
113+
logger.debug(
114+
f"New value for attribute is `None` or has no value, but old value "
115+
f"was set, so keeping the old value ({old_value})."
116+
)
117+
return old_value
118+
return new_value

0 commit comments

Comments
 (0)