22
33#include " tensorflow/compiler/xla/client/lib/logdet.h"
44#include " tensorflow/compiler/xla/shape_util.h"
5+ #include " torch_xla/csrc/elementwise.h"
56#include " torch_xla/csrc/helpers.h"
67#include " torch_xla/csrc/pooling.h"
78
@@ -208,6 +209,10 @@ xla::Shape RsqrtOutputShape(const torch::lazy::Value& input) {
208209 return GetXlaShape (input);
209210}
210211
212+ xla::Shape SeluOutputShape (const torch::lazy::Value& input) {
213+ return GetXlaShape (input);
214+ }
215+
211216xla::Shape SgnOutputShape (const torch::lazy::Value& input) {
212217 return GetXlaShape (input);
213218}
@@ -216,6 +221,20 @@ xla::Shape SignOutputShape(const torch::lazy::Value& input) {
216221 return GetXlaShape (input);
217222}
218223
224+ xla::Shape SiluOutputShape (const torch::lazy::Value& input) {
225+ return GetXlaShape (input);
226+ }
227+
228+ xla::Shape SiluBackwardOutputShape (const torch::lazy::Value& grad_output,
229+ const torch::lazy::Value& input) {
230+ auto lower_for_shape_fn =
231+ [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
232+ return BuildSiLUBackward (operands[0 ], operands[1 ]);
233+ };
234+ return InferOutputShape ({GetXlaShape (grad_output), GetXlaShape (input)},
235+ lower_for_shape_fn);
236+ }
237+
219238xla::Shape SinOutputShape (const torch::lazy::Value& input) {
220239 return GetXlaShape (input);
221240}
0 commit comments