Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

LazyIr.h:

class HardshrinkBackward : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::hardshrink_backward);
  }

  HardshrinkBackward(const torch::lazy::Value& grad_out,
                     const torch::lazy::Value& self,
                     const torch::lazy::Value& lambd,
                     std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::hardshrink_backward),
                {grad_out, self, lambd}, std::move(shapes),
                [&]() {
                  return HardshrinkBackwardOutputShape(grad_out, self, lambd);
                },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& grad_out,
                   const torch::lazy::Value& self,
                   const torch::lazy::Value& lambd) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class SoftshrinkBackward : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::softshrink_backward);
  }

  SoftshrinkBackward(const torch::lazy::Value& grad_output,
                     const torch::lazy::Value& self,
                     const torch::lazy::Value& lambd,
                     std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::softshrink_backward),
                {grad_output, self, lambd}, std::move(shapes),
                [&]() {
                  return SoftshrinkBackwardOutputShape(grad_output, self,
                                                       lambd);
                },
                /* num_outputs */ 1, torch::lazy::MHash()) {}

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch::lazy::Value& grad_output,
                   const torch::lazy::Value& self,
                   const torch::lazy::Value& lambd) const {
    return false;
  }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

XlaNativeFunctions.cpp:

at::Tensor XLANativeFunctions::hardshrink_backward(const at::Tensor& grad_out,
                                                   const at::Tensor& self,
                                                   const at::Scalar& lambd) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(grad_out, self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_grad_out =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(grad_out,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  auto node_lambd =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          lambd, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<HardshrinkBackward>(
      lazy_grad_out->GetIrValue(), lazy_self->GetIrValue(), node_lambd);
  if (!node) {
    auto grad_out_meta = to_meta(grad_out);
    auto self_meta = to_meta(self);
    auto out_meta =
        at::meta::hardshrink_backward(grad_out_meta, self_meta, lambd);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {grad_out, self, lambd};
      const char* schema_str =
          "aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar "
          "lambd) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<HardshrinkBackward>(
        lazy_grad_out->GetIrValue(), lazy_self->GetIrValue(), node_lambd,
        std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
}

at::Tensor XLANativeFunctions::softshrink_backward(
    const at::Tensor& grad_output, const at::Tensor& self,
    const at::Scalar& lambd) {
  XLA_FN_COUNTER("xla::");
  auto common_device = torch_xla::bridge::GetXlaDevice(grad_output, self);
  TORCH_INTERNAL_ASSERT(common_device);

  torch_xla::XLATensorPtr lazy_grad_output =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(grad_output,
                                                              *common_device);
  torch_xla::XLATensorPtr lazy_self =
      torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self,
                                                              *common_device);
  auto node_lambd =
      torch::lazy::LazyGraphExecutor::Get()->GetIrValueForScalarFromCodegen(
          lambd, *common_device);
  torch::lazy::NodePtr node = torch::lazy::ReuseNode<SoftshrinkBackward>(
      lazy_grad_output->GetIrValue(), lazy_self->GetIrValue(), node_lambd);
  if (!node) {
    auto grad_output_meta = to_meta(grad_output);
    auto self_meta = to_meta(self);
    auto out_meta =
        at::meta::softshrink_backward(grad_output_meta, self_meta, lambd);

    std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
    TORCH_INTERNAL_ASSERT(shapes.size() == 1);
    if (torch::lazy::symbolicShapeEnabled()) {
      std::vector<torch::jit::IValue> inputs = {grad_output, self, lambd};
      const char* schema_str =
          "aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar "
          "lambd) -> Tensor";
      applySymbolicShapesOnLT(schema_str, inputs, shapes);
    }

    node = torch::lazy::MakeNode<SoftshrinkBackward>(
        lazy_grad_output->GetIrValue(), lazy_self->GetIrValue(), node_lambd,
        std::move(shapes));
    CacheNode(node);
  }

  auto result = torch_xla::bridge::AtenFromXlaTensor(
      torch_xla::XLATensor::Create(std::move(node), *common_device));
  return result;
}

@vanbasten23 vanbasten23 force-pushed the codegenHardshrinkBackward branch from da0be9a to dcd19cb Compare September 16, 2022 00:12
@vanbasten23 vanbasten23 marked this pull request as draft September 16, 2022 00:21
@vanbasten23 vanbasten23 removed the request for review from wonjoo-wj September 16, 2022 18:49
@vanbasten23 vanbasten23 force-pushed the codegenHardshrinkBackward branch from dcd19cb to 640ea5b Compare September 16, 2022 18:53
@JackCaoG JackCaoG requested a review from wonjoo-wj September 19, 2022 17:27
@wonjoo-wj
Copy link
Collaborator

@vanbasten23, any reason why this is in draft?


xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input,
const at::Scalar& lambda) {
xla::XlaOp lambda) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we may want to do MaybeCast similar to 4ccfc24

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@vanbasten23
Copy link
Collaborator Author

@vanbasten23, any reason why this is in draft?

I realized I haven't done the MaybeCast thing. So I'll fix it and then mark this pr as ready for review. :p

@vanbasten23 vanbasten23 reopened this Sep 19, 2022
@vanbasten23 vanbasten23 marked this pull request as ready for review September 20, 2022 00:04
Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! 👍

@vanbasten23 vanbasten23 merged commit ab15402 into master Sep 20, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants