@@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,
303303
304304 case Stmt::CaseStmtClass:
305305 case Stmt::DefaultStmtClass:
306- assert (0 &&
307- " Should not get here, currently handled directly from SwitchStmt" );
306+ return buildSwitchCase (cast<SwitchCase>(*S));
308307 break ;
309308
310309 case Stmt::BreakStmtClass:
@@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
715714 return buildCaseDefaultCascade (&S, condType, caseAttrs);
716715}
717716
718- mlir::LogicalResult
719- CIRGenFunction::buildSwitchCase (const SwitchCase &S, mlir::Type condType,
720- SmallVector<mlir::Attribute, 4 > &caseAttrs) {
717+ mlir::LogicalResult CIRGenFunction::buildSwitchCase (const SwitchCase &S) {
718+ assert (!caseAttrsStack.empty () &&
719+ " build switch case without seeting case attrs" );
720+ assert (!condTypeStack.empty () &&
721+ " build switch case without specifying the type of the condition" );
722+
721723 if (S.getStmtClass () == Stmt::CaseStmtClass)
722- return buildCaseStmt (cast<CaseStmt>(S), condType, caseAttrs);
724+ return buildCaseStmt (cast<CaseStmt>(S), condTypeStack.back (),
725+ caseAttrsStack.back ());
723726
724727 if (S.getStmtClass () == Stmt::DefaultStmtClass)
725- return buildDefaultStmt (cast<DefaultStmt>(S), condType, caseAttrs);
728+ return buildDefaultStmt (cast<DefaultStmt>(S), condTypeStack.back (),
729+ caseAttrsStack.back ());
726730
727731 llvm_unreachable (" expect case or default stmt" );
728732}
@@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
987991 return mlir::success ();
988992}
989993
990- mlir::LogicalResult CIRGenFunction::buildSwitchBody (
991- const Stmt *S, mlir::Type condType,
992- llvm::SmallVector<mlir::Attribute, 4 > &caseAttrs) {
994+ mlir::LogicalResult CIRGenFunction::buildSwitchBody (const Stmt *S) {
993995 if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
994996 mlir::Block *lastCaseBlock = nullptr ;
995997 auto res = mlir::success ();
996998 for (auto *c : compoundStmt->body ()) {
997999 if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
998- res = buildSwitchCase (*switchCase, condType, caseAttrs );
1000+ res = buildSwitchCase (*switchCase);
9991001 } else if (lastCaseBlock) {
10001002 // This means it's a random stmt following up a case, just
10011003 // emit it as part of previous known case.
@@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
10451047 [&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
10461048 currLexScope->setAsSwitch ();
10471049
1048- llvm::SmallVector<mlir::Attribute, 4 > caseAttrs;
1050+ caseAttrsStack.push_back ({});
1051+ condTypeStack.push_back (condV.getType ());
10491052
1050- res = buildSwitchBody (S.getBody (), condV. getType (), caseAttrs );
1053+ res = buildSwitchBody (S.getBody ());
10511054
10521055 os.addRegions (currLexScope->getSwitchRegions ());
1053- os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
1056+ os.addAttribute (" cases" , builder.getArrayAttr (caseAttrsStack.back ()));
1057+
1058+ caseAttrsStack.pop_back ();
1059+ condTypeStack.pop_back ();
10541060 });
10551061
10561062 if (res.failed ())
0 commit comments