99from collections .abc import Callable
1010from dataclasses import dataclass , field
1111from enum import Enum
12- from typing import Any
12+ from typing import Any , cast
1313
1414from ax .analysis .analysis import AnalysisCard
1515
@@ -67,9 +67,10 @@ class SQAConfig:
6767 serialization function.
6868 """
6969
70+ EXPERIMENT_TYPES_WITH_NO_DATA_STORAGE : set [str ] = field (default_factory = set )
71+
7072 def _default_class_to_sqa_class (self = None ) -> dict [type [Base ], type [SQABase ]]:
71- # pyre-fixme[7]
72- return {
73+ ax_cls_to_sqa_cls = {
7374 AbandonedArm : SQAAbandonedArm ,
7475 AnalysisCard : SQAAnalysisCard ,
7576 Arm : SQAArm ,
@@ -84,6 +85,10 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
8485 Trial : SQATrial ,
8586 AuxiliaryExperiment : SQAAuxiliaryExperiment ,
8687 }
88+ return {
89+ cast (type [Base ], k ): cast (type [SQABase ], v )
90+ for k , v in ax_cls_to_sqa_cls .items ()
91+ }
8792
8893 class_to_sqa_class : dict [type [Base ], type [SQABase ]] = field (
8994 default_factory = _default_class_to_sqa_class
@@ -92,27 +97,21 @@ def _default_class_to_sqa_class(self=None) -> dict[type[Base], type[SQABase]]:
9297 generator_run_type_enum : Enum | type [Enum ] | None = GeneratorRunType
9398 auxiliary_experiment_purpose_enum : type [Enum ] = AuxiliaryExperimentPurpose
9499
95- # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
96- # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
97- # `typing.Type` to avoid runtime subscripting errors.
98- json_encoder_registry : dict [type , Callable [[Any ], dict [str , Any ]]] = field (
100+ # Encoding and decoding registries:
101+ json_encoder_registry : dict [type [Any ], Callable [[Any ], dict [str , Any ]]] = field (
99102 default_factory = lambda : CORE_ENCODER_REGISTRY
100103 )
101- # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
102- # pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
103- # `typing.Type` to avoid runtime subscripting errors.
104- json_class_encoder_registry : dict [type , Callable [[Any ], dict [str , Any ]]] = field (
105- default_factory = lambda : CORE_CLASS_ENCODER_REGISTRY
104+ json_class_encoder_registry : dict [type [Any ], Callable [[Any ], dict [str , Any ]]] = (
105+ field (default_factory = lambda : CORE_CLASS_ENCODER_REGISTRY )
106106 )
107-
108107 json_decoder_registry : TDecoderRegistry = field (
109108 default_factory = lambda : CORE_DECODER_REGISTRY
110109 )
111- # pyre-fixme[4]: Attribute annotation cannot contain `Any`.
112110 json_class_decoder_registry : dict [str , Callable [[dict [str , Any ]], Any ]] = field (
113111 default_factory = lambda : CORE_CLASS_DECODER_REGISTRY
114112 )
115113
114+ # Metric and runner class registries:
116115 metric_registry : dict [type [Metric ], int ] = field (
117116 default_factory = lambda : CORE_METRIC_REGISTRY
118117 )
0 commit comments