Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,8 +1099,7 @@ class PRelu(OnnxOpConverter):
def _impl_v1(cls, bb, inputs, attr, params):
x = inputs[0]
slope = inputs[1]
# TODO(tvm-team): Should add a new op for this.
return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope)
return relax.op.nn.prelu(x, slope)


class ThresholdedRelu(OnnxOpConverter):
Expand Down
45 changes: 43 additions & 2 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,54 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) {

TVM_FFI_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu);

StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) {
TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
if (data_sinfo->IsUnknownNdim()) {
return data_sinfo;
}
if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
ctx->ReportFatal(Diagnostic::Error(call) << "Prelu requires the input tensor to have float "
"dtype. However, the given input dtype is "
<< data_sinfo->dtype);
}
const auto* attrs = call->attrs.as<PReluAttrs>();
NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);

return data_sinfo;
}

InferLayoutOutput InferLayoutPRelu(const Call& call,
const Map<String, Array<String>>& desired_layouts,
const VarLayoutMap& var_layout_map) {
ICHECK(NoDesiredLayout(call, desired_layouts));
const auto* attrs = call->attrs.as<PReluAttrs>();
ICHECK(attrs) << "Invalid Call";

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);

// TODO(Siva): We could handle if the axis is not the sub indexed one.
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
int ndim = tensor_sinfo->ndim;
layout = LayoutDecision(InitialLayout(ndim));
}

ObjectPtr<PReluAttrs> new_attrs = make_object<PReluAttrs>(*attrs);
new_attrs->axis = FindAxis(layout->layout, attrs->axis);

LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, call->args[1]);
return InferLayoutOutput({layout, alpha_layout}, {layout}, Attrs(new_attrs));
}

TVM_REGISTER_OP("relax.nn.prelu")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("alpha", "Tensor", "The channel-wise learnable slope.")
.set_attrs_type<PReluAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoUnaryArith</*require_float_dtype=*/true>)
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPRelu)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPRelu)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.softmax */
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relax/test_frontend_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ def test_mish():


def test_prelu():
verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32])
verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32])


def test_thresholded_relu():
Expand Down
83 changes: 83 additions & 0 deletions tests/python/relax/test_transform_legalize_ops_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,6 +1159,89 @@ def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle):
tvm.ir.assert_structural_equal(mod, Expected)


def test_prelu():
# fmt: off
@tvm.script.ir_module
class PRelu:
@R.function
def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"):
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), dtype="float32"))
return gv

@T.prim_func(private=True)
def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")):
T.func_attr({"tir.noalias": True})
# with T.block("root"):
slope_broadcasted = T.alloc_buffer((T.int64(3),))
for c in range(T.int64(3)):
with T.block("slope_broadcasted"):
v_c = T.axis.spatial(T.int64(3), c)
T.reads(y[T.int64(0)])
T.writes(slope_broadcasted[v_c])
slope_broadcasted[v_c] = y[T.int64(0)]
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
# fmt: on

mod = LegalizeOps()(PRelu)
tvm.ir.assert_structural_equal(mod, Expected)


def test_prelu_symbolic():
# fmt: off
@tvm.script.ir_module
class PRelu:
@R.function
def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor(("m", 7), "float32"):
m = T.int64()
gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"):
m = T.int64()
gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), dtype="float32"))
return gv

@T.prim_func(private=True)
def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T.handle):
T.func_attr({"tir.noalias": True})
m = T.int64()
x = T.match_buffer(var_x, (m, T.int64(7)))
compute = T.match_buffer(var_compute, (m, T.int64(7)))
# with T.block("root"):
slope_broadcasted = T.alloc_buffer((T.int64(7),))
for c in range(T.int64(7)):
with T.block("slope_broadcasted"):
v_c = T.axis.spatial(T.int64(7), c)
T.reads(y[T.int64(0)])
T.writes(slope_broadcasted[v_c])
slope_broadcasted[v_c] = y[T.int64(0)]
for i0, i1 in T.grid(m, T.int64(7)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1])
T.writes(compute[v_i0, v_i1])
compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1])
# fmt: on

mod = LegalizeOps()(PRelu)
tvm.ir.assert_structural_equal(mod, Expected)


def test_gelu():
# fmt: off
@tvm.script.ir_module
Expand Down
19 changes: 19 additions & 0 deletions tests/python/relax/test_tvmscript_parser_op_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,5 +364,24 @@ def foo(
_check(foo, bb.get()["foo"])


def test_prelu():
@R.function
def foo(
x: R.Tensor((2, 4, 4, 5), "float32"),
alpha: R.Tensor((1,), "float32"),
) -> R.Tensor((2, 4, 4, 5), "float32"):
gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.prelu(x, alpha)
return gv

x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32"))
alpha = relax.Var("alpha", R.Tensor((1,), "float32"))
bb = relax.BlockBuilder()
with bb.function("foo", [x, alpha]):
gv = bb.emit(relax.op.nn.prelu(x, alpha))
bb.emit_func_output(gv)

_check(foo, bb.get()["foo"])


if __name__ == "__main__":
tvm.testing.main()
Loading