Skip to content

Commit ec282f6

Browse files
ds patch for 2.7 (#3638)
* Revert "Handle new linear modules in DeepSpeed v0.16.5 (#3622) (#3631)" This reverts commit d74bd04. * update * fix comment Signed-off-by: Liu, Mingzhi <[email protected]> --------- Signed-off-by: Liu, Mingzhi <[email protected]> Co-authored-by: Chunyuan WU <[email protected]>
1 parent c12230b commit ec282f6

File tree

4 files changed

+60
-24
lines changed

4 files changed

+60
-24
lines changed

intel_extension_for_pytorch/nn/utils/_parameter_wrapper.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ def IPEX_WEIGHT_PREPACK_MODULE_CPU():
4646
deepspeed_modules_mapping.update(
4747
{LmHeadLinearAllreduce: _IPEXLmHeadLinearAllreduce}
4848
)
49+
if len(deepspeed_modules) > 3:
50+
for module in deepspeed_modules[3:]:
51+
if module not in deepspeed_modules_mapping:
52+
if issubclass(module, LinearAllreduce):
53+
deepspeed_modules_mapping[module] = _IPEXLinearAllreduce
54+
elif issubclass(module, LinearLayer):
55+
deepspeed_modules_mapping[module] = _IPEXLinear
56+
else:
57+
raise ValueError(f"Unrecognized module type: {module}")
4958
torch_modules.update(deepspeed_modules_mapping)
5059

5160
return torch_modules
@@ -190,7 +199,9 @@ def get_shared_parameter_status(module, shared_p):
190199
if deepspeed_modules is not None:
191200
LinearAllreduce, LinearLayer = deepspeed_modules[:2]
192201

193-
if isinstance(module, (LinearLayer, LinearAllreduce)):
202+
if isinstance(module, (LinearLayer, LinearAllreduce)) or issubclass(
203+
type(module), (LinearLayer, LinearAllreduce)
204+
):
194205
module.weight = torch.nn.Parameter(module.weight, requires_grad=False)
195206
if module.bias is not None:
196207
module.bias = torch.nn.Parameter(module.bias, requires_grad=False)

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,22 +101,22 @@ def may_import_deepspeed_modules():
101101
try:
102102
# import deepspeed in a global space will raise circular import error
103103
# intel-extension-for-deepspeed imports both IPEX and deepspeed
104-
from deepspeed.module_inject.layers import (
105-
LinearAllreduce,
106-
LinearLayer,
107-
LmHeadLinearAllreduce,
108-
fused_LinearLayer,
109-
GateUpPack_LinearLayer,
110-
)
104+
import deepspeed.module_inject.layers as dslayers
105+
from deepspeed.module_inject.layers import LinearAllreduce, LinearLayer
106+
107+
ds_layers = [LinearAllreduce, LinearLayer]
108+
109+
from deepspeed.module_inject.layers import LmHeadLinearAllreduce
111110

112-
ds_layers = [
113-
LinearAllreduce,
114-
LinearLayer,
115-
LmHeadLinearAllreduce,
116-
fused_LinearLayer,
117-
GateUpPack_LinearLayer,
118-
]
111+
ds_layers.append(LmHeadLinearAllreduce)
112+
ds_layers += list(
113+
cls
114+
for cls in dslayers.LinearAllreduce.__subclasses__()
115+
if cls is not LmHeadLinearAllreduce
116+
)
117+
ds_layers += list(cls for cls in dslayers.LinearLayer.__subclasses__())
119118
return ds_layers
119+
120120
except ImportError:
121121
return None
122122

intel_extension_for_pytorch/quantization/_quantize.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,12 @@ def _may_insert_deepspeed_modules(
137137
LmHeadLinearAllreduce: q_lm_head_linear_all_reduce_module,
138138
}
139139
)
140+
if len(deepspeed_modules) > 3:
141+
for module in deepspeed_modules[3:]:
142+
if issubclass(module, LinearLayer):
143+
deepspeed_modules_dict[module] = q_linear_layer_module
144+
elif issubclass(module, LinearAllreduce):
145+
deepspeed_modules_dict[module] = q_linear_all_reduce_module
140146
torch_modules.update(deepspeed_modules_dict)
141147
return torch_modules
142148

@@ -231,6 +237,12 @@ def _float_module(cls):
231237
if deepspeed_modules is not None:
232238
LinearLayer = deepspeed_modules[1]
233239
_FLOAT_MODULE.extend([LinearLayer])
240+
241+
if len(deepspeed_modules) > 3:
242+
for module in deepspeed_modules[3:]:
243+
if issubclass(module, LinearLayer):
244+
_FLOAT_MODULE.extend([module])
245+
234246
return _FLOAT_MODULE
235247

236248
def __repr__(self):
@@ -260,6 +272,10 @@ def _float_module(cls):
260272
), "DynamicQuantizedLinearAllreduce requires deepspeed to be installed"
261273
LinearAllreduce = deepspeed_modules[0]
262274
_FLOAT_MODULE = [LinearAllreduce]
275+
if len(deepspeed_modules) > 3:
276+
for module in deepspeed_modules[3:]:
277+
if issubclass(module, LinearAllreduce):
278+
_FLOAT_MODULE.extend([module])
263279
return _FLOAT_MODULE
264280

265281
def __init__(
@@ -361,6 +377,7 @@ def may_quantize_deepspeed_modules(
361377
IPEX_QUANTIZATION_MODULE, q_config, module_mappings, qconfig_spec
362378
):
363379
deepspeed_modules = may_import_deepspeed_modules()
380+
364381
if deepspeed_modules is not None:
365382
LinearAllreduce, LinearLayer = deepspeed_modules[:2]
366383
module_mappings.update(IPEX_QUANTIZATION_MODULE)
@@ -375,6 +392,13 @@ def may_quantize_deepspeed_modules(
375392
LmHeadLinearAllreduce: q_config,
376393
}
377394
)
395+
if len(deepspeed_modules) > 3:
396+
for module in deepspeed_modules[3:]:
397+
deepspeed_qconfig_spec.update(
398+
{
399+
module: q_config,
400+
}
401+
)
378402

379403
qconfig_spec.update(deepspeed_qconfig_spec)
380404
return module_mappings, qconfig_spec

intel_extension_for_pytorch/utils/weight_only_quantization.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,22 +292,23 @@ def _convert_woq_with_low_precision_checkpoint(
292292

293293
deepspeed_modules = may_import_deepspeed_modules()
294294
if deepspeed_modules is not None:
295-
(
296-
LinearAllreduce,
297-
LinearLayer,
298-
LmHeadLinearAllreduce,
299-
fused_LinearLayer,
300-
GateUpPack_LinearLayer,
301-
) = deepspeed_modules
295+
LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, *extra_linear_modules = (
296+
deepspeed_modules
297+
)
298+
302299
q_op_map.update(
303300
{
304301
LinearAllreduce: IpexWoqLinearAllreduce,
305302
LinearLayer: WeightOnlyQuantizedLinear,
306-
fused_LinearLayer: WeightOnlyQuantizedLinear,
307-
GateUpPack_LinearLayer: WeightOnlyQuantizedLinear,
308303
}
309304
)
310305

306+
if extra_linear_modules:
307+
for module in extra_linear_modules:
308+
if issubclass(module, LinearAllreduce):
309+
q_op_map[module] = IpexWoqLinearAllreduce
310+
elif issubclass(module, LinearLayer):
311+
q_op_map[module] = WeightOnlyQuantizedLinear
311312
linear_modules = tuple(q_op_map.keys())
312313

313314
def _convert(mod, attr_name):

0 commit comments

Comments
 (0)