@@ -116,26 +116,6 @@ torch::lazy::NodePtr Logit(const torch::lazy::Value& input,
116116 torch::lazy::MHash (eps));
117117}
118118
119- torch::lazy::NodePtr SgnOp (const torch::lazy::Value& input) {
120- auto lower_fn = [](const XlaNode& node,
121- LoweringContext* loctx) -> XlaOpVector {
122- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
123- return node.ReturnOp (BuildSgn (xla_input), loctx);
124- };
125- return GenericOp (torch::lazy::OpKind (at::aten::sgn), {input},
126- GetXlaShape (input), std::move (lower_fn));
127- }
128-
129- torch::lazy::NodePtr SignOp (const torch::lazy::Value& input) {
130- auto lower_fn = [](const XlaNode& node,
131- LoweringContext* loctx) -> XlaOpVector {
132- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
133- return node.ReturnOp (BuildSign (xla_input), loctx);
134- };
135- return GenericOp (torch::lazy::OpKind (at::aten::sign), {input},
136- GetXlaShape (input), std::move (lower_fn));
137- }
138-
139119torch::lazy::NodePtr Prelu (const torch::lazy::Value& input,
140120 const torch::lazy::Value& weight) {
141121 auto lower_fn = [](const XlaNode& node,
@@ -169,57 +149,6 @@ torch::lazy::NodePtr PreluBackward(const torch::lazy::Value& grad,
169149 std::move (lower_fn), /* num_outputs=*/ 2 );
170150}
171151
172- torch::lazy::NodePtr LogSigmoid (const torch::lazy::Value& input) {
173- auto lower_fn = [](const XlaNode& node,
174- LoweringContext* loctx) -> XlaOpVector {
175- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
176- return node.ReturnOps (BuildLogSigmoid (xla_input), loctx);
177- };
178- return GenericOp (torch::lazy::OpKind (at::aten::log_sigmoid), {input},
179- GetXlaShape (input), std::move (lower_fn), /* num_outputs=*/ 2 );
180- }
181-
182- torch::lazy::NodePtr SiLU (const torch::lazy::Value& input) {
183- auto lower_fn = [](const XlaNode& node,
184- LoweringContext* loctx) -> XlaOpVector {
185- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
186- return node.ReturnOp (xla_input * BuildSigmoid (xla_input), loctx);
187- };
188- return GenericOp (torch::lazy::OpKind (at::aten::silu), {input},
189- GetXlaShape (input), std::move (lower_fn));
190- }
191-
192- torch::lazy::NodePtr SiLUBackward (const torch::lazy::Value& grad_output,
193- const torch::lazy::Value& input) {
194- auto lower_fn = [](const XlaNode& node,
195- LoweringContext* loctx) -> XlaOpVector {
196- xla::XlaOp xla_grad_output = loctx->GetOutputOp (node.operand (0 ));
197- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (1 ));
198- return node.ReturnOp (BuildSiLUBackward (xla_grad_output, xla_input), loctx);
199- };
200- auto lower_for_shape_fn =
201- [](absl::Span<const xla::XlaOp> operands) -> xla::XlaOp {
202- return BuildSiLUBackward (operands[0 ], operands[1 ]);
203- };
204- return GenericOp (
205- torch::lazy::OpKind (at::aten::silu_backward), {grad_output, input},
206- [&]() {
207- return InferOutputShape ({GetXlaShape (grad_output), GetXlaShape (input)},
208- lower_for_shape_fn);
209- },
210- std::move (lower_fn));
211- }
212-
213- torch::lazy::NodePtr Sigmoid (const torch::lazy::Value& input) {
214- auto lower_fn = [](const XlaNode& node,
215- LoweringContext* loctx) -> XlaOpVector {
216- xla::XlaOp xla_input = loctx->GetOutputOp (node.operand (0 ));
217- return node.ReturnOp (BuildSigmoid (xla_input), loctx);
218- };
219- return GenericOp (torch::lazy::OpKind (at::aten::sigmoid), {input},
220- GetXlaShape (input), std::move (lower_fn));
221- }
222-
223152torch::lazy::NodePtr SigmoidBackward (const torch::lazy::Value& grad_output,
224153 const torch::lazy::Value& output) {
225154 torch::lazy::Value scalar_1 = ScalarOp (1 , GetXlaShape (output));
0 commit comments