Skip to content

Conversation

@JackCaoG
Copy link
Collaborator

FIx #3777, #3776, #3775, #3774

LazyIR

class Selu : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::selu);
  }

  Selu(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::selu), {self}, std::move(shapes),
                [&]() { return SeluOutputShape(self); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self) const { return false; }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};
class Silu : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::silu);
  }

  Silu(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::silu), {self}, std::move(shapes),
                [&]() { return SiluOutputShape(self); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& self) const { return false; }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class SiluBackward : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::silu_backward);
  }

  SiluBackward(const torch::lazy::Value& grad_output,
               const torch::lazy::Value& self,
               std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::silu_backward),
                {grad_output, self}, std::move(shapes),
                [&]() { return SiluBackwardOutputShape(grad_output, self); },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& grad_output,
                   const torch::lazy::Value& self) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

NativeFunction

at::Tensor XLANativeFunctions::selu(const at::Tensor& self) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch::lazy::NodePtr node =
      torch::lazy::ReuseNode<Selu>(lazy_self->GetIrValue());
  if (!node) {
    auto shapes = torch::lazy::compute_shape_selu(self);
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self};
      const char* schema_str = "aten::selu(Tensor self) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node =
        torch::lazy::MakeNode<Selu>(lazy_self->GetIrValue(), std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};
at::Tensor XLANativeFunctions::silu(const at::Tensor& self) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch::lazy::NodePtr node =
      torch::lazy::ReuseNode<Silu>(lazy_self->GetIrValue());
  if (!node) {
    auto self_meta = to_meta(self);
    auto out_meta = at::meta::silu(self_meta);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {self};
      const char* schema_str = "aten::silu(Tensor self) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node =
        torch::lazy::MakeNode<Silu>(lazy_self->GetIrValue(), std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};
at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
                                             const at::Tensor& self) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_grad_output =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(grad_output,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<SiluBackward>(
      lazy_grad_output->GetIrValue(), lazy_self->GetIrValue());
  if (!node) {
    auto grad_output_meta = to_meta(grad_output);
    auto self_meta = to_meta(self);
    auto out_meta = at::meta::silu_backward(grad_output_meta, self_meta);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {grad_output, self};
      const char* schema_str =
          "aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<SiluBackward>(lazy_grad_output->GetIrValue(),
                                               lazy_self->GetIrValue(),
                                               std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
};

@JackCaoG JackCaoG requested a review from wonjoo-wj July 27, 2022 05:37
@JackCaoG
Copy link
Collaborator Author

[ RUN      ] AtenXlaTensorTest.TestSeluInPlace
/tmp/pytorch/xla/test/cpp/cpp_test_util.h:51: Failure
Value of: CloseValues(tensor, xla_tensor, rtol, atol)
  Actual: false
Expected: true
/tmp/pytorch/xla/test/cpp/cpp_test_util.h:51: Failure
Value of: CloseValues(tensor, xla_tensor, rtol, atol)
  Actual: false
Expected: true
[  FAILED  ] AtenXlaTensorTest.TestSeluInPlace (7 ms)

weird, this test passed locally.

@JackCaoG
Copy link
Collaborator Author

so I log into the ci vm and tried to run the test manually and see

(1,1,.,.) = 
  0.9270  0.9614  0.4023  1.0079  0.4102  0.6314
  0.2696  0.8339  0.9885  0.1399  0.9820  0.6237
  0.9135  0.5965  0.7787  0.4512  0.9303  0.6030
  0.2801  0.6593  0.2833  0.4637  0.3120  0.8739

(2,1,.,.) = 
  0.1107  0.2832  0.3770  0.2095  0.5749  0.0065
  0.9998  0.0791  0.9309  0.6128  0.3548  0.8500
  0.6072  0.9498  0.5828  0.3597  0.6665  0.3829
  0.7464  0.9944  0.8290  0.2957  0.8286  0.6193
[ CPUFloatType{2,1,4,6} ]
(1,1,.,.) = 
  0.9270  0.9614  0.4023  1.0079  0.4102  0.6314
  0.2696  0.8339  0.9885  0.1399  0.9820  0.6237
  0.9135  0.5965  0.7787  0.4512  0.9303  0.6030
  0.2801  0.6593  0.2833  0.4637  0.3120  0.8739

(2,1,.,.) = 
  0.1107  0.2832  0.3770  0.2095  0.5749  0.0065
  0.9998  0.0791  0.9309  0.6128  0.3548  0.8500
  0.6072  0.9498  0.5828  0.3597  0.6665  0.3829
  0.7464  0.9944  0.8290  0.2957  0.8286  0.6193
[ XLAFloatType{2,1,4,6} ]
[       OK ] AtenXlaTensorTest.TestSeluInPlace (148 ms)
[----------] 1 test from AtenXlaTensorTest (148 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (148 ms total)
[  PASSED  ] 1 test.

seems like it only fail when I run all cppTest and will pass if I only run one test

@JackCaoG
Copy link
Collaborator Author

weird.. I suspect this is caused by a cache issue. If I run with other cppTest, I saw

(1,1,.,.) = 
  0.9270  0.9614  0.4023  1.0079  0.4102  0.6314
  0.2696  0.8339  0.9885  0.1399  0.9820  0.6237
  0.9135  0.5965  0.7787  0.4512  0.9303  0.6030
  0.2801  0.6593  0.2833  0.4637  0.3120  0.8739

(2,1,.,.) = 
  0.1107  0.2832  0.3770  0.2095  0.5749  0.0065
  0.9998  0.0791  0.9309  0.6128  0.3548  0.8500
  0.6072  0.9498  0.5828  0.3597  0.6665  0.3829
  0.7464  0.9944  0.8290  0.2957  0.8286  0.6193
[ CPUFloatType{2,1,4,6} ]
(1,1,.,.) = 
  2.2057  2.2875  0.9572  2.3983  0.9761  1.5022
  0.6414  1.9841  2.3519  0.3330  2.3365  1.4839
  2.1735  1.4193  1.8527  1.0735  2.2136  1.4348
  0.6665  1.5686  0.6741  1.1034  0.7423  2.0792

(2,1,.,.) = 
  0.2633  0.6737  0.8970  0.4984  1.3680  0.0154
  2.3789  0.1882  2.2150  1.4580  0.8441  2.0224
  1.4448  2.2600  1.3866  0.8558  1.5859  0.9110
  1.7761  2.3660  1.9726  0.7035  1.9716  1.4737
[ XLAFloatType{2,1,4,6} ]

@JackCaoG
Copy link
Collaborator Author

I think I figured out what's going on, our Elu node is being implemented incorrectly

torch::lazy::NodePtr Elu(const torch::lazy::Value& input,
                         const at::Scalar& alpha, const at::Scalar& scale,
                         const at::Scalar& input_scale) {
  auto lower_fn = [=](const XlaNode& node,
                      LoweringContext* loctx) -> XlaOpVector {
    xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
    return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx);
  };
  return GenericOp(torch::lazy::OpKind(at::aten::elu), {input},
                   GetXlaShape(input), std::move(lower_fn));
}

GenericOp can only be used with the op that only takes torch::lazy::Value because it only include torch::lazy::Value in its node hash. What happened here is that selu_ uses elu but with a different alpha and scale. Given the implementation above, alpha and scale were not included in the hash. Pytorch/xla mistakenly thinking two graph is the same and reuse the previously compiled graph when they are different. I will work on a fix while codegen the elu.

FYI @wonjoolee95

@JackCaoG
Copy link
Collaborator Author

@wonjoolee95 I saw you assign elu to yourself but I think I will just grab it to unblock this pr.

@JackCaoG
Copy link
Collaborator Author

oh OK, ELU takes at::Scalar will can not be codegen until pytorch/pytorch#82208 resolved. In the mean time I think I will manually overwrite the selu native function and force it to use selu, this should unblock the test.

@JackCaoG JackCaoG force-pushed the codegen_silu_selu branch from 43d9ea8 to 2459546 Compare July 29, 2022 04:47
@wonjoo-wj
Copy link
Collaborator

I think I figured out what's going on, our Elu node is being implemented incorrectly

torch::lazy::NodePtr Elu(const torch::lazy::Value& input,
                         const at::Scalar& alpha, const at::Scalar& scale,
                         const at::Scalar& input_scale) {
  auto lower_fn = [=](const XlaNode& node,
                      LoweringContext* loctx) -> XlaOpVector {
    xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
    return node.ReturnOp(BuildElu(xla_input, alpha, scale, input_scale), loctx);
  };
  return GenericOp(torch::lazy::OpKind(at::aten::elu), {input},
                   GetXlaShape(input), std::move(lower_fn));
}

GenericOp can only be used with the op that only takes torch::lazy::Value because it only include torch::lazy::Value in its node hash. What happened here is that selu_ uses elu but with a different alpha and scale. Given the implementation above, alpha and scale were not included in the hash. Pytorch/xla mistakenly thinking two graph is the same and reuse the previously compiled graph when they are different. I will work on a fix while codegen the elu.

FYI @wonjoolee95

This is really interesting.. thanks for the investigation!

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@JackCaoG JackCaoG merged commit 63a85d9 into master Jul 29, 2022
@wonjoo-wj wonjoo-wj deleted the codegen_silu_selu branch August 1, 2022 22:09
ysiraichi added a commit that referenced this pull request May 22, 2025
- `SgnOp` and `SignOp`
    - Full codegen migration: #3577
    - Mistakenly re-introduced: #3572
- `LogSigmoid`
    - Introduced: #3539
    - Full codegen migration: #3743
- `SiLU`
    - Introduced: #2721
    - Full codegen migration: #3780
- `SiLUBackward`
    - Introduced: #3195
    - Full codegen migration: #3780
- `SeLU`
    - Introduced: #3547
    - Full codegen migration: #3780
- `Sigmoid`
    - Introduced: 6a73deb (no PR record)
    - Full codegen migration: #6342
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

silu_backward

3 participants