Skip to content

Conversation

@wonjoo-wj
Copy link
Collaborator

@wonjoo-wj wonjoo-wj commented May 20, 2022

Full codegen logical boolean ops

  • logical_and
  • logical_not
  • logical_or
  • logical_xor

Currently blocked by:


Generated LazyIr.h:

class LogicalAnd : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::logical_and);
  }

  LogicalAnd(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::logical_and),
              {self, other}, std::move(shapes),
              [&]() { return LogicalAndOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class LogicalNot : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::logical_not);
  }

  LogicalNot(const torch::lazy::Value& self, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::logical_not),
              {self}, std::move(shapes),
              [&]() { return LogicalNotOutputShape(self); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class LogicalOr : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::logical_or);
  }

  LogicalOr(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::logical_or),
              {self, other}, std::move(shapes),
              [&]() { return LogicalOrOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

class LogicalXor : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::logical_xor);
  }

  LogicalXor(const torch::lazy::Value& self, const torch::lazy::Value& other, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::logical_xor),
              {self, other}, std::move(shapes),
              [&]() { return LogicalXorOutputShape(self, other); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& other) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

Generated XLANativeFunctions.cpp:

    at::Tensor XLANativeFunctions::logical_and(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<LogicalAnd>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_logical_and(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::logical_and(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<LogicalAnd>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::logical_not(const at::Tensor & self) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<LogicalNot>(lazy_self->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_logical_not(self);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self };
                const char* schema_str = "aten::logical_not(Tensor self) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<LogicalNot>(lazy_self->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::logical_or(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<LogicalOr>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_logical_or(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::logical_or(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<LogicalOr>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

    
    at::Tensor XLANativeFunctions::logical_xor(const at::Tensor & self, const at::Tensor & other) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self, other);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        torch_xla::XLATensorPtr lazy_other = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(other, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<LogicalXor>(lazy_self->GetIrValue(), lazy_other->GetIrValue());
        if (!node) {
            
            auto shapes = torch::lazy::compute_shape_logical_xor(self, other);
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, other };
                const char* schema_str = "aten::logical_xor(Tensor self, Tensor other) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<LogicalXor>(lazy_self->GetIrValue(), lazy_other->GetIrValue(), std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    };

@wonjoo-wj wonjoo-wj self-assigned this May 20, 2022
@wonjoo-wj wonjoo-wj changed the title [WIP] Full codegen logical boolean ops Full codegen logical boolean ops Jul 20, 2022
@wonjoo-wj
Copy link
Collaborator Author

This is now unblocked as of pytorch/pytorch#78004 and pytorch/pytorch#79999.

@wonjoo-wj wonjoo-wj requested review from JackCaoG and steventk-g July 20, 2022 06:37
Copy link
Collaborator

@JackCaoG JackCaoG left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Feel free to merge after test passed

@wonjoo-wj
Copy link
Collaborator Author

@steventk-g, FYI this PR is an example of adding shape inference function to upstream pytorch. Take a look at the linked pytorch PRs above for examples.

@wonjoo-wj
Copy link
Collaborator Author

GPU tests failed with:

Downloading Python-3.7.0.tar.xz...
-> https://www.python.org/ftp/python/3.7.0/Python-3.7.0.tar.xz
error: failed to download Python-3.7.0.tar.xz

BUILD FAILED (Ubuntu 20.04 using python-build 2.2.0)

Doesn't seem related to this PR, will give it a retry.

@wonjoo-wj wonjoo-wj merged commit f62b9d8 into master Jul 21, 2022
@wonjoo-wj wonjoo-wj deleted the codegen-3 branch July 21, 2022 01:55
@wonjoo-wj
Copy link
Collaborator Author

wonjoo-wj commented Jul 21, 2022

CI all green after retying

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants