@@ -261,8 +261,10 @@ loopScheduling(scf::ForOp forOp,
261261 return 1 ;
262262 };
263263
264- std::optional<int64_t > ubConstant = getConstantIntValue (forOp.getUpperBound ());
265- std::optional<int64_t > lbConstant = getConstantIntValue (forOp.getLowerBound ());
264+ std::optional<int64_t > ubConstant =
265+ getConstantIntValue (forOp.getUpperBound ());
266+ std::optional<int64_t > lbConstant =
267+ getConstantIntValue (forOp.getLowerBound ());
266268 DenseMap<Operation *, unsigned > opCycles;
267269 std::map<unsigned , std::vector<Operation *>> wrappedSchedule;
268270 for (Operation &op : forOp.getBody ()->getOperations ()) {
@@ -447,113 +449,6 @@ void transform::TakeAssumedBranchOp::getEffects(
447449// LoopFuseSiblingOp
448450// ===----------------------------------------------------------------------===//
449451
450- // / Check if `target` and `source` are siblings, in the context that `target`
451- // / is being fused into `source`.
452- // /
453- // / This is a simple check that just checks if both operations are in the same
454- // / block and some checks to ensure that the fused IR does not violate
455- // / dominance.
456- static DiagnosedSilenceableFailure isOpSibling (Operation *target,
457- Operation *source) {
458- // Check if both operations are same.
459- if (target == source)
460- return emitSilenceableFailure (source)
461- << " target and source need to be different loops" ;
462-
463- // Check if both operations are in the same block.
464- if (target->getBlock () != source->getBlock ())
465- return emitSilenceableFailure (source)
466- << " target and source are not in the same block" ;
467-
468- // Check if fusion will violate dominance.
469- DominanceInfo domInfo (source);
470- if (target->isBeforeInBlock (source)) {
471- // Since `target` is before `source`, all users of results of `target`
472- // need to be dominated by `source`.
473- for (Operation *user : target->getUsers ()) {
474- if (!domInfo.properlyDominates (source, user, /* enclosingOpOk=*/ false )) {
475- return emitSilenceableFailure (target)
476- << " user of results of target should be properly dominated by "
477- " source" ;
478- }
479- }
480- } else {
481- // Since `target` is after `source`, all values used by `target` need
482- // to dominate `source`.
483-
484- // Check if operands of `target` are dominated by `source`.
485- for (Value operand : target->getOperands ()) {
486- Operation *operandOp = operand.getDefiningOp ();
487- // Operands without defining operations are block arguments. When `target`
488- // and `source` occur in the same block, these operands dominate `source`.
489- if (!operandOp)
490- continue ;
491-
492- // Operand's defining operation should properly dominate `source`.
493- if (!domInfo.properlyDominates (operandOp, source,
494- /* enclosingOpOk=*/ false ))
495- return emitSilenceableFailure (target)
496- << " operands of target should be properly dominated by source" ;
497- }
498-
499- // Check if values used by `target` are dominated by `source`.
500- bool failed = false ;
501- OpOperand *failedValue = nullptr ;
502- visitUsedValuesDefinedAbove (target->getRegions (), [&](OpOperand *operand) {
503- Operation *operandOp = operand->get ().getDefiningOp ();
504- if (operandOp && !domInfo.properlyDominates (operandOp, source,
505- /* enclosingOpOk=*/ false )) {
506- // `operand` is not an argument of an enclosing block and the defining
507- // op of `operand` is outside `target` but does not dominate `source`.
508- failed = true ;
509- failedValue = operand;
510- }
511- });
512-
513- if (failed)
514- return emitSilenceableFailure (failedValue->getOwner ())
515- << " values used inside regions of target should be properly "
516- " dominated by source" ;
517- }
518-
519- return DiagnosedSilenceableFailure::success ();
520- }
521-
522- // / Check if `target` scf.forall can be fused into `source` scf.forall.
523- // /
524- // / This simply checks if both loops have the same bounds, steps and mapping.
525- // / No attempt is made at checking that the side effects of `target` and
526- // / `source` are independent of each other.
527- static bool isForallWithIdenticalConfiguration (Operation *target,
528- Operation *source) {
529- auto targetOp = dyn_cast<scf::ForallOp>(target);
530- auto sourceOp = dyn_cast<scf::ForallOp>(source);
531- if (!targetOp || !sourceOp)
532- return false ;
533-
534- return targetOp.getMixedLowerBound () == sourceOp.getMixedLowerBound () &&
535- targetOp.getMixedUpperBound () == sourceOp.getMixedUpperBound () &&
536- targetOp.getMixedStep () == sourceOp.getMixedStep () &&
537- targetOp.getMapping () == sourceOp.getMapping ();
538- }
539-
540- // / Check if `target` scf.for can be fused into `source` scf.for.
541- // /
542- // / This simply checks if both loops have the same bounds and steps. No attempt
543- // / is made at checking that the side effects of `target` and `source` are
544- // / independent of each other.
545- static bool isForWithIdenticalConfiguration (Operation *target,
546- Operation *source) {
547- auto targetOp = dyn_cast<scf::ForOp>(target);
548- auto sourceOp = dyn_cast<scf::ForOp>(source);
549- if (!targetOp || !sourceOp)
550- return false ;
551-
552- return targetOp.getLowerBound () == sourceOp.getLowerBound () &&
553- targetOp.getUpperBound () == sourceOp.getUpperBound () &&
554- targetOp.getStep () == sourceOp.getStep ();
555- }
556-
557452DiagnosedSilenceableFailure
558453transform::LoopFuseSiblingOp::apply (transform::TransformRewriter &rewriter,
559454 transform::TransformResults &results,
@@ -569,25 +464,32 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
569464 << " source handle (got " << llvm::range_size (sourceOps) << " )" ;
570465 }
571466
572- Operation *target = *targetOps.begin ();
573- Operation *source = *sourceOps.begin ();
467+ auto target = dyn_cast<LoopLikeOpInterface>(*targetOps.begin ());
468+ auto source = dyn_cast<LoopLikeOpInterface>(*sourceOps.begin ());
469+ if (!target || !source)
470+ return emitSilenceableFailure (target->getLoc ())
471+ << " target or source is not a loop op" ;
574472
575- // Check if the target and source are siblings.
576- DiagnosedSilenceableFailure diag = isOpSibling (target, source );
577- if (!diag. succeeded ( ))
578- return diag;
473+ // Check if loops can be fused
474+ Diagnostic diag (target. getLoc (), DiagnosticSeverity::Error );
475+ if (!mlir::checkFusionStructuralLegality (target, source, diag ))
476+ return DiagnosedSilenceableFailure::silenceableFailure ( std::move ( diag)) ;
579477
580478 Operation *fusedLoop;
581- // / TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
582- if (isForWithIdenticalConfiguration (target, source)) {
479+ // TODO: Support fusion for loop-like ops besides scf.for, scf.forall
480+ // and scf.parallel.
481+ if (isa<scf::ForOp>(target) && isa<scf::ForOp>(source)) {
583482 fusedLoop = fuseIndependentSiblingForLoops (
584483 cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
585- } else if (isForallWithIdenticalConfiguration (target, source)) {
484+ } else if (isa<scf::ForallOp> (target) && isa<scf::ForallOp>( source)) {
586485 fusedLoop = fuseIndependentSiblingForallLoops (
587486 cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
487+ } else if (isa<scf::ParallelOp>(target) && isa<scf::ParallelOp>(source)) {
488+ fusedLoop = fuseIndependentSiblingParallelLoops (
489+ cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
588490 } else
589491 return emitSilenceableFailure (target->getLoc ())
590- << " operations cannot be fused " ;
492+ << " unsupported loop type for fusion " ;
591493
592494 assert (fusedLoop && " failed to fuse operations" );
593495
0 commit comments