Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2785,32 +2785,14 @@ at::Tensor XLANativeFunctions::select(const at::Tensor& self, int64_t dim,
XLATensor::select(bridge::GetXlaTensor(self), dim, index));
}

at::Tensor XLANativeFunctions::selu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::selu(bridge::GetXlaTensor(self)));
}

// TODO(JackCaoG): Remove after elu being codegened
at::Tensor& XLANativeFunctions::selu_(at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
XLATensor::selu_(self_tensor);
return self;
}

at::Tensor XLANativeFunctions::silu(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::silu(bridge::GetXlaTensor(self)));
}

at::Tensor XLANativeFunctions::silu_backward(const at::Tensor& grad_output,
const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
XLATensorPtr grad_output_tensor = bridge::GetXlaTensor(grad_output);
XLATensorPtr self_tensor = bridge::GetXlaTensor(self);
return bridge::AtenFromXlaTensor(
XLATensor::silu_backward(grad_output_tensor, self_tensor));
}

at::Tensor XLANativeFunctions::sigmoid(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(
Expand Down
16 changes: 16 additions & 0 deletions torch_xla/csrc/ops/ops_lower_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,11 @@ torch_xla::XlaOpVector Rsqrt::Lower(LoweringContext* loctx) const {
return ReturnOp(xla::Rsqrt(xla_input), loctx);
}

torch_xla::XlaOpVector Selu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildSelu(xla_input), loctx);
}

torch_xla::XlaOpVector Sgn::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(BuildSgn(xla_input), loctx);
Expand All @@ -226,6 +231,17 @@ torch_xla::XlaOpVector Sign::Lower(LoweringContext* loctx) const {
return ReturnOp(BuildSign(xla_input), loctx);
}

torch_xla::XlaOpVector Silu::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla_input * BuildSigmoid(xla_input), loctx);
}

torch_xla::XlaOpVector SiluBackward::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_grad_output = loctx->GetOutputOp(operand(0));
xla::XlaOp xla_input = loctx->GetOutputOp(operand(1));
return ReturnOp(BuildSiLUBackward(xla_grad_output, xla_input), loctx);
}

torch_xla::XlaOpVector Sin::Lower(LoweringContext* loctx) const {
xla::XlaOp xla_input = loctx->GetOutputOp(operand(0));
return ReturnOp(xla::Sin(xla_input), loctx);
Expand Down
19 changes: 19 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "tensorflow/compiler/xla/client/lib/logdet.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "torch_xla/csrc/elementwise.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/pooling.h"

Expand Down Expand Up @@ -208,6 +209,10 @@ xla::Shape RsqrtOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SeluOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SgnOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand All @@ -216,6 +221,20 @@ xla::Shape SignOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SiluOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}

xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input) {
auto lower_for_shape_fn =
[](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
return BuildSiLUBackward(operands[0], operands[1]);
};
return InferOutputShape({GetXlaShape(grad_output), GetXlaShape(input)},
lower_for_shape_fn);
}

xla::Shape SinOutputShape(const torch::lazy::Value& input) {
return GetXlaShape(input);
}
Expand Down
7 changes: 7 additions & 0 deletions torch_xla/csrc/ops/ops_xla_shape_fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,17 @@ xla::Shape RoundOutputShape(const torch::lazy::Value& input);

xla::Shape RsqrtOutputShape(const torch::lazy::Value& input);

xla::Shape SeluOutputShape(const torch::lazy::Value& input);

xla::Shape SgnOutputShape(const torch::lazy::Value& input);

xla::Shape SignOutputShape(const torch::lazy::Value& input);

xla::Shape SiluOutputShape(const torch::lazy::Value& input);

xla::Shape SiluBackwardOutputShape(const torch::lazy::Value& grad_output,
const torch::lazy::Value& input);

xla::Shape SinOutputShape(const torch::lazy::Value& input);

xla::Shape SinhOutputShape(const torch::lazy::Value& input);
Expand Down
1 change: 0 additions & 1 deletion torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1037,7 +1037,6 @@ class XLATensor : public c10::intrusive_ptr_target {
static XLATensorPtr select(const XLATensorPtr& input, int64_t dim,
int64_t index);

static XLATensorPtr selu(const XLATensorPtr& input);
static void selu_(XLATensorPtr& input);

static XLATensorPtr silu(const XLATensorPtr& input);
Expand Down
14 changes: 0 additions & 14 deletions torch_xla/csrc/tensor_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2420,24 +2420,10 @@ XLATensorPtr XLATensor::select(const XLATensorPtr& input, int64_t dim,
return tensor_ops::Select(input, dim, index);
}

XLATensorPtr XLATensor::selu(const XLATensorPtr& input) {
return input->CreateFrom(Selu(input->GetIrValue()));
}

void XLATensor::selu_(XLATensorPtr& input) {
input->SetInPlaceIrValue(Selu(input->GetIrValue()));
}

XLATensorPtr XLATensor::silu(const XLATensorPtr& input) {
return input->CreateFrom(SiLU(input->GetIrValue()));
}

XLATensorPtr XLATensor::silu_backward(XLATensorPtr& grad_output,
XLATensorPtr& input) {
return input->CreateFrom(
SiLUBackward(grad_output->GetIrValue(), input->GetIrValue()));
}

XLATensorPtr XLATensor::sigmoid(const XLATensorPtr& input) {
return input->CreateFrom(Sigmoid(input->GetIrValue()));
}
Expand Down
6 changes: 3 additions & 3 deletions xla_native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@ full_codegen:
- reciprocal
- round
- rsqrt
- selu
- sgn
- sign
- silu
- silu_backward
- sin
- sinh
- tan
Expand Down Expand Up @@ -271,12 +274,9 @@ supported:
- scatter.value_reduce
- scatter_add
- select.int
- selu
- selu_
- sigmoid
- sigmoid_backward
- silu
- silu_backward
- slice.Tensor
- slogdet
- smooth_l1_loss
Expand Down