Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions python/llm/src/ipex_llm/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
:param model: The original PyTorch model (nn.module)
:param low_bit: str value, options are ``'sym_int4'``, ``'asym_int4'``, ``'sym_int5'``,
``'asym_int5'``, ``'sym_int8'``, ``'nf3'``, ``'nf4'``, ``'fp4'``,
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'`` or ``'bf16'``,
``'fp8'``, ``'fp8_e4m3'``, ``'fp8_e5m2'``, ``'fp16'``, ``'bf16'`` or None,
``'sym_int4'`` means symmetric int 4, ``'asym_int4'`` means
asymmetric int 4, ``'nf4'`` means 4-bit NormalFloat, etc.
Relevant low bit optimizations will be applied to the model.
Expand All @@ -225,10 +225,11 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
>>> # (Optional) you can also save the optimized model by calling 'save_low_bit'
>>> model.save_low_bit(saved_dir)
"""
invalidInputError(low_bit in ggml_tensor_qtype,
invalidInputError(low_bit is None or low_bit in ggml_tensor_qtype,
f"Unknown load_in_low_bit value: {low_bit}, expected:"
f" sym_int4, asym_int4, sym_int5, asym_int5 or sym_int8.")
invalidInputError(isinstance(model, torch.nn.Module),
invalidInputError(isinstance(model, torch.nn.Module) or
model.__class__.__name__ == "StableDiffusionPipeline",
"model should be an instance of "
f"`torch.nn.Module`, but got {type(model)} at last.")
# To adapt vLLM models
Expand All @@ -249,7 +250,7 @@ def optimize_model(model, low_bit='sym_int4', optimize_llm=True, modules_to_not_
torch_dtype = torch.float16
else:
torch_dtype = kwargs.get("torch_dtype", "auto")
qtype = ggml_tensor_qtype[low_bit]
qtype = ggml_tensor_qtype[low_bit] if low_bit is not None else None
model = ggml_convert_low_bit(model,
qtype=qtype,
torch_dtype=torch_dtype,
Expand Down
66 changes: 38 additions & 28 deletions python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
logger.info(f"Converting the current model to "
f"{list(ggml_tensor_qtype.keys())[index]} "
f"format......")
else:
elif qtype in gguf_mixed_qtype.values():
index = list(gguf_mixed_qtype.values()).index(qtype)
logger.info(f"Converting the current model to "
f"{list(gguf_mixed_qtype.keys())[index]} "
Expand Down Expand Up @@ -1089,34 +1089,35 @@ def ggml_convert_low_bit(model, qtype, optimize_model=True,
enable_scale_search = use_scale_search(model_config, qtype)

# mixed quantization needs model_config to choose custom quantization strategy
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
if not has_been_replaced:
warnings.warn(
"No linear modules were found in "
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
if qtype is not None:
model, has_been_replaced = _replace_with_low_bit_linear(
model, qtype, modules_to_not_convert,
convert_shape_only, cpu_embedding,
imatrix_data=imatrix_data,
embedding_qtype=embedding_qtype,
model_config=model_config,
torch_dtype=torch_dtype,
enable_xetla=enable_xetla,
mixed_precision=mixed_precision,
act_order=act_order,
enable_scale_search=enable_scale_search,
)
elif device == "cpu":
if not (getattr(model, "quantization_method", None) == "gptq"):
if torch_dtype == "auto":
convert_bigdl_other_module(model, torch.float32)
else:
convert_bigdl_other_module(model, torch_dtype)
elif device == "meta":
# Do nothing here for weights are empty.
pass
if not has_been_replaced:
warnings.warn(
"No linear modules were found in "
"your model. This can happen for some architectures such as gpt2 that uses Conv1D "
"instead of Linear layers. Please double check your model architecture, or submit "
"an issue on github if you think this is a bug."
)
elif device == "cpu":
if not (getattr(model, "quantization_method", None) == "gptq"):
if torch_dtype == "auto":
convert_bigdl_other_module(model, torch.float32)
else:
convert_bigdl_other_module(model, torch_dtype)
elif device == "meta":
# Do nothing here for weights are empty.
pass

if optimize_model:
model = _optimize_post(model, lightweight_bmm)
Expand Down Expand Up @@ -1221,6 +1222,15 @@ def _optimize_ipex(model, qtype=ggml_tensor_qtype["bf16"]):


def _optimize_post(model, lightweight_bmm=False):
try:
from diffusers import StableDiffusionPipeline
if isinstance(model, StableDiffusionPipeline):
from ipex_llm.transformers.models.sd15 import AttnProcessor2_0
model.unet.set_attn_processor(AttnProcessor2_0())
return model
except ModuleNotFoundError:
pass

try:
from sentence_transformers.SentenceTransformer import SentenceTransformer
if isinstance(model, SentenceTransformer):
Expand Down
Loading