diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py index 3caaa6efc5..ba52b446b1 100644 --- a/test/quantization/test_config_serialization.py +++ b/test/quantization/test_config_serialization.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import json import os import tempfile @@ -14,6 +20,7 @@ config_to_dict, ) from torchao.quantization.quant_api import ( + AOPerModuleConfig, Float8DynamicActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, @@ -63,6 +70,14 @@ # Sparsity configs SemiSparseWeightConfig(), BlockSparseWeightConfig(blocksize=128), + AOPerModuleConfig({}), + AOPerModuleConfig({"_default": Int4WeightOnlyConfig(), "linear1": None}), + AOPerModuleConfig( + { + "linear1": Int4WeightOnlyConfig(), + "linear2": Int8DynamicActivationInt4WeightConfig(), + } + ), ] diff --git a/torchao/core/config.py b/torchao/core/config.py index fe03ac225b..a041130835 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -255,6 +255,14 @@ def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: else item for item in value ] + elif isinstance(value, dict): + # Handle dicts of possible configs + processed_data[key] = { + k: config_from_dict(v) + if isinstance(v, dict) and "_type" in v and "_data" in v + else v + for k, v in value.items() + } else: processed_data[key] = value