Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ax/storage/sqa_store/sqa_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ class SQAMetric(Base):
class SQAArm(Base):
__tablename__: str = "arm_v2"

generator_run_id: Column[int] = Column(
Integer, ForeignKey("generator_run_v2.id"), nullable=False
generator_run_id: Column[int | None] = Column(
Integer, ForeignKey("generator_run_v2.id")
)
id: Column[int] = Column(Integer, primary_key=True)
name: Column[str | None] = Column(String(NAME_OR_TYPE_FIELD_LENGTH))
Expand Down
48 changes: 24 additions & 24 deletions ax/storage/sqa_store/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ax.storage.sqa_store.sqa_classes import (
ONLY_ONE_FIELDS,
ONLY_ONE_METRIC_FIELDS,
SQAArm,
SQAMetric,
SQAParameter,
SQAParameterConstraint,
Expand Down Expand Up @@ -51,7 +52,6 @@ def wrapper(fn: Callable) -> Callable:
return wrapper


# pyre-fixme[3]: Return annotation cannot be `Any`.
def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) -> Any:
"""Ensure that exactly one of `exactly_one_fields` has a value set."""
values = [getattr(instance, field) is not None for field in exactly_one_fields]
Expand All @@ -62,29 +62,6 @@ def consistency_exactly_one(instance: SQABase, exactly_one_fields: list[str]) ->
)


@listens_for_multiple(
targets=GR_LARGE_MODEL_ATTRS,
identifier="set",
# `retval=True` instruct the operation ('set' on attributes in `targets`) to use
# the return value of decorated function to set the attribute.
retval=True,
# `propagate=True` ensures that targets with subclasses of SQA classes used by
# default Ax OSS encoder inherit the event listeners.
propagate=True,
)
def do_not_set_existing_value_to_null(
instance: SQABase, new_value: T, old_value: T, initiator_event: event.Events
) -> T:
no_value = [None, NO_VALUE]
if new_value in no_value and old_value not in no_value:
logger.debug(
f"New value for attribute is `None` or has no value, but old value "
f"was set, so keeping the old value ({old_value})."
)
return old_value
return new_value


@event.listens_for(
SQAParameter,
"before_insert",
Expand Down Expand Up @@ -116,3 +93,26 @@ def validate_metric(mapper: Mapper, connection: Connection, target: SQABase) ->
@event.listens_for(SQARunner, "before_update")
def validate_runner(mapper: Mapper, connection: Connection, target: SQABase) -> None:
consistency_exactly_one(target, ["experiment_id", "trial_id"])


@listens_for_multiple(
targets=GR_LARGE_MODEL_ATTRS,
identifier="set",
# `retval=True` instruct the operation ('set' on attributes in `targets`) to use
# the return value of decorated function to set the attribute.
retval=True,
# `propagate=True` ensures that targets with subclasses of SQA classes used by
# default Ax OSS encoder inherit the event listeners.
propagate=True,
)
def do_not_set_existing_value_to_null(
instance: SQABase, new_value: T, old_value: T, initiator_event: event.Events
) -> T:
no_value = [None, NO_VALUE]
if new_value in no_value and old_value not in no_value:
logger.debug(
f"New value for attribute is `None` or has no value, but old value "
f"was set, so keeping the old value ({old_value})."
)
return old_value
return new_value
Loading