diff --git a/ax/adapter/tests/test_torch_adapter.py b/ax/adapter/tests/test_torch_adapter.py index 50e7b2a1dd2..12a3095c1a1 100644 --- a/ax/adapter/tests/test_torch_adapter.py +++ b/ax/adapter/tests/test_torch_adapter.py @@ -483,8 +483,7 @@ def test_candidate_metadata_propagation(self) -> None: exp = get_branin_experiment(with_status_quo=True, with_completed_batch=True) # Check that the metadata is correctly re-added to observation # features during `fit`. - # pyre-fixme[16]: `BaseTrial` has no attribute `_generator_run_structs`. - preexisting_batch_gr = exp.trials[0]._generator_runs[0] + preexisting_batch_gr = exp.trials[0].generator_runs[0] preexisting_batch_gr._candidate_metadata_by_arm_signature = { preexisting_batch_gr.arms[0].signature: { "preexisting_batch_cand_metadata": "some_value" diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 5d50c9a895b..ae1d529cc2f 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -1043,8 +1043,16 @@ def trial_to_sqa( ) return trial_sqa - def experiment_data_to_sqa(self, experiment: Experiment) -> list[SQAData]: - """Convert Ax experiment data to SQLAlchemy.""" + def experiment_data_to_sqa( + self, + experiment: Experiment, + ) -> list[SQAData]: + if ( + experiment.experiment_type + in self.config.EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE + ): + return [] + return [ self.data_to_sqa(data=data, trial_index=trial_index, timestamp=timestamp) for trial_index, data_by_timestamp in experiment.data_by_trial.items() diff --git a/ax/storage/sqa_store/sqa_config.py b/ax/storage/sqa_store/sqa_config.py index 1102a6f14a7..07fb03cf321 100644 --- a/ax/storage/sqa_store/sqa_config.py +++ b/ax/storage/sqa_store/sqa_config.py @@ -9,7 +9,7 @@ from collections.abc import Callable from dataclasses import dataclass, field from enum import Enum -from typing import Any +from typing import Any, cast from ax.analysis.analysis import AnalysisCard @@ -67,9 +67,10 @@ class SQAConfig: serialization function. """ + EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE: set[str] = field(default_factory=set) + def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: - # pyre-fixme[7] - return { + ax_cls_to_sqa_cls = { AbandonedArm: SQAAbandonedArm, AnalysisCard: SQAAnalysisCard, Arm: SQAArm, @@ -84,6 +85,10 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: Trial: SQATrial, AuxiliaryExperiment: SQAAuxiliaryExperiment, } + return { + cast(type[Base], k): cast(type[SQABase], v) + for k, v in ax_cls_to_sqa_cls.items() + } class_to_sqa_class: dict[type[Base], type[SQABase]] = field( default_factory=_default_class_to_sqa_class @@ -92,27 +97,21 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]: generator_run_type_enum: Enum | type[Enum] | None = GeneratorRunType auxiliary_experiment_purpose_enum: type[Enum] = AuxiliaryExperimentPurpose - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( + # Encoding and decoding registries: + json_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = field( default_factory=lambda: CORE_ENCODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. - # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use - # `typing.Type` to avoid runtime subscripting errors. - json_class_encoder_registry: dict[type, Callable[[Any], dict[str, Any]]] = field( - default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY + json_class_encoder_registry: dict[type[Any], Callable[[Any], dict[str, Any]]] = ( + field(default_factory=lambda: CORE_CLASS_ENCODER_REGISTRY) ) - json_decoder_registry: TDecoderRegistry = field( default_factory=lambda: CORE_DECODER_REGISTRY ) - # pyre-fixme[4]: Attribute annotation cannot contain `Any`. json_class_decoder_registry: dict[str, Callable[[dict[str, Any]], Any]] = field( default_factory=lambda: CORE_CLASS_DECODER_REGISTRY ) + # Metric and runner class registries: metric_registry: dict[type[Metric], int] = field( default_factory=lambda: CORE_METRIC_REGISTRY ) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 3c1baa72fd1..f97ea8ca502 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -218,7 +218,6 @@ def creator() -> Mock: def test_GeneratorRunTypeValidation(self) -> None: experiment = get_experiment_with_batch_trial() - # pyre-fixme[16]: `BaseTrial` has no attribute `generator_run_structs`. generator_run = experiment.trials[0].generator_runs[0] generator_run._generator_run_type = "foobar" with self.assertRaises(SQAEncodeError):