Skip to content

Conversation

@miladm
Copy link
Collaborator

@miladm miladm commented May 17, 2022

Full codegen for sgn, sign


Generate LazyIr.h

class Sgn : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::sgn);
  }

  Sgn(const torch_xla::XlaValue& self, std::vector<torch::lazy::Shape>&& shapes)

      : XlaNode(torch::lazy::OpKind(at::aten::sgn),
              {self}, std::move(shapes),
              [&]() { return SgnOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch_xla::XlaValue& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

class Sign : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::sign);
  }

  Sign(const torch_xla::XlaValue& self, std::vector<torch::lazy::Shape>&& shapes)

      : XlaNode(torch::lazy::OpKind(at::aten::sign),
              {self}, std::move(shapes),
              [&]() { return SignOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();

    return ss.str();
  }

  bool CanBeReused(const torch_xla::XlaValue& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;
};

Generated XLANativeFunctions.cpp :

    at::Tensor XLANativeFunctions::sgn(const at::Tensor & self) {

        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);

        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Sgn>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::sgn(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                char* schema_str = "aten::sgn(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }

            node = torch::lazy::MakeNode<Sgn>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }

        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };


    at::Tensor XLANativeFunctions::sign(const at::Tensor & self) {

        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);

        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Sign>(lazy_self->GetIrValue());
        if (!node) {
            auto out_meta = at::meta::sign(self);
            std::vector<torch::lazy::Shape> shapes{
        torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                char* schema_str = "aten::sign(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }

            node = torch::lazy::MakeNode<Sign>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }

        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

@miladm miladm added the codegen label May 17, 2022
@miladm miladm requested review from JackCaoG and wonjoo-wj May 17, 2022 10:28
@miladm miladm self-assigned this May 17, 2022
torch::lazy::MakeNode<Sign>(
f, std::vector<torch::lazy::Shape>()) *
torch::lazy::MakeNode<Sign>(
divisor, std::vector<torch::lazy::Shape>()),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kind of call pattern starts to get very messy. I suggest we think about a cleaner solution. Thinking...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to clean up all of the Node Level lowering. meaning a Node should not use another Node to do the lowering directly. The solution is to wrote a lower function for all these ops and in this case call BuildSign.

@miladm miladm linked an issue May 17, 2022 that may be closed by this pull request
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

torch::lazy::MakeNode<Sign>(
f, std::vector<torch::lazy::Shape>()) *
torch::lazy::MakeNode<Sign>(
divisor, std::vector<torch::lazy::Shape>()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to clean up all of the Node Level lowering. meaning a Node should not use another Node to do the lowering directly. The solution is to wrote a lower function for all these ops and in this case call BuildSign.

@JackCaoG JackCaoG merged commit 041ebf9 into master May 17, 2022
@JackCaoG JackCaoG deleted the ltc_sgn branch May 17, 2022 19:51
ysiraichi added a commit that referenced this pull request May 22, 2025
- `SgnOp` and `SignOp`
    - Full codegen migration: #3577
    - Mistakenly re-introduced: #3572
- `LogSigmoid`
    - Introduced: #3539
    - Full codegen migration: #3743
- `SiLU`
    - Introduced: #2721
    - Full codegen migration: #3780
- `SiLUBackward`
    - Introduced: #3195
    - Full codegen migration: #3780
- `SeLU`
    - Introduced: #3547
    - Full codegen migration: #3780
- `Sigmoid`
    - Introduced: 6a73deb (no PR record)
    - Full codegen migration: #6342
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

PyTorch/XLA Codegen Migration

3 participants