Skip to content

Commit 30bb54d

Browse files
goxullanza
authored andcommitted
[CIR][ThroughMLIR] Templatize unary math op lowerings. (llvm#1557)
A lot of the unary math op lowerings follow the same template -- we can templatize this to remove redundant code and make things a little more neater. (Similar to what we do [here](https://github.com/llvm/clangir/blob/e4b8a48fb4d9a72a85e38f5439bcfb0673b4bea2/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp#L502)) I've checked all existing LIT tests via `ninja clang-check-cir` and they seem to be passing fine.
1 parent e51dfab commit 30bb54d

File tree

1 file changed

+37
-194
lines changed

1 file changed

+37
-194
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 37 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -279,173 +279,52 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
279279
}
280280
};
281281

282-
class CIRACosOpLowering : public mlir::OpConversionPattern<cir::ACosOp> {
283-
public:
284-
using OpConversionPattern<cir::ACosOp>::OpConversionPattern;
285-
286-
mlir::LogicalResult
287-
matchAndRewrite(cir::ACosOp op, OpAdaptor adaptor,
288-
mlir::ConversionPatternRewriter &rewriter) const override {
289-
rewriter.replaceOpWithNewOp<mlir::math::AcosOp>(op, adaptor.getSrc());
290-
return mlir::LogicalResult::success();
291-
}
292-
};
293-
294-
class CIRATanOpLowering : public mlir::OpConversionPattern<cir::ATanOp> {
295-
public:
296-
using OpConversionPattern<cir::ATanOp>::OpConversionPattern;
297-
298-
mlir::LogicalResult
299-
matchAndRewrite(cir::ATanOp op, OpAdaptor adaptor,
300-
mlir::ConversionPatternRewriter &rewriter) const override {
301-
rewriter.replaceOpWithNewOp<mlir::math::AtanOp>(op, adaptor.getSrc());
302-
return mlir::LogicalResult::success();
303-
}
304-
};
305-
306-
class CIRCosOpLowering : public mlir::OpConversionPattern<cir::CosOp> {
307-
public:
308-
using OpConversionPattern<cir::CosOp>::OpConversionPattern;
309-
310-
mlir::LogicalResult
311-
matchAndRewrite(cir::CosOp op, OpAdaptor adaptor,
312-
mlir::ConversionPatternRewriter &rewriter) const override {
313-
auto convertedType = getTypeConverter()->convertType(op.getType());
314-
rewriter.replaceOpWithNewOp<mlir::math::CosOp>(op, convertedType, adaptor.getSrc());
315-
return mlir::LogicalResult::success();
316-
}
317-
};
318-
319-
class CIRTanOpLowering : public mlir::OpConversionPattern<cir::TanOp> {
320-
public:
321-
using OpConversionPattern<cir::TanOp>::OpConversionPattern;
322-
323-
mlir::LogicalResult
324-
matchAndRewrite(cir::TanOp op, OpAdaptor adaptor,
325-
mlir::ConversionPatternRewriter &rewriter) const override {
326-
rewriter.replaceOpWithNewOp<mlir::math::TanOp>(op, adaptor.getSrc());
327-
return mlir::LogicalResult::success();
328-
}
329-
};
330-
331-
class CIRSqrtOpLowering : public mlir::OpConversionPattern<cir::SqrtOp> {
332-
public:
333-
using mlir::OpConversionPattern<cir::SqrtOp>::OpConversionPattern;
282+
/// Converts CIR unary math ops (e.g., cir::SinOp) to their MLIR equivalents
283+
/// (e.g., math::SinOp) using a generic template to avoid redundant boilerplate
284+
/// matchAndRewrite definitions.
334285

335-
mlir::LogicalResult
336-
matchAndRewrite(cir::SqrtOp op, OpAdaptor adaptor,
337-
mlir::ConversionPatternRewriter &rewriter) const override {
338-
rewriter.replaceOpWithNewOp<mlir::math::SqrtOp>(op, adaptor.getSrc());
339-
return mlir::LogicalResult::success();
340-
}
341-
};
342-
343-
class CIRFAbsOpLowering : public mlir::OpConversionPattern<cir::FAbsOp> {
344-
public:
345-
using mlir::OpConversionPattern<cir::FAbsOp>::OpConversionPattern;
346-
347-
mlir::LogicalResult
348-
matchAndRewrite(cir::FAbsOp op, OpAdaptor adaptor,
349-
mlir::ConversionPatternRewriter &rewriter) const override {
350-
rewriter.replaceOpWithNewOp<mlir::math::AbsFOp>(op, adaptor.getSrc());
351-
return mlir::LogicalResult::success();
352-
}
353-
};
354-
class CIRAbsOpLowering : public mlir::OpConversionPattern<cir::AbsOp> {
355-
public:
356-
using mlir::OpConversionPattern<cir::AbsOp>::OpConversionPattern;
357-
358-
mlir::LogicalResult
359-
matchAndRewrite(cir::AbsOp op, OpAdaptor adaptor,
360-
mlir::ConversionPatternRewriter &rewriter) const override {
361-
rewriter.replaceOpWithNewOp<mlir::math::AbsIOp>(op, adaptor.getSrc());
362-
return mlir::LogicalResult::success();
363-
}
364-
};
365-
366-
class CIRFloorOpLowering : public mlir::OpConversionPattern<cir::FloorOp> {
367-
public:
368-
using mlir::OpConversionPattern<cir::FloorOp>::OpConversionPattern;
369-
370-
mlir::LogicalResult
371-
matchAndRewrite(cir::FloorOp op, OpAdaptor adaptor,
372-
mlir::ConversionPatternRewriter &rewriter) const override {
373-
rewriter.replaceOpWithNewOp<mlir::math::FloorOp>(op, adaptor.getSrc());
374-
return mlir::LogicalResult::success();
375-
}
376-
};
377-
378-
class CIRCeilOpLowering : public mlir::OpConversionPattern<cir::CeilOp> {
379-
public:
380-
using mlir::OpConversionPattern<cir::CeilOp>::OpConversionPattern;
381-
382-
mlir::LogicalResult
383-
matchAndRewrite(cir::CeilOp op, OpAdaptor adaptor,
384-
mlir::ConversionPatternRewriter &rewriter) const override {
385-
rewriter.replaceOpWithNewOp<mlir::math::CeilOp>(op, adaptor.getSrc());
386-
return mlir::LogicalResult::success();
387-
}
388-
};
389-
390-
class CIRLog10OpLowering : public mlir::OpConversionPattern<cir::Log10Op> {
391-
public:
392-
using mlir::OpConversionPattern<cir::Log10Op>::OpConversionPattern;
393-
394-
mlir::LogicalResult
395-
matchAndRewrite(cir::Log10Op op, OpAdaptor adaptor,
396-
mlir::ConversionPatternRewriter &rewriter) const override {
397-
rewriter.replaceOpWithNewOp<mlir::math::Log10Op>(op, adaptor.getSrc());
398-
return mlir::LogicalResult::success();
399-
}
400-
};
401-
402-
class CIRLogOpLowering : public mlir::OpConversionPattern<cir::LogOp> {
403-
public:
404-
using mlir::OpConversionPattern<cir::LogOp>::OpConversionPattern;
405-
406-
mlir::LogicalResult
407-
matchAndRewrite(cir::LogOp op, OpAdaptor adaptor,
408-
mlir::ConversionPatternRewriter &rewriter) const override {
409-
rewriter.replaceOpWithNewOp<mlir::math::LogOp>(op, adaptor.getSrc());
410-
return mlir::LogicalResult::success();
411-
}
412-
};
413-
414-
class CIRLog2OpLowering : public mlir::OpConversionPattern<cir::Log2Op> {
415-
public:
416-
using mlir::OpConversionPattern<cir::Log2Op>::OpConversionPattern;
417-
418-
mlir::LogicalResult
419-
matchAndRewrite(cir::Log2Op op, OpAdaptor adaptor,
420-
mlir::ConversionPatternRewriter &rewriter) const override {
421-
rewriter.replaceOpWithNewOp<mlir::math::Log2Op>(op, adaptor.getSrc());
422-
return mlir::LogicalResult::success();
423-
}
424-
};
425-
426-
class CIRRoundOpLowering : public mlir::OpConversionPattern<cir::RoundOp> {
286+
template <typename CIROp, typename MLIROp>
287+
class CIRUnaryMathOpLowering : public mlir::OpConversionPattern<CIROp> {
427288
public:
428-
using mlir::OpConversionPattern<cir::RoundOp>::OpConversionPattern;
289+
using mlir::OpConversionPattern<CIROp>::OpConversionPattern;
429290

430291
mlir::LogicalResult
431-
matchAndRewrite(cir::RoundOp op, OpAdaptor adaptor,
292+
matchAndRewrite(CIROp op,
293+
typename mlir::OpConversionPattern<CIROp>::OpAdaptor adaptor,
432294
mlir::ConversionPatternRewriter &rewriter) const override {
433-
rewriter.replaceOpWithNewOp<mlir::math::RoundOp>(op, adaptor.getSrc());
295+
rewriter.replaceOpWithNewOp<MLIROp>(op, adaptor.getSrc());
434296
return mlir::LogicalResult::success();
435297
}
436298
};
437299

438-
class CIRExpOpLowering : public mlir::OpConversionPattern<cir::ExpOp> {
439-
public:
440-
using mlir::OpConversionPattern<cir::ExpOp>::OpConversionPattern;
441-
442-
mlir::LogicalResult
443-
matchAndRewrite(cir::ExpOp op, OpAdaptor adaptor,
444-
mlir::ConversionPatternRewriter &rewriter) const override {
445-
rewriter.replaceOpWithNewOp<mlir::math::ExpOp>(op, adaptor.getSrc());
446-
return mlir::LogicalResult::success();
447-
}
448-
};
300+
using CIRASinOpLowering =
301+
CIRUnaryMathOpLowering<cir::ASinOp, mlir::math::AsinOp>;
302+
using CIRSinOpLowering = CIRUnaryMathOpLowering<cir::SinOp, mlir::math::SinOp>;
303+
using CIRExp2OpLowering =
304+
CIRUnaryMathOpLowering<cir::Exp2Op, mlir::math::Exp2Op>;
305+
using CIRExpOpLowering = CIRUnaryMathOpLowering<cir::ExpOp, mlir::math::ExpOp>;
306+
using CIRRoundOpLowering =
307+
CIRUnaryMathOpLowering<cir::RoundOp, mlir::math::RoundOp>;
308+
using CIRLog2OpLowering =
309+
CIRUnaryMathOpLowering<cir::Log2Op, mlir::math::Log2Op>;
310+
using CIRLogOpLowering = CIRUnaryMathOpLowering<cir::LogOp, mlir::math::LogOp>;
311+
using CIRLog10OpLowering =
312+
CIRUnaryMathOpLowering<cir::Log10Op, mlir::math::Log10Op>;
313+
using CIRCeilOpLowering =
314+
CIRUnaryMathOpLowering<cir::CeilOp, mlir::math::CeilOp>;
315+
using CIRFloorOpLowering =
316+
CIRUnaryMathOpLowering<cir::FloorOp, mlir::math::FloorOp>;
317+
using CIRAbsOpLowering = CIRUnaryMathOpLowering<cir::AbsOp, mlir::math::AbsIOp>;
318+
using CIRFAbsOpLowering =
319+
CIRUnaryMathOpLowering<cir::FAbsOp, mlir::math::AbsFOp>;
320+
using CIRSqrtOpLowering =
321+
CIRUnaryMathOpLowering<cir::SqrtOp, mlir::math::SqrtOp>;
322+
using CIRCosOpLowering = CIRUnaryMathOpLowering<cir::CosOp, mlir::math::CosOp>;
323+
using CIRATanOpLowering =
324+
CIRUnaryMathOpLowering<cir::ATanOp, mlir::math::AtanOp>;
325+
using CIRACosOpLowering =
326+
CIRUnaryMathOpLowering<cir::ACosOp, mlir::math::AcosOp>;
327+
using CIRTanOpLowering = CIRUnaryMathOpLowering<cir::TanOp, mlir::math::TanOp>;
449328

450329
class CIRShiftOpLowering : public mlir::OpConversionPattern<cir::ShiftOp> {
451330
public:
@@ -480,42 +359,6 @@ class CIRShiftOpLowering : public mlir::OpConversionPattern<cir::ShiftOp> {
480359
}
481360
};
482361

483-
class CIRExp2OpLowering : public mlir::OpConversionPattern<cir::Exp2Op> {
484-
public:
485-
using mlir::OpConversionPattern<cir::Exp2Op>::OpConversionPattern;
486-
487-
mlir::LogicalResult
488-
matchAndRewrite(cir::Exp2Op op, OpAdaptor adaptor,
489-
mlir::ConversionPatternRewriter &rewriter) const override {
490-
rewriter.replaceOpWithNewOp<mlir::math::Exp2Op>(op, adaptor.getSrc());
491-
return mlir::LogicalResult::success();
492-
}
493-
};
494-
495-
class CIRSinOpLowering : public mlir::OpConversionPattern<cir::SinOp> {
496-
public:
497-
using mlir::OpConversionPattern<cir::SinOp>::OpConversionPattern;
498-
499-
mlir::LogicalResult
500-
matchAndRewrite(cir::SinOp op, OpAdaptor adaptor,
501-
mlir::ConversionPatternRewriter &rewriter) const override {
502-
rewriter.replaceOpWithNewOp<mlir::math::SinOp>(op, adaptor.getSrc());
503-
return mlir::LogicalResult::success();
504-
}
505-
};
506-
507-
class CIRASinOpLowering : public mlir::OpConversionPattern<cir::ASinOp> {
508-
public:
509-
using mlir::OpConversionPattern<cir::ASinOp>::OpConversionPattern;
510-
511-
mlir::LogicalResult
512-
matchAndRewrite(cir::ASinOp op, OpAdaptor adaptor,
513-
mlir::ConversionPatternRewriter &rewriter) const override {
514-
rewriter.replaceOpWithNewOp<mlir::math::AsinOp>(op, adaptor.getSrc());
515-
return mlir::LogicalResult::success();
516-
}
517-
};
518-
519362
template <typename CIROp, typename MLIROp>
520363
class CIRCountZerosBitOpLowering : public mlir::OpConversionPattern<CIROp> {
521364
public:

0 commit comments

Comments
 (0)