Skip to content

Commit 6a73deb

Browse files
dlibenziasuhan
authored andcommitted
Added aten::sigmoid operations.
1 parent 64bd93a commit 6a73deb

File tree

10 files changed

+57
-5
lines changed

10 files changed

+57
-5
lines changed

test/cpp/test_aten_xla_tensor.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,16 @@ TEST_F(AtenXlaTensorTest, TestAbs) {
584584
});
585585
}
586586

587+
TEST_F(AtenXlaTensorTest, TestSigmoid) {
588+
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
589+
at::Tensor b = at::sigmoid(a);
590+
ForEachDevice([&](const Device& device) {
591+
at::Tensor xla_a = bridge::CreateXlaTensor(a, device);
592+
at::Tensor xla_b = at::sigmoid(xla_a);
593+
AllClose(b, xla_b, /*rtol=*/1e-3, /*atol=*/1e-5);
594+
});
595+
}
596+
587597
TEST_F(AtenXlaTensorTest, TestAddCMul) {
588598
at::Tensor a = at::rand({2, 2}, at::TensorOptions(at::kFloat));
589599
at::Tensor b = at::rand({2, 2}, at::TensorOptions(at::kFloat));

torch_xla/csrc/aten_xla_type.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,17 @@ at::Tensor AtenXlaType::softmax(const at::Tensor& self, int64_t dim) const {
541541
XLATensor::softmax(bridge::GetXlaTensor(self), dim));
542542
}
543543

544+
at::Tensor AtenXlaType::sigmoid(const at::Tensor& self) const {
545+
return bridge::AtenFromXlaTensor(
546+
XLATensor::sigmoid(bridge::GetXlaTensor(self)));
547+
}
548+
549+
at::Tensor& AtenXlaType::sigmoid_(at::Tensor& self) const {
550+
XLATensor self_tensor = bridge::GetXlaTensor(self);
551+
XLATensor::sigmoid_(self_tensor);
552+
return self;
553+
}
554+
544555
at::Tensor AtenXlaType::max_pool2d(const at::Tensor& self,
545556
at::IntList kernel_size, at::IntList stride,
546557
at::IntList padding, at::IntList dilation,

torch_xla/csrc/aten_xla_type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ class AtenXlaType : public AtenXlaTypeBase {
184184

185185
at::Tensor softmax(const at::Tensor& self, int64_t dim) const override;
186186

187+
at::Tensor sigmoid(const at::Tensor& self) const override;
188+
at::Tensor& sigmoid_(at::Tensor& self) const override;
189+
187190
at::Tensor max_pool2d(const at::Tensor& self, at::IntList kernel_size,
188191
at::IntList stride, at::IntList padding,
189192
at::IntList dilation, bool ceil_mode) const override;

torch_xla/csrc/elementwise.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,4 +91,11 @@ xla::XlaOp BuildTypeAs(const torch::jit::Node* node,
9191
return xla::ConvertElementType(operand, target_type);
9292
}
9393

94+
xla::XlaOp BuildSigmoid(const xla::XlaOp& input) {
95+
xla::Shape shape = XlaHelpers::ShapeOfXlaOp(input);
96+
xla::XlaOp half =
97+
XlaHelpers::ScalarValue<float>(0.5, shape.element_type(), input.builder());
98+
return half + half * xla::Tanh(half * input);
99+
}
100+
94101
} // namespace torch_xla

torch_xla/csrc/elementwise.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,6 @@ xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output,
2828
// Computes the rectified linear unit (replace negative elements with 0).
2929
xla::XlaOp BuildRelu(const xla::XlaOp& input);
3030

31+
xla::XlaOp BuildSigmoid(const xla::XlaOp& input);
32+
3133
} // namespace torch_xla

torch_xla/csrc/ops/ops.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ NodePtr TransposeOp(const Value& input) {
9090
output_shape, std::move(lower_fn));
9191
}
9292

93+
NodePtr Sigmoid(const Value& input) {
94+
auto lower_fn = [](const ir::Node& node,
95+
ir::LoweringContext* loctx) -> ir::XlaOpVector {
96+
xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0));
97+
return node.ReturnOp(BuildSigmoid(xla_input), loctx);
98+
};
99+
return ir::ops::GenericOp(ir::OpKind(at::aten::sigmoid), ir::OpList{input},
100+
input.shape(), std::move(lower_fn));
101+
}
102+
93103
NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
94104
c10::optional<at::Scalar> max) {
95105
const xla::Shape& input_shape = input.shape();

torch_xla/csrc/ops/ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ NodePtr Sqrt(const Value& input);
8181

8282
NodePtr Pow(const Value& input, const Value& exponent);
8383

84+
NodePtr Sigmoid(const Value& input);
85+
8486
NodePtr Clamp(const Value& input, c10::optional<at::Scalar> min,
8587
c10::optional<at::Scalar> max);
8688

torch_xla/csrc/tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,14 @@ XLATensor XLATensor::softmax(const XLATensor& input, xla::int64 dim) {
11361136
input.GetDevice());
11371137
}
11381138

1139+
XLATensor XLATensor::sigmoid(const XLATensor& input) {
1140+
return Create(ir::ops::Sigmoid(input.GetIrValue()), input.GetDevice());
1141+
}
1142+
1143+
void XLATensor::sigmoid_(XLATensor& input) {
1144+
input.SetIrValue(ir::ops::Sigmoid(input.GetIrValue()));
1145+
}
1146+
11391147
XLATensor XLATensor::nll_loss(const XLATensor& input, const XLATensor& target) {
11401148
return Create(ir::ops::NllLossOp(input.GetIrValue(), target.GetIrValue()),
11411149
input.GetDevice());

torch_xla/csrc/tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ class XLATensor {
196196

197197
static XLATensor softmax(const XLATensor& input, xla::int64 dim);
198198

199+
static XLATensor sigmoid(const XLATensor& input);
200+
static void sigmoid_(XLATensor& input);
201+
199202
static XLATensor ones(tensorflow::gtl::ArraySlice<const xla::int64> size,
200203
const Device& device, at::ScalarType scalar_type);
201204
static XLATensor ones_like(const XLATensor& input, const Device& device,

torch_xla/csrc/translator.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,7 @@ void TranslateSigmoid(const torch::jit::Node* node, ComputationContext* cctx,
366366
xla::XlaBuilder* b) {
367367
XLA_CHECK_EQ(node->inputs().size(), 1);
368368
xla::XlaOp xla_input = cctx->OpForInput(node, 0);
369-
xla::Shape xla_input_shape = XlaHelpers::ShapeOfXlaOp(xla_input);
370-
xla::XlaOp half =
371-
XlaHelpers::ScalarValue<float>(0.5, xla_input_shape.element_type(), b);
372-
xla::XlaOp xla_output = half + half * xla::Tanh(half * xla_input);
373-
cctx->AddNodeOp(node, xla_output);
369+
cctx->AddNodeOp(node, BuildSigmoid(xla_input));
374370
}
375371

376372
void TranslateRelu(const torch::jit::Node* node, ComputationContext* cctx,

0 commit comments

Comments
 (0)