Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ax/adapter/tests/test_torch_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ def test_candidate_metadata_propagation(self) -> None:
exp = get_branin_experiment(with_status_quo=True, with_completed_batch=True)
# Check that the metadata is correctly re-added to observation
# features during `fit`.
# pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run_structs`.
preexisting_batch_gr = exp.trials[0]._generator_runs[0]
preexisting_batch_gr = exp.trials[0].generator_runs[0]
preexisting_batch_gr._candidate_metadata_by_arm_signature = {
preexisting_batch_gr.arms[0].signature: {
"preexisting_batch_cand_metadata": "some_value"
Expand Down
12 changes: 10 additions & 2 deletions ax/storage/sqa_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,16 @@ def trial_to_sqa(
)
return trial_sqa

def experiment_data_to_sqa(self, experiment: Experiment) -> list[SQAData]:
"""Convert Ax experiment data to SQLAlchemy."""
def experiment_data_to_sqa(
self,
experiment: Experiment,
) -> list[SQAData]:
if (
experiment.experiment_type
in self.config.EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE
):
return []

return [
self.data_to_sqa(data=data, trial_index=trial_index, timestamp=timestamp)
for trial_index, data_by_timestamp in experiment.data_by_trial.items()
Expand Down
27 changes: 13 additions & 14 deletions ax/storage/sqa_store/sqa_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from typing import Any, cast

from ax.analysis.analysis import AnalysisCard

Expand Down Expand Up @@ -67,9 +67,10 @@ class SQAConfig:
serialization function.
"""

EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE: set[str] = field(default_factory=set)

def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
# pyre-fixme[7]
return {
ax_cls_to_sqa_cls = {
AbandonedArm: SQAAbandonedArm,
AnalysisCard: SQAAnalysisCard,
Arm: SQAArm,
Expand All @@ -84,6 +85,10 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
Trial: SQATrial,
AuxiliaryExperiment: SQAAuxiliaryExperiment,
}
return {
cast(type[Base], k): cast(type[SQABase], v)
for k, v in ax_cls_to_sqa_cls.items()
}

class_to_sqa_class: dict[type[Base], type[SQABase]] = field(
default_factory=_default_class_to_sqa_class
Expand All @@ -92,27 +97,21 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType
auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose

# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field(
# Encoding and decoding registries:
json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = field(
default_factory=lambda: CORE_ENCODER_REGISTRY
)
# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field(
default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY
json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = (
field(default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY)
)

json_decoder_registry: TDecoderRegistry = field(
default_factory=lambda: CORE_DECODER_REGISTRY
)
# pyre-fixme[4]: Attribute annotation cannot contain `Any`.
json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] = field(
default_factory=lambda: CORE_CLASS_DECODER_REGISTRY
)

# Metric and runner class registries:
metric_registry: dict[type[Metric], int] = field(
default_factory=lambda: CORE_METRIC_REGISTRY
)
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,6 @@ def creator() -> Mock:

def test_GeneratorRunTypeValidation(self) -> None:
experiment = get_experiment_with_batch_trial()
# pyre-fixme[16]: `BaseTrial` has no attribute `generator_run_structs`.
generator_run = experiment.trials[0].generator_runs[0]
generator_run._generator_run_type = "foobar"
with self.assertRaises(SQAEncodeError):
Expand Down