-
Notifications
You must be signed in to change notification settings - Fork 315
Description
What is the recommended way to show the dtype that the tensor appears to be? i.e. when call subclass_tensor.dtype
I see that the current AffineQuantizedTensor
and NF4Tensor
will show the original dtype. I understand that this helps with compatibility for existing code (e.g. in gpt-fast, KVCache dtype is taken from weight dtype)
ao/torchao/_models/llama/model.py
Line 122 in f172c47
dtype = self.output.weight.dtype |
However, personally I feel that it is a bit unintuitive, because the weight is actually not FP32/BF16 anymore (but it appears to be so for compatibility reason I suppose)
@msaroufim also mentions that
This is unfortunately a big limitation with subclasses mostly because of limitations with autograd that are very difficult to get rid of