Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3474,8 +3474,6 @@ def FuncOp : CIR_Op<"func", [
/// Returns the results types that the callable region produces when
/// executed.
llvm::ArrayRef<mlir::Type> getCallableResults() {
if (::llvm::isa<cir::VoidType>(getFunctionType().getReturnType()))
return {};
return getFunctionType().getReturnTypes();
}

Expand All @@ -3492,10 +3490,15 @@ def FuncOp : CIR_Op<"func", [
}

/// Returns the argument types of this function.
llvm::ArrayRef<mlir::Type> getArgumentTypes() { return getFunctionType().getInputs(); }
llvm::ArrayRef<mlir::Type> getArgumentTypes() {
return getFunctionType().getInputs();
}

/// Returns the result types of this function.
llvm::ArrayRef<mlir::Type> getResultTypes() { return getFunctionType().getReturnTypes(); }
/// Returns 0 or 1 result type of this function (0 in the case of a function
/// returing void)
llvm::ArrayRef<mlir::Type> getResultTypes() {
return getFunctionType().getReturnTypes();
}

/// Hook for OpTrait::FunctionOpInterfaceTrait, called after verifying that
/// the 'type' attribute is present and checks if it holds a function type.
Expand Down
25 changes: 19 additions & 6 deletions clang/include/clang/CIR/Dialect/IR/CIRTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -372,29 +372,38 @@ def CIR_VectorType : CIR_Type<"Vector", "vector",
def CIR_FuncType : CIR_Type<"Func", "func"> {
let summary = "CIR function type";
let description = [{
The `!cir.func` is a function type. It consists of a single return type, a
list of parameter types and can optionally be variadic.
The `!cir.func` is a function type. It consists of an optional return type,
a list of parameter types and can optionally be variadic.

Example:

```mlir
!cir.func<()>
!cir.func<!bool ()>
!cir.func<(!s8i, !s8i)>
!cir.func<!s32i (!s8i, !s8i)>
!cir.func<!s32i (!s32i, ...)>
```
}];

let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs, "mlir::Type":$returnType,
let parameters = (ins ArrayRefParameter<"mlir::Type">:$inputs,
"mlir::Type":$optionalReturnType,
"bool":$varArg);
// Use a custom parser to handle the optional return and argument types
// without an optional anchor.
let assemblyFormat = [{
`<` $returnType ` ` `(` custom<FuncTypeArgs>($inputs, $varArg) `>`
`<` custom<FuncType>($optionalReturnType, $inputs, $varArg) `>`
}];

let builders = [
// Construct with an actual return type or explicit !cir.void
TypeBuilderWithInferredContext<(ins
"llvm::ArrayRef<mlir::Type>":$inputs, "mlir::Type":$returnType,
CArg<"bool", "false">:$isVarArg), [{
return $_get(returnType.getContext(), inputs, returnType, isVarArg);
return $_get(returnType.getContext(), inputs,
mlir::isa<cir::VoidType>(returnType) ? nullptr
: returnType,
isVarArg);
}]>
];

Expand All @@ -408,11 +417,15 @@ def CIR_FuncType : CIR_Type<"Func", "func"> {
/// Returns the number of arguments to the function.
unsigned getNumInputs() const { return getInputs().size(); }

/// Returns the result type of the function as an actual return type or
/// explicit !cir.void
mlir::Type getReturnType() const;

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
llvm::ArrayRef<mlir::Type> getReturnTypes() const;

/// Returns whether the function is returns void.
/// Returns whether the function returns void.
bool isVoid() const;

/// Returns a clone of this function type with the given argument
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ mlir::Type CIRGenTypes::convertFunctionTypeInternal(QualType QFT) {
assert(QFT.isCanonical());
const Type *Ty = QFT.getTypePtr();
const FunctionType *FT = cast<FunctionType>(QFT.getTypePtr());
// First, check whether we can build the full fucntion type. If the function
// First, check whether we can build the full function type. If the function
// type depends on an incomplete type (e.g. a struct or enum), we cannot lower
// the function type.
assert(isFuncTypeConvertible(FT) && "NYI");
Expand Down
14 changes: 7 additions & 7 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2490,13 +2490,8 @@ void cir::FuncOp::print(OpAsmPrinter &p) {
p.printSymbolName(getSymName());
auto fnType = getFunctionType();
llvm::SmallVector<Type, 1> resultTypes;
if (!fnType.isVoid())
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(),
fnType.getReturnTypes());
else
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), {});
function_interface_impl::printFunctionSignature(
p, *this, fnType.getInputs(), fnType.isVarArg(), fnType.getReturnTypes());

if (mlir::ArrayAttr annotations = getAnnotationsAttr()) {
p << ' ';
Expand Down Expand Up @@ -2565,6 +2560,11 @@ LogicalResult cir::FuncOp::verifyType() {
if (!getNoProto() && type.isVarArg() && type.getNumInputs() == 0)
return emitError()
<< "prototyped function must have at least one non-variadic input";
if (auto rt = type.getReturnTypes();
!rt.empty() && mlir::isa<cir::VoidType>(rt.front()))
return emitOpError("The return type for a function returning void should "
"be empty instead of an explicit !cir.void");

return success();
}

Expand Down
94 changes: 82 additions & 12 deletions clang/lib/CIR/Dialect/IR/CIRTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <optional>

using cir::MissingFeatures;
Expand All @@ -41,12 +42,13 @@ using cir::MissingFeatures;
// CIR Custom Parser/Printer Signatures
//===----------------------------------------------------------------------===//

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnTypes,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg);

static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg);
static mlir::ParseResult parsePointerAddrSpace(mlir::AsmParser &p,
mlir::Attribute &addrSpaceAttr);
static void printPointerAddrSpace(mlir::AsmPrinter &p,
Expand Down Expand Up @@ -913,9 +915,38 @@ FuncType FuncType::clone(TypeRange inputs, TypeRange results) const {
return get(llvm::to_vector(inputs), results[0], isVarArg());
}

mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
// A special parser is needed for function returning void to handle the missing
// type.
static mlir::ParseResult parseFuncTypeReturn(mlir::AsmParser &p,
mlir::Type &optionalReturnType) {
if (succeeded(p.parseOptionalLParen())) {
// If we have already a '(', the function has no return type
optionalReturnType = {};
return mlir::success();
}
mlir::Type type;
if (p.parseType(type))
return mlir::failure();
if (isa<cir::VoidType>(type))
// An explicit !cir.void means also no return type.
optionalReturnType = {};
else
// Otherwise use the actual type.
optionalReturnType = type;
return p.parseLParen();
}

// A special pretty-printer for function returning or not a result.
static void printFuncTypeReturn(mlir::AsmPrinter &p,
mlir::Type optionalReturnType) {
if (optionalReturnType)
p << optionalReturnType << ' ';
p << '(';
}

static mlir::ParseResult
parseFuncTypeArgs(mlir::AsmParser &p, llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
isVarArg = false;
// `(` `)`
if (succeeded(p.parseOptionalRParen()))
Expand Down Expand Up @@ -945,8 +976,9 @@ mlir::ParseResult parseFuncTypeArgs(mlir::AsmParser &p,
return p.parseRParen();
}

void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
static void printFuncTypeArgs(mlir::AsmPrinter &p,
mlir::ArrayRef<mlir::Type> params,
bool isVarArg) {
llvm::interleaveComma(params, p,
[&p](mlir::Type type) { p.printType(type); });
if (isVarArg) {
Expand All @@ -957,11 +989,49 @@ void printFuncTypeArgs(mlir::AsmPrinter &p, mlir::ArrayRef<mlir::Type> params,
p << ')';
}

// Use a custom parser to handle the optional return and argument types without
// an optional anchor.
static mlir::ParseResult parseFuncType(mlir::AsmParser &p,
mlir::Type &optionalReturnTypes,
llvm::SmallVector<mlir::Type> &params,
bool &isVarArg) {
if (failed(parseFuncTypeReturn(p, optionalReturnTypes)))
return failure();
return parseFuncTypeArgs(p, params, isVarArg);
}

static void printFuncType(mlir::AsmPrinter &p, mlir::Type optionalReturnTypes,
mlir::ArrayRef<mlir::Type> params, bool isVarArg) {
printFuncTypeReturn(p, optionalReturnTypes);
printFuncTypeArgs(p, params, isVarArg);
}

// Return the actual return type or an explicit !cir.void if the function does
// not return anything
mlir::Type FuncType::getReturnType() const {
if (isVoid())
return cir::VoidType::get(getContext());
return static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
}

/// Returns the result type of the function as an ArrayRef, enabling better
/// integration with generic MLIR utilities.
llvm::ArrayRef<mlir::Type> FuncType::getReturnTypes() const {
return static_cast<detail::FuncTypeStorage *>(getImpl())->returnType;
if (isVoid())
return {};
return static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
}

bool FuncType::isVoid() const { return mlir::isa<VoidType>(getReturnType()); }
// Whether the function returns void
bool FuncType::isVoid() const {
auto rt =
static_cast<detail::FuncTypeStorage *>(getImpl())->optionalReturnType;
assert(!rt ||
!mlir::isa<cir::VoidType>(rt) &&
"The return type for a function returning void should be empty "
"instead of a real !cir.void");
return !rt;
}

//===----------------------------------------------------------------------===//
// MethodType Definitions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ FuncType LowerTypes::getFunctionType(const LowerFunctionInfo &FI) {
}
}

return FuncType::get(getMLIRContext(), ArgTypes, resultType, FI.isVariadic());
return FuncType::get(ArgTypes, resultType, FI.isVariadic());
}

/// Convert a CIR type to its ABI-specific default form.
Expand Down
6 changes: 3 additions & 3 deletions clang/test/CIR/CodeGen/fun-ptr.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ int foo(Data* d) {
return f(d);
}

// CIR: cir.func private {{@.*test.*}}() -> !cir.ptr<!cir.func<!void ()>>
// CIR: cir.func private {{@.*test.*}}() -> !cir.ptr<!cir.func<()>>
// CIR: cir.func {{@.*bar.*}}()
// CIR: [[RET:%.*]] = cir.call {{@.*test.*}}() : () -> !cir.ptr<!cir.func<!void ()>>
// CIR: cir.call [[RET]]() : (!cir.ptr<!cir.func<!void ()>>) -> ()
// CIR: [[RET:%.*]] = cir.call {{@.*test.*}}() : () -> !cir.ptr<!cir.func<()>>
// CIR: cir.call [[RET]]() : (!cir.ptr<!cir.func<()>>) -> ()
// CIR: cir.return

// LLVM: declare ptr {{@.*test.*}}()
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CIR/CodeGen/gnu-extension.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ void bar(void) {
}

//CHECK: cir.func @bar()
//CHECK: {{.*}} = cir.get_global @bar : !cir.ptr<!cir.func<!void ()>>
//CHECK: {{.*}} = cir.get_global @bar : !cir.ptr<!cir.func<()>>
//CHECK: cir.return
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/member-init-struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ C a, b(x), c(0, 2);
// CHECK: %[[VAL_8:.*]] = cir.get_member %[[VAL_2]][2] {name = "d"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.array<!s32i x 10>>
// CHECK: %[[VAL_9:.*]] = cir.const {{.*}} : !cir.array<!s32i x 10>
// CHECK: cir.store %[[VAL_9]], %[[VAL_8]] : !cir.array<!s32i x 10>, !cir.ptr<!cir.array<!s32i x 10>>
// CHECK: %[[VAL_10:.*]] = cir.get_member %[[VAL_2]][4] {name = "e"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.method<!cir.func<!void ()> in !ty_C>>
// CHECK: %[[VAL_11:.*]] = cir.const #cir.method<null> : !cir.method<!cir.func<!void ()> in !ty_C>
// CHECK: cir.store %[[VAL_11]], %[[VAL_10]] : !cir.method<!cir.func<!void ()> in !ty_C>, !cir.ptr<!cir.method<!cir.func<!void ()> in !ty_C>>
// CHECK: cir.return
// CHECK: %[[VAL_10:.*]] = cir.get_member %[[VAL_2]][4] {name = "e"} : !cir.ptr<!ty_C> -> !cir.ptr<!cir.method<!cir.func<()> in !ty_C>>
// CHECK: %[[VAL_11:.*]] = cir.const #cir.method<null> : !cir.method<!cir.func<()> in !ty_C>
// CHECK: cir.store %[[VAL_11]], %[[VAL_10]] : !cir.method<!cir.func<()> in !ty_C>, !cir.ptr<!cir.method<!cir.func<()> in !ty_C>>
// CHECK: cir.return
4 changes: 2 additions & 2 deletions clang/test/CIR/CodeGen/multi-vtable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,9 @@ int main() {

// CIR: cir.func @main() -> !s32i extra(#fn_attr) {

// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Mother>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Mother>)>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Mother>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Mother>)>>>

// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Child>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<!void (!cir.ptr<!ty_Child>)>>>
// CIR: %{{[0-9]+}} = cir.vtable.address_point( %{{[0-9]+}} : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Child>)>>>, vtable_index = 0, address_point_index = 0) : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!ty_Child>)>>>

// CIR: }

Expand Down
14 changes: 7 additions & 7 deletions clang/test/CIR/CodeGen/no-proto-fun-ptr.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ void check_noproto_ptr() {
}

// CHECK: cir.func no_proto @check_noproto_ptr()
// CHECK: [[ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<!void ()>>, !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, ["fun", init] {alignment = 8 : i64}
// CHECK: [[GGO:%.*]] = cir.get_global @empty : !cir.ptr<!cir.func<!void ()>>
// CHECK: cir.store [[GGO]], [[ALLOC]] : !cir.ptr<!cir.func<!void ()>>, !cir.ptr<!cir.ptr<!cir.func<!void ()>>>
// CHECK: [[ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<()>>, !cir.ptr<!cir.ptr<!cir.func<()>>>, ["fun", init] {alignment = 8 : i64}
// CHECK: [[GGO:%.*]] = cir.get_global @empty : !cir.ptr<!cir.func<()>>
// CHECK: cir.store [[GGO]], [[ALLOC]] : !cir.ptr<!cir.func<()>>, !cir.ptr<!cir.ptr<!cir.func<()>>>
// CHECK: cir.return

void empty(void) {}
Expand All @@ -20,8 +20,8 @@ void buz() {
}

// CHECK: cir.func no_proto @buz()
// CHECK: [[FNPTR_ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<!void (...)>>, !cir.ptr<!cir.ptr<!cir.func<!void (...)>>>, ["func"] {alignment = 8 : i64}
// CHECK: [[FNPTR:%.*]] = cir.load deref [[FNPTR_ALLOC]] : !cir.ptr<!cir.ptr<!cir.func<!void (...)>>>, !cir.ptr<!cir.func<!void (...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %1 : !cir.ptr<!cir.func<!void (...)>>), !cir.ptr<!cir.func<!void ()>>
// CHECK: cir.call [[CAST]]() : (!cir.ptr<!cir.func<!void ()>>) -> ()
// CHECK: [[FNPTR_ALLOC:%.*]] = cir.alloca !cir.ptr<!cir.func<(...)>>, !cir.ptr<!cir.ptr<!cir.func<(...)>>>, ["func"] {alignment = 8 : i64}
// CHECK: [[FNPTR:%.*]] = cir.load deref [[FNPTR_ALLOC]] : !cir.ptr<!cir.ptr<!cir.func<(...)>>>, !cir.ptr<!cir.func<(...)>>
// CHECK: [[CAST:%.*]] = cir.cast(bitcast, %1 : !cir.ptr<!cir.func<(...)>>), !cir.ptr<!cir.func<()>>
// CHECK: cir.call [[CAST]]() : (!cir.ptr<!cir.func<()>>) -> ()
// CHECK: cir.return
8 changes: 4 additions & 4 deletions clang/test/CIR/CodeGen/pointer-arith-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ void *f4_1(void *a, int b) { return (a -= b); }

FP f5(FP a, int b) { return a + b; }
// CIR-LABEL: f5
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, !cir.ptr<!cir.func<!void ()>>
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<!void ()>>, %[[STRIDE]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[STRIDE]] : !s32i)

// LLVM-LABEL: f5
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
Expand All @@ -67,10 +67,10 @@ FP f6_1(int a, FP b) { return (a += b); }

FP f7(FP a, int b) { return a - b; }
// CIR-LABEL: f7
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<!void ()>>>, !cir.ptr<!cir.func<!void ()>>
// CIR: %[[PTR:.*]] = cir.load {{.*}} : !cir.ptr<!cir.ptr<!cir.func<()>>>, !cir.ptr<!cir.func<()>>
// CIR: %[[STRIDE:.*]] = cir.load {{.*}} : !cir.ptr<!s32i>, !s32i
// CIR: %[[SUB:.*]] = cir.unary(minus, %[[STRIDE]]) : !s32i, !s32i
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<!void ()>>, %[[SUB]] : !s32i)
// CIR: cir.ptr_stride(%[[PTR]] : !cir.ptr<!cir.func<()>>, %[[SUB]] : !s32i)

// LLVM-LABEL: f7
// LLVM: %[[PTR:.*]] = load ptr, ptr {{.*}}, align 8
Expand Down
Loading
Loading