Skip to content

Commit ee7dc34

Browse files
zhuhaozhejianan-gu
andauthored
delete plain weight while prepack (#1445) (#1593)
* delete plain weight while prepack * only delete plain weight when user setting inplace * fix ut Co-authored-by: jianan-gu <[email protected]>
1 parent bc76ab1 commit ee7dc34

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

intel_extension_for_pytorch/frontend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -531,15 +531,15 @@ def optimize(
531531
"FP16 weight prepack needs the cpu support avx512_core_fp16, " + \
532532
"please set dtype to torch.float or set weights_prepack to False."
533533
optimized_model, optimized_optimizer, params_attr = utils._weight_prepack.weight_prepack_with_ipex(
534-
optimized_model, optimized_optimizer, params_attr, 'cpu')
534+
optimized_model, optimized_optimizer, params_attr, inplace, 'cpu')
535535
torch._dynamo.allow_in_graph(utils._weight_prepack._IPEXConv2d)
536536
torch._dynamo.allow_in_graph(utils._weight_prepack._IPEXConvTranspose2d)
537537
torch._dynamo.allow_in_graph(utils._weight_prepack._IPEXLinear)
538538
torch._dynamo.allow_in_graph(utils._model_convert._LSTM)
539539
else:
540540
assert device_type == 'xpu', "Unknown device type, only support device CPU and XPU"
541541
optimized_model, optimized_optimizer, params_attr = utils._weight_prepack.weight_prepack_with_ipex(
542-
optimized_model, optimized_optimizer, params_attr, 'xpu')
542+
optimized_model, optimized_optimizer, params_attr, inplace, 'xpu')
543543

544544
if opt_properties.graph_mode:
545545
_old_forward = optimized_model.forward

intel_extension_for_pytorch/nn/utils/_weight_prepack.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def weight_prepack_with_ipex_xpu(module):
405405
weight_prepack_with_ipex_xpu(child)
406406
return module
407407

408-
def weight_prepack_with_ipex(module, optimizer, params_attr, device_type='cpu'):
408+
def weight_prepack_with_ipex(module, optimizer, params_attr, inplace=False, device_type='cpu'):
409409
def convert(m, optimizer, params_attr):
410410
if _should_prepack(m, is_training=(optimizer!=None)) and (m.weight.dtype == torch.float32 or m.weight.dtype == torch.bfloat16 or m.weight.dtype == torch.half):
411411
weight = m.master_weight if hasattr(m, "master_weight") else m.weight
@@ -457,6 +457,8 @@ def convert(m, optimizer, params_attr):
457457
# replace optimizer's param with prepacked param, also prepack its state.
458458
optim._optimizer_utils.pack_optimizer_params_and_states(
459459
optimizer, params_pair, params_attr, m.weight.dtype)
460+
if inplace:
461+
del m.weight
460462
return new_m
461463
else:
462464
return m

tests/cpu/test_ipex_optimize.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,11 +224,10 @@ def test_optimize_inplace_behavior_eval_mode(self):
224224
opt_M = ipex.optimize(M, dtype=dtype, level=level, inplace=True)
225225
# After ConvBN folding, opt_M will be Graph Module while the M is original nn.Module which they
226226
# share parameters. But the changes on Graph Module cannot be reflected on original module. So
227-
# only the un-opitimized weight will use same mem buffer with original module.
228-
# While dtype = float, ipex.optimize will choose mkl backend and does not prepack weight
227+
# only the un-opitimized weight will use same mem buffer with original module.
229228
if level == "O1":
230229
self.assertTrue(M.conv.weight.data_ptr() != opt_M.conv.weight.data_ptr())
231-
self.assertTrue(dtype is torch.float or M.linear.weight.data_ptr() != opt_M.linear.weight.data_ptr())
230+
self.assertFalse(hasattr(M.linear, 'weight'))
232231
# un-optimized part should be inplaced
233232
self.assertTrue(M.embeddingbag.weight.data_ptr() == opt_M.embeddingbag.weight.data_ptr())
234233

tests/cpu/test_weight_prepack.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def forward(self, x):
610610
# Example taken from GPT-J. The weight loaded from the state_dict is non-contiguous with the below size and stride:
611611
m.linear.weight = torch.nn.Parameter(copy.deepcopy(m.linear.weight).as_strided([oc, ic], [1, oc]))
612612

613-
optimized_m = ipex.optimize(m, dtype=dtype, inplace=True)
613+
optimized_m = ipex.optimize(m, dtype=dtype, inplace=False)
614614
with torch.cpu.amp.autocast(enabled=True, dtype=dtype):
615615
jit_m = torch.jit.trace(optimized_m, x)
616616
jit_m = torch.jit.freeze(jit_m)

0 commit comments

Comments
 (0)