Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions torchao/dtypes/affine_quantized_tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@
_linear_int8_act_int8_weight_check,
_linear_int8_act_int8_weight_impl,
)
from torchao.dtypes.uintx.q_dq_layout import (
_embedding_check as _embedding_q_dq_check,
)
from torchao.dtypes.uintx.q_dq_layout import (
_embedding_impl as _embedding_q_dq_impl,
)
from torchao.dtypes.uintx.q_dq_layout import (
_linear_check as _linear_q_dq_check,
)
Expand Down Expand Up @@ -263,6 +269,9 @@ def _(func, types, args, kwargs):

@implements(torch.nn.functional.embedding)
def _(func, types, args, kwargs):
if _embedding_q_dq_check(args, kwargs):
return _embedding_q_dq_impl(args, kwargs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does line 299 only dequantizes weight bu tnot actually run embedding op?

# new_arg1 = args[1].dequantize()
# return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs)
assert isinstance(
Expand Down
13 changes: 13 additions & 0 deletions torchao/dtypes/uintx/q_dq_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,16 @@ def _linear_impl(input_tensor, weight_tensor, bias):
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)


def _embedding_check(args, kwargs):
_, weight_tensor = args
layout = weight_tensor.tensor_impl.get_layout()
return isinstance(layout, QDQLayout)


def _embedding_impl(args, kwargs):
input_tensor, weight_tensor = args
if isinstance(weight_tensor, AffineQuantizedTensor):
weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.embedding(input_tensor, weight_tensor, **kwargs)
87 changes: 43 additions & 44 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
quantize_per_channel_group,
)

from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.granularity import Granularity, PerAxis, PerGroup, PerRow
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -366,32 +366,44 @@ def __init__(
):
super().__init__()
self.bit_width = bit_width
self.pack_weights_op = getattr(
torch.ops.torchao, f"_pack_embedding_{bit_width}bit"
)
self.embedding_op = getattr(torch.ops.torchao, f"_embedding_{bit_width}bit")

def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
assert has_weight_zeros, "has_weight_zeros must be True for QuantizedEmbedding"
num_embeddings, embedding_dim = weights.shape
if group_size == -1:
group_size = embedding_dim
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.bit_width, has_weight_zeros=True
embedding = torch.nn.Embedding(num_embeddings, embedding_dim)
embedding.weight = weights
quantize_(
embedding,
IntxWeightOnlyConfig(
weight_dtype=getattr(torch, f"int{self.bit_width}"),
granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0),
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
mapping_type=MappingType.ASYMMETRIC,
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
weight_qvals, weight_scales, weight_zeros = (
embedding.weight.tensor_impl.get_plain()
)
weight_scales = weight_scales.reshape(num_embeddings, -1)
weight_zeros = weight_zeros.reshape(num_embeddings, -1).to(torch.int8)
self.register_buffer(
"packed_weight_qvals", self.pack_weights_op(weight_qvals.to(torch.int8))
"packed_weight_qvals",
getattr(torch.ops.torchao, f"_pack_embedding_{self.bit_width}bit")(
weight_qvals.to(torch.int8)
),
)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.register_buffer("weight_scales", weight_scales)
self.register_buffer("weight_zeros", weight_zeros.to(torch.int8))
self.register_buffer("weight_zeros", weight_zeros)

def forward(self, x):
shape = x.shape
return self.embedding_op(
return getattr(torch.ops.torchao, f"_embedding_{self.bit_width}bit")(
self.packed_weight_qvals,
self.num_embeddings,
self.embedding_dim,
Expand All @@ -410,38 +422,23 @@ def __init__(
self.bit_width = bit_width

def quantize_and_pack_weights(self, weights, group_size, has_weight_zeros):
assert (
has_weight_zeros
), "has_weight_zeros must be True for QuantizedEmbeddingFallback"
num_embeddings, embedding_dim = weights.shape
if group_size == -1:
group_size = embedding_dim
self.group_size = group_size

weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.bit_width, has_weight_zeros=True
self.embedding = torch.nn.Embedding(*weights.shape)
self.embedding.weight = weights
quantize_(
self.embedding,
IntxWeightOnlyConfig(
weight_dtype=getattr(torch, f"int{self.bit_width}"),
granularity=PerGroup(group_size) if group_size > 0 else PerAxis(0),
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
mapping_type=MappingType.ASYMMETRIC,
),
lambda m, fqn: isinstance(m, torch.nn.Embedding),
)
self.weight_qvals = weight_qvals.to(torch.int32)
self.weight_scales = weight_scales
self.weight_zeros = weight_zeros.to(torch.int32)

def forward(self, x):
shape = x.shape
res = []
for i in x:
res.append(
dequantize_per_channel_group(
w_int8=self.weight_qvals[i, :].reshape(1, -1),
scales=self.weight_scales[i, :].reshape(1, -1),
zero_points=self.weight_zeros[i, :].reshape(1, -1),
quant_min=None, # TODO: why is this an arg for this function
quant_max=None, # TODO: why is this an arg for this function
dtype=None, # TODO: why is this an arg for this function
group_size=self.group_size,
output_dtype=torch.float32,
).reshape(-1)
)
return torch.stack(res).reshape(*shape, -1)
return self.embedding(x)


class QuantizedSharedEmbedding(nn.Module):
Expand Down Expand Up @@ -586,15 +583,16 @@ class EmbeddingQuantizer:
def __init__(
self,
weight_dtype: torch.dtype = torch.int4,
granularity: Union[PerRow, PerGroup] = PerRow(),
granularity: Granularity = PerAxis(0),
has_weight_zeros: bool = True,
use_fallback: bool = False,
):
bit_width = _dtype_to_bit_width(weight_dtype)

if isinstance(granularity, PerGroup):
group_size = granularity.group_size
elif isinstance(granularity, PerRow):
elif isinstance(granularity, PerAxis):
assert granularity.axis == 0
group_size = -1
else:
raise ValueError(f"Unsupported granularity: {granularity}")
Expand Down Expand Up @@ -630,6 +628,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import (
IntxWeightOnlyConfig,
MappingType,
ZeroPointDomain,
to_affine_quantized_intx,
Expand Down
98 changes: 98 additions & 0 deletions torchao/experimental/quant_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,101 @@ def replace_q_dq_patterns_with_quantized_linear_ops_pass(

# Re-export
return torch.export.export(gm, *ep.example_inputs)


def _get_q_dq_embedding_patterns_replacements_and_filters(
weight_bit_width,
):
w_quant_min = -(1 << (weight_bit_width - 1))
w_quant_max = (1 << (weight_bit_width - 1)) - 1
w_target_dtype = torch.int8

def pattern(
indices,
w_int_data,
w_block_size,
w_scale,
w_zero_point,
):
dq_w = torch.ops.quant.dequantize_affine.default(
w_int_data,
w_block_size,
w_scale,
w_zero_point,
w_target_dtype,
w_quant_min,
w_quant_max,
)
return torch.ops.aten.embedding.default(dq_w, indices)

def replacement(
indices,
w_int_data,
w_block_size,
w_scale,
w_zero_point,
):
num_embeddings, embedding_dim = w_int_data.size()
packed_weight_qvals = getattr(
torch.ops.torchao, f"_pack_embedding_{weight_bit_width}bit"
)(w_int_data)
out_shape = indices.shape + (embedding_dim,)
group_size = w_block_size[-1]
n_groups = embedding_dim // group_size
w_scale = w_scale.reshape(-1, n_groups)
w_zero_point = w_zero_point.reshape(-1, n_groups)
return getattr(torch.ops.torchao, f"_embedding_{weight_bit_width}bit")(
packed_weight_qvals,
num_embeddings,
embedding_dim,
w_scale,
w_zero_point,
indices.reshape(-1),
).reshape(out_shape)

def match_filter(match, x, y):
def get_val(name):
node = [n for n in match.nodes_map if n.name == name][0]
return match.nodes_map[node]

# We only want w_block_size with shape [1, group_size]
w_block_size = get_val("w_block_size")
if len(w_block_size) != 2 or w_block_size[0] != 1:
return False

return True

return pattern, replacement, match_filter


def replace_q_dq_patterns_with_quantized_embedding_ops_pass(
ep: torch.export.ExportedProgram,
) -> torch.export.ExportedProgram:
"""
This replaces Q/DQ patterns with torchao quantized embedding ops.
It is intended for converting Q/DQ nodes exported with QDQLayout to using
the lowbit quantized embedding ops.
"""
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
# See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/
assert (
len(ep.range_constraints) == 0
), "ExportedProgram with range constraints are not supported"

# ep.module() unlifts the weight inputs, which we need for constant folding
gm = ep.module()
for weight_bit_width in range(1, 9):
pattern, replacement, match_filter = (
_get_q_dq_embedding_patterns_replacements_and_filters(
weight_bit_width,
)
)
subgraph_rewriter.replace_pattern_with_filters(
gm, pattern, replacement, match_filters=[match_filter]
)

# Constant fold evaluates and removes the packing ops
constant_fold(gm)

# Re-export
return torch.export.export(gm, *ep.example_inputs)
11 changes: 5 additions & 6 deletions torchao/experimental/tests/test_embedding_xbit_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Int8DynamicActivationIntxWeightConfig,
SharedEmbeddingQuantizer,
)
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.granularity import PerAxis, PerGroup, PerRow
from torchao.quantization.quant_api import quantize_


Expand Down Expand Up @@ -68,7 +68,7 @@ def test_accuracy(self):

def test_export_compile_aoti(self):
weight_dtype = torch.int4
granularity = PerRow()
granularity = PerAxis(0)
embedding_dim = 4096
num_embeddings = 131
model = torch.nn.Sequential(
Expand Down Expand Up @@ -113,7 +113,6 @@ def test_export_compile_aoti(self):

def test_shared_embedding(self):
weight_dtype = torch.int4
granularity = PerRow()
has_weight_zeros = True
embedding_dim = 4096
num_embeddings = 131
Expand All @@ -134,14 +133,14 @@ def test_shared_embedding(self):
quantized_model_reference = copy.deepcopy(model)
EmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
granularity=PerAxis(0),
has_weight_zeros=has_weight_zeros,
).quantize(quantized_model_reference)
quantize_(
quantized_model_reference,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
granularity=granularity,
granularity=PerRow(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be PerAxis as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can't be because that's controlled by Int8DynamicActivationIntxWeightConfig, which uses PerRow until #1968 lands

has_weight_zeros=has_weight_zeros,
round_weight_scale_to_bf16=False,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(
Expand All @@ -155,7 +154,7 @@ def test_shared_embedding(self):
quantized_model = copy.deepcopy(model)
SharedEmbeddingQuantizer(
weight_dtype=weight_dtype,
granularity=granularity,
granularity=PerRow(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this

has_weight_zeros=has_weight_zeros,
).quantize(quantized_model)

Expand Down
Loading
Loading