-
Notifications
You must be signed in to change notification settings - Fork 559
Codegen hardshrink #3999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Codegen hardshrink #3999
Conversation
| } | ||
|
|
||
| xla::XlaOp BuildHardshrink(xla::XlaOp input, const at::Scalar& lambda) { | ||
| xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to be careful about these at::Scalar -> xla::XlaOp because currently the default is f64. We can do a MaybeCast here, take a look at https://github.com/pytorch/xla/blame/f8b3dfd45d753a8844aca871cb39511022bb35ff/torch_xla/csrc/elementwise.cpp#L383 for example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we don't cast, then we may get error
Seen floating point types of different precisions in %subtract.4 = f64[] subtract(f32[] %constant.3, f64[] %constant.1), metadata={op_type="aten__hardshrink" op_name="aten__hardshrink" source_file="[email protected]" source_line=1026}, but mixed precision is disallowed.
? Is this the reason we have to do a MaybeCast here?
Also, shouldn't we cast the element type of xla::XlaOp to its original type of at::Scalar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check
xla/torch_xla/csrc/elementwise.cpp
Line 23 in 76df130
| XlaHelpers::ScalarValue(min_val, element_type, builder)); |
other operand type in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks!
By "promotion issue", you meant when we have an operation with 2 mixed type operands, xla will try to convert one type to another implicitly, just like any c++ operator?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea.. we usually promote to "more complex" type like f64 and s64 which are slower to compute.
wonjoo-wj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
LazyIr.h:
XLANativeFunctions.cpp: