Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 39 additions & 38 deletions clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {

} // namespace

static mlir::Value lowerScalarToComplexCast(MLIRContext &ctx, CastOp op) {
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op) {
CIRBaseBuilderTy builder(ctx);
builder.setInsertionPoint(op);

Expand All @@ -40,7 +41,9 @@ static mlir::Value lowerScalarToComplexCast(MLIRContext &ctx, CastOp op) {
return builder.createComplexCreate(op.getLoc(), src, imag);
}

static mlir::Value lowerComplexToScalarCast(MLIRContext &ctx, CastOp op) {
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind elemToBoolKind) {
CIRBaseBuilderTy builder(ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
CIRBaseBuilderTy builder(ctx);
cir::CIRBaseBuilderTy builder(ctx);

builder.setInsertionPoint(op);

Expand All @@ -52,24 +55,17 @@ static mlir::Value lowerComplexToScalarCast(MLIRContext &ctx, CastOp op) {
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);

cir::CastKind elemToBoolKind;
if (op.getKind() == cir::CastKind::float_complex_to_bool)
elemToBoolKind = cir::CastKind::float_to_bool;
else if (op.getKind() == cir::CastKind::int_complex_to_bool)
elemToBoolKind = cir::CastKind::int_to_bool;
else
llvm_unreachable("invalid complex to bool cast kind");

cir::BoolType boolTy = builder.getBoolTy();
mlir::Value srcRealToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
mlir::Value srcImagToBool =
builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);

return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
}

static mlir::Value lowerComplexToComplexCast(MLIRContext &ctx, CastOp op) {
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
cir::CastOp op,
cir::CastKind scalarCastKind) {
CIRBaseBuilderTy builder(ctx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
CIRBaseBuilderTy builder(ctx);
cir::CIRBaseBuilderTy builder(ctx);

builder.setInsertionPoint(op);

Expand All @@ -80,24 +76,6 @@ static mlir::Value lowerComplexToComplexCast(MLIRContext &ctx, CastOp op) {
mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
mlir::Value srcImag = builder.createComplexReal(op.getLoc(), src);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be createComplexImag?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, I will write a test case (Complex to Complex cast) that shows the problem now, to update it here and incubator


cir::CastKind scalarCastKind;
switch (op.getKind()) {
case cir::CastKind::float_complex:
scalarCastKind = cir::CastKind::floating;
break;
case cir::CastKind::float_complex_to_int_complex:
scalarCastKind = cir::CastKind::float_to_int;
break;
case cir::CastKind::int_complex:
scalarCastKind = cir::CastKind::integral;
break;
case cir::CastKind::int_complex_to_float_complex:
scalarCastKind = cir::CastKind::int_to_float;
break;
default:
llvm_unreachable("invalid complex to complex cast kind");
}

mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
dstComplexElemTy);
mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
Expand All @@ -114,19 +92,42 @@ void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
break;

case cir::CastKind::float_complex_to_real:
case cir::CastKind::int_complex_to_real:
case cir::CastKind::float_complex_to_bool:
case cir::CastKind::int_complex_to_real: {
loweredValue = lowerComplexToScalarCast(getContext(), op, op.getKind());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks fishy, why is it op.getKind() here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of those kinds, we don't care about the kind, just call complex real op

break;
}

case cir::CastKind::float_complex_to_bool: {
loweredValue = lowerComplexToScalarCast(getContext(), op,
cir::CastKind::float_to_bool);
break;
}
case cir::CastKind::int_complex_to_bool: {
loweredValue = lowerComplexToScalarCast(getContext(), op);
loweredValue =
lowerComplexToScalarCast(getContext(), op, cir::CastKind::int_to_bool);
break;
}

case cir::CastKind::float_complex:
case cir::CastKind::float_complex_to_int_complex:
case cir::CastKind::int_complex:
case cir::CastKind::int_complex_to_float_complex:
loweredValue = lowerComplexToComplexCast(getContext(), op);
case cir::CastKind::float_complex: {
loweredValue =
lowerComplexToComplexCast(getContext(), op, cir::CastKind::floating);
break;
}
case cir::CastKind::float_complex_to_int_complex: {
loweredValue = lowerComplexToComplexCast(getContext(), op,
cir::CastKind::float_to_int);
break;
}
case cir::CastKind::int_complex: {
loweredValue =
lowerComplexToComplexCast(getContext(), op, cir::CastKind::integral);
break;
}
case cir::CastKind::int_complex_to_float_complex: {
loweredValue = lowerComplexToComplexCast(getContext(), op,
cir::CastKind::int_to_float);
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use immediately invoked lambda here and getContext beforehand, it will be way nicer :)

Something like this:

void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
  mlir::MLIRContext *ctx = getContext();
  Value loweredValue = [&]() -> Value {
    switch (op.getKind()) {
    case cir::CastKind::float_to_complex:
    case cir::CastKind::int_to_complex:
      return lowerScalarToComplexCast(ctx, op);
    case cir::CastKind::float_complex_to_real:
    case cir::CastKind::int_complex_to_real:
      return lowerComplexToScalarCast(ctx, op, op.getKind());
    case cir::CastKind::float_complex_to_bool:
      return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
    case cir::CastKind::int_complex_to_bool:
      return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
    case cir::CastKind::float_complex:
      return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
    case cir::CastKind::float_complex_to_int_complex:
      return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
    case cir::CastKind::int_complex:
      return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
    case cir::CastKind::int_complex_to_float_complex:
      return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
    default:
      return nullptr;
    }
  }();

  if (loweredValue) {
    op.replaceAllUsesWith(loweredValue);
    op.erase();
  }
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, i will do that


default:
return;
Expand Down