Skip to content

Performance regression after migrating to LTC codegen (addcdiv, addcmul) #3942

@ymwangg

Description

@ymwangg

🐛 Bug

We noticed ~18% performance drop in BERT model after #3768. It looks like this issue is due to a new flag in upstream LTC not being enabled by default here. This special scalar check is important for XLA to optimize ops like torch.addcdiv(a, b, c, value=1.0), torch.add(a, b, alpha=1.0) with constant folding.

To Reproduce

import torch
import torch_xla.core.xla_model as xm
import torch_xla
import torch_xla.debug.metrics as met

device = xm.xla_device()

a = torch.rand(10).to(device)
b = torch.rand(10).to(device)
c = torch.rand(10).to(device)
d = torch.addcdiv(a, b, c, value=1.0)
print(torch_xla._XLAC._get_xla_tensors_hlo([d]))

HLO dump:

HloModule IrToHlo.10, entry_computation_layout={(f32[],f32[10]{0},f32[10]{0},f32[10]{0})->(f32[10]{0})}

ENTRY %IrToHlo.10 (p0.1: f32[], p1.2: f32[10], p2.3: f32[10], p3.4: f32[10]) -> (f32[10]) {
  %p3.4 = f32[10]{0} parameter(3)
  %p2.3 = f32[10]{0} parameter(2)
  %p1.2 = f32[10]{0} parameter(1)
  %divide.5 = f32[10]{0} divide(f32[10]{0} %p2.3, f32[10]{0} %p1.2)
  %p0.1 = f32[] parameter(0)
  %broadcast.6 = f32[10]{0} broadcast(f32[] %p0.1), dimensions={}
  %multiply.7 = f32[10]{0} multiply(f32[10]{0} %divide.5, f32[10]{0} %broadcast.6)
  %add.8 = f32[10]{0} add(f32[10]{0} %p3.4, f32[10]{0} %multiply.7)
  ROOT %tuple.9 = (f32[10]{0}) tuple(f32[10]{0} %add.8)
}

Note setting torch_lazy_handle_special_scalars=True solves the special scalar problem, but the result is improperly casted to fp64:

HloModule IrToHlo.12, entry_computation_layout={(f32[10]{0},f32[10]{0},f32[10]{0})->(f64[10]{0})}

ENTRY %IrToHlo.12 (p0.2: f32[10], p1.3: f32[10], p2.4: f32[10]) -> (f64[10]) {
  %p2.4 = f32[10]{0} parameter(2)
  %convert.9 = f64[10]{0} convert(f32[10]{0} %p2.4)
  %p1.3 = f32[10]{0} parameter(1)
  %p0.2 = f32[10]{0} parameter(0)
  %divide.5 = f32[10]{0} divide(f32[10]{0} %p1.3, f32[10]{0} %p0.2)
  %convert.6 = f64[10]{0} convert(f32[10]{0} %divide.5)
  %constant.1 = f64[] constant(1)
  %broadcast.7 = f64[10]{0} broadcast(f64[] %constant.1), dimensions={}
  %multiply.8 = f64[10]{0} multiply(f64[10]{0} %convert.6, f64[10]{0} %broadcast.7)
  %add.10 = f64[10]{0} add(f64[10]{0} %convert.9, f64[10]{0} %multiply.8)
  ROOT %tuple.11 = (f64[10]{0}) tuple(f64[10]{0} %add.10)
}

Expected behavior

value=1.0 in torch.addcdiv should be treated as constant.

Environment

  • Reproducible on XLA backend [CPU/TPU]: GPU
  • torch_xla version: master

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions