diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index abce622af48a..d1b6fb048b1a 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1322,27 +1322,6 @@ at::Tensor XLANativeFunctions::eq(const at::Tensor& self, XLATensor::eq(bridge::GetXlaTensor(self), bridge::GetXlaTensor(other))); } -at::Tensor XLANativeFunctions::erf(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::erf(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::erfc(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::erfc(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::erfinv(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor( - XLATensor::erfinv(bridge::GetXlaTensor(self))); -} - -at::Tensor XLANativeFunctions::exp(const at::Tensor& self) { - XLA_FN_COUNTER("xla::"); - return bridge::AtenFromXlaTensor(XLATensor::exp(bridge::GetXlaTensor(self))); -} - at::Tensor XLANativeFunctions::expand(const at::Tensor& self, at::IntArrayRef size, bool implicit) { XLA_FN_COUNTER("xla::"); diff --git a/torch_xla/csrc/ops/ops_lower_fn.cpp b/torch_xla/csrc/ops/ops_lower_fn.cpp index 0793d6290840..ffeed40049fa 100644 --- a/torch_xla/csrc/ops/ops_lower_fn.cpp +++ b/torch_xla/csrc/ops/ops_lower_fn.cpp @@ -53,6 +53,26 @@ torch_xla::XlaOpVector Cosh::Lower(LoweringContext* loctx) const { return ReturnOp(xla::Cosh(xla_input), loctx); } +torch_xla::XlaOpVector Erf::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Erf(xla_input), loctx); +} + +torch_xla::XlaOpVector Erfc::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Erfc(xla_input), loctx); +} + +torch_xla::XlaOpVector Erfinv::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::ErfInv(xla_input), loctx); +} + +torch_xla::XlaOpVector Exp::Lower(LoweringContext* loctx) const { + xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); + return ReturnOp(xla::Exp(xla_input), loctx); +} + torch_xla::XlaOpVector Floor::Lower(LoweringContext* loctx) const { xla::XlaOp xla_input = loctx->GetOutputOp(operand(0)); return ReturnOp(xla::Floor(xla_input), loctx); diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp index ff4703c66f78..bdd74a9112cd 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.cpp +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.cpp @@ -41,6 +41,22 @@ xla::Shape CoshOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } +xla::Shape ErfOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ErfcOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ErfinvOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + +xla::Shape ExpOutputShape(const torch::lazy::Value& input) { + return GetXlaShape(input); +} + xla::Shape FloorOutputShape(const torch::lazy::Value& input) { return GetXlaShape(input); } diff --git a/torch_xla/csrc/ops/ops_xla_shape_fn.h b/torch_xla/csrc/ops/ops_xla_shape_fn.h index 98ee86a23612..679875404d41 100644 --- a/torch_xla/csrc/ops/ops_xla_shape_fn.h +++ b/torch_xla/csrc/ops/ops_xla_shape_fn.h @@ -21,6 +21,14 @@ xla::Shape CosOutputShape(const torch::lazy::Value& input); xla::Shape CoshOutputShape(const torch::lazy::Value& input); +xla::Shape ErfOutputShape(const torch::lazy::Value& input); + +xla::Shape ErfcOutputShape(const torch::lazy::Value& input); + +xla::Shape ErfinvOutputShape(const torch::lazy::Value& input); + +xla::Shape ExpOutputShape(const torch::lazy::Value& input); + xla::Shape FloorOutputShape(const torch::lazy::Value& input); xla::Shape InverseOutputShape(const torch::lazy::Value& input); diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 26496d7d649c..b206a40b3def 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -555,12 +555,6 @@ class XLATensor : public c10::intrusive_ptr_target { static XLATensor eq(const XLATensor& input, const XLATensor& other); - static XLATensor erf(const XLATensor& input); - - static XLATensor erfc(const XLATensor& input); - - static XLATensor erfinv(const XLATensor& input); - static XLATensor exp(const XLATensor& input); static XLATensor expand(const XLATensor& input, std::vector size); diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 36be8cd3dee6..401d263a7ba3 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -1271,18 +1271,6 @@ XLATensor XLATensor::embedding_dense_backward(const XLATensor& grad_output, padding_idx, scale_grad_by_freq); } -XLATensor XLATensor::erf(const XLATensor& input) { - return input.CreateFrom(Erf(input.GetIrValue())); -} - -XLATensor XLATensor::erfc(const XLATensor& input) { - return input.CreateFrom(Erfc(input.GetIrValue())); -} - -XLATensor XLATensor::erfinv(const XLATensor& input) { - return input.CreateFrom(Erfinv(input.GetIrValue())); -} - XLATensor XLATensor::exp(const XLATensor& input) { return input.CreateFrom(Exp(input.GetIrValue())); } diff --git a/xla_native_functions.yaml b/xla_native_functions.yaml index cbdeec34c0bc..4aa60658a012 100644 --- a/xla_native_functions.yaml +++ b/xla_native_functions.yaml @@ -10,6 +10,10 @@ full_codegen: - atanh - cos - cosh + - erf + - erfc + - erfinv + - exp - floor - inverse - logdet @@ -121,10 +125,6 @@ supported: - empty_strided - eq.Scalar - eq.Tensor - - erf - - erfc - - erfinv - - exp - expand - expm1 - exponential_