-
Notifications
You must be signed in to change notification settings - Fork 560
Closed
Description
🐛 Bug
We noticed ~18% performance drop in BERT model after #3768. It looks like this issue is due to a new flag in upstream LTC not being enabled by default here. This special scalar check is important for XLA to optimize ops like torch.addcdiv(a, b, c, value=1.0), torch.add(a, b, alpha=1.0) with constant folding.
To Reproduce
import torch
import torch_xla.core.xla_model as xm
import torch_xla
import torch_xla.debug.metrics as met
device = xm.xla_device()
a = torch.rand(10).to(device)
b = torch.rand(10).to(device)
c = torch.rand(10).to(device)
d = torch.addcdiv(a, b, c, value=1.0)
print(torch_xla._XLAC._get_xla_tensors_hlo([d]))HLO dump:
HloModule IrToHlo.10, entry_computation_layout={(f32[],f32[10]{0},f32[10]{0},f32[10]{0})->(f32[10]{0})}
ENTRY %IrToHlo.10 (p0.1: f32[], p1.2: f32[10], p2.3: f32[10], p3.4: f32[10]) -> (f32[10]) {
%p3.4 = f32[10]{0} parameter(3)
%p2.3 = f32[10]{0} parameter(2)
%p1.2 = f32[10]{0} parameter(1)
%divide.5 = f32[10]{0} divide(f32[10]{0} %p2.3, f32[10]{0} %p1.2)
%p0.1 = f32[] parameter(0)
%broadcast.6 = f32[10]{0} broadcast(f32[] %p0.1), dimensions={}
%multiply.7 = f32[10]{0} multiply(f32[10]{0} %divide.5, f32[10]{0} %broadcast.6)
%add.8 = f32[10]{0} add(f32[10]{0} %p3.4, f32[10]{0} %multiply.7)
ROOT %tuple.9 = (f32[10]{0}) tuple(f32[10]{0} %add.8)
}Note setting torch_lazy_handle_special_scalars=True solves the special scalar problem, but the result is improperly casted to fp64:
HloModule IrToHlo.12, entry_computation_layout={(f32[10]{0},f32[10]{0},f32[10]{0})->(f64[10]{0})}
ENTRY %IrToHlo.12 (p0.2: f32[10], p1.3: f32[10], p2.4: f32[10]) -> (f64[10]) {
%p2.4 = f32[10]{0} parameter(2)
%convert.9 = f64[10]{0} convert(f32[10]{0} %p2.4)
%p1.3 = f32[10]{0} parameter(1)
%p0.2 = f32[10]{0} parameter(0)
%divide.5 = f32[10]{0} divide(f32[10]{0} %p1.3, f32[10]{0} %p0.2)
%convert.6 = f64[10]{0} convert(f32[10]{0} %divide.5)
%constant.1 = f64[] constant(1)
%broadcast.7 = f64[10]{0} broadcast(f64[] %constant.1), dimensions={}
%multiply.8 = f64[10]{0} multiply(f64[10]{0} %convert.6, f64[10]{0} %broadcast.7)
%add.10 = f64[10]{0} add(f64[10]{0} %convert.9, f64[10]{0} %multiply.8)
ROOT %tuple.11 = (f64[10]{0}) tuple(f64[10]{0} %add.10)
}Expected behavior
value=1.0 in torch.addcdiv should be treated as constant.
Environment
- Reproducible on XLA backend [CPU/TPU]: GPU
- torch_xla version: master
Metadata
Metadata
Assignees
Labels
No labels