Skip to content

Conversation

@vanbasten23
Copy link
Collaborator

@vanbasten23 vanbasten23 commented Sep 13, 2022

LazyIr.h:

class Hardshrink : public XlaNode {
 public:
  static torch::lazy::OpKind ClassOpKind() {
    return torch::lazy::OpKind(at::aten::hardshrink);
  }

  Hardshrink(const torch::lazy::Value& self, const torch::lazy::Value& lambd, std::vector<torch::lazy::Shape>&& shapes)
      : XlaNode(torch::lazy::OpKind(at::aten::hardshrink),
              {self, lambd}, std::move(shapes),
              [&]() { return HardshrinkOutputShape(self, lambd); },
              /* num_outputs */ 1,
              torch::lazy::MHash())
  {
    
  }

  std::string ToString() const override {
    std::stringstream ss;
    ss << XlaNode::ToString();
    
    return ss.str();
  }

  

  bool CanBeReused(const torch::lazy::Value& self, const torch::lazy::Value& lambd) const {
    return false;
    }

  torch_xla::XlaOpVector Lower(LoweringContext* loctx) const override;

  
  

};

XLANativeFunctions.cpp:

    at::Tensor XLANativeFunctions::hardshrink(const at::Tensor & self, const at::Scalar & lambd) {
        
        XLA_FN_COUNTER("xla::");
        auto common_device = torch_xla::bridge::GetXlaDevice(self);
        TORCH_INTERNAL_ASSERT(common_device);
        
        torch_xla::XLATensorPtr lazy_self = torch_xla::bridge::GetXlaTensorOrCreateForWrappedNumber(self, *common_device);
        auto node_lambd = torch::lazy::LazyGraphExecutor::Get()->
                            GetIrValueForScalarFromCodegen(lambd, *common_device);
        torch::lazy::NodePtr node = torch::lazy::ReuseNode<Hardshrink>(lazy_self->GetIrValue(), node_lambd);
        if (!node) {
                    auto self_meta = to_meta(self);
        auto out_meta = at::meta::hardshrink(self_meta, lambd);
        
std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
            TORCH_INTERNAL_ASSERT(shapes.size() == 1);
            if(torch::lazy::symbolicShapeEnabled()){
                std::vector<torch::jit::IValue> inputs = { self, lambd };
                const char* schema_str = "aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor";
                applySymbolicShapesOnLT(schema_str, inputs, shapes);
            }
        
            node = torch::lazy::MakeNode<Hardshrink>(lazy_self->GetIrValue(), node_lambd, std::move(shapes));
            CacheNode(node);
        }
        
        auto result = torch_xla::bridge::AtenFromXlaTensor(
                torch_xla::XLATensor::Create(std::move(node), *common_device));
        return result;
    }

}

xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda) {
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be careful about these at::Scalar -> xla::XlaOp because currently the default is f64. We can do a MaybeCast here, take a look at https://github.com/pytorch/xla/blame/f8b3dfd45d753a8844aca871cb39511022bb35ff/torch_xla/csrc/elementwise.cpp#L383 for example.

Copy link
Collaborator Author

@vanbasten23 vanbasten23 Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't cast, then we may get error

Seen floating point types of different precisions in %subtract.4 = f64[] subtract(f32[] %constant.3, f64[] %constant.1), metadata={op_type="aten__hardshrink" op_name="aten__hardshrink" source_file="[email protected]" source_line=1026}, but mixed precision is disallowed.

? Is this the reason we have to do a MaybeCast here?

Also, shouldn't we cast the element type of xla::XlaOp to its original type of at::Scalar?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check

XlaHelpers::ScalarValue(min_val, element_type, builder));
for existing behavior. scalar often default to f64 when user don't specified it and will cause additional type promotion issue if we use its own type. It is easier to user the other operand type in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. Thanks!

By "promotion issue", you meant when we have an operation with 2 mixed type operands, xla will try to convert one type to another implicitly, just like any c++ operator?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea.. we usually promote to "more complex" type like f64 and s64 which are slower to compute.

Copy link
Collaborator

@wonjoo-wj wonjoo-wj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@vanbasten23 vanbasten23 merged commit a1691bd into master Sep 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants