2121#include " mlir/Interfaces/InferIntRangeInterface.h"
2222#include " mlir/Query/Matcher/MatchersInternal.h"
2323#include " mlir/Query/Query.h"
24- #include " llvm/ADT/SetVector.h"
25-
2624namespace mlir {
2725
2826namespace detail {
@@ -366,21 +364,14 @@ struct RecursivePatternMatcher {
366364 std::tuple<OperandMatchers...> operandMatchers;
367365};
368366
369- // / Fills `backwardSlice` with the computed backward slice (i.e.
370- // / all the transitive defs of op)
371- // /
372- // / The implementation traverses the def chains in postorder traversal for
373- // / efficiency reasons: if an operation is already in `backwardSlice`, no
374- // / need to traverse its definitions again. Since use-def chains form a DAG,
375- // / this terminates.
376- // /
377- // / Upon return to the root call, `backwardSlice` is filled with a
378- // / postorder list of defs. This happens to be a topological order, from the
379- // / point of view of the use-def chains.
367+ // / A matcher encapsulating the initial `getBackwardSlice` method from
368+ // / SliceAnalysis.h
369+ // / Additionally, it limits the slice computation to a certain depth level using
370+ // / a custom filter
380371// /
381- // / Example starting from node 8
372+ // / Example starting from node 9, assuming the matcher
373+ // / computes the slice for the first two depth levels
382374// / ============================
383- // /
384375// / 1 2 3 4
385376// / |_______| |______|
386377// / | | |
@@ -393,240 +384,52 @@ struct RecursivePatternMatcher {
393384// / 9
394385// /
395386// / Assuming all local orders match the numbering order:
396- // / {1, 2, 5, 3, 4, 6}
397- // /
398-
387+ // / {5, 7, 6, 8, 9}
399388class BackwardSliceMatcher {
400389public:
401- BackwardSliceMatcher (mlir:: query::matcher::DynMatcher &&innerMatcher,
390+ BackwardSliceMatcher (query::matcher::DynMatcher &&innerMatcher,
402391 int64_t maxDepth)
403392 : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
404-
405393 bool match (Operation *op, SetVector<Operation *> &backwardSlice,
406- mlir:: query::QueryOptions &options) {
394+ query::QueryOptions &options) {
407395
408396 if (innerMatcher.match (op) &&
409397 matches (op, backwardSlice, options, maxDepth)) {
410- if (!options.inclusive ) {
411- // Don't insert the top level operation, we just queried on it and don't
412- // want it in the results.
413- backwardSlice.remove (op);
414- }
415398 return true ;
416399 }
417400 return false ;
418401 }
419402
420403private:
421- bool matches (Operation *op, SetVector<Operation *> &backwardSlice,
422- mlir::query::QueryOptions &options, int64_t remainingDepth) {
423-
424- if (op->hasTrait <OpTrait::IsIsolatedFromAbove>()) {
425- return false ;
426- }
427-
428- auto processValue = [&](Value value) {
429- // We need to check the current depth level;
430- // if we have reached level 0, we stop further traversing
431- if (remainingDepth == 0 ) {
432- return ;
433- }
434- if (auto *definingOp = value.getDefiningOp ()) {
435- // We omit traversing the same operations
436- if (backwardSlice.count (definingOp) == 0 )
437- matches (definingOp, backwardSlice, options, remainingDepth - 1 );
438- } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
439- if (options.omitBlockArguments )
440- return ;
441- Block *block = blockArg.getOwner ();
442-
443- Operation *parentOp = block->getParentOp ();
444- // TODO: determine whether we want to recurse backward into the other
445- // blocks of parentOp, which are not technically backward unless they
446- // flow into us. For now, just bail.
447- if (parentOp && backwardSlice.count (parentOp) == 0 ) {
448- if (parentOp->getNumRegions () != 1 &&
449- parentOp->getRegion (0 ).getBlocks ().size () != 1 ) {
450- llvm::errs ()
451- << " Error: Expected parentOp to have exactly one region and "
452- << " exactly one block, but found " << parentOp->getNumRegions ()
453- << " regions and "
454- << (parentOp->getRegion (0 ).getBlocks ().size ()) << " blocks.\n " ;
455- };
456- matches (parentOp, backwardSlice, options, remainingDepth - 1 );
457- }
458- } else {
459- llvm_unreachable (" No definingOp and not a block argument\n " );
460- return ;
461- }
462- };
463-
464- if (!options.omitUsesFromAbove ) {
465- llvm::for_each (op->getRegions (), [&](Region ®ion) {
466- // Walk this region recursively to collect the regions that descend from
467- // this op's nested regions (inclusive).
468- SmallPtrSet<Region *, 4 > descendents;
469- region.walk (
470- [&](Region *childRegion) { descendents.insert (childRegion); });
471- region.walk ([&](Operation *op) {
472- for (OpOperand &operand : op->getOpOperands ()) {
473- if (!descendents.contains (operand.get ().getParentRegion ()))
474- processValue (operand.get ());
475- }
476- });
477- });
478- }
479-
480- llvm::for_each (op->getOperands (), processValue);
481- backwardSlice.insert (op);
482- return true ;
483- }
404+ bool matches (Operation *rootOp, llvm::SetVector<Operation *> &backwardSlice,
405+ query::QueryOptions &options, int64_t maxDepth);
484406
485407private:
486408 // The outer matcher (e.g., BackwardSliceMatcher) relies on the innerMatcher
487409 // to determine whether we want to traverse the DAG or not. For example, we
488410 // want to explore the DAG only if the top-level operation name is
489411 // "arith.addf".
490- mlir::query::matcher::DynMatcher innerMatcher;
491-
412+ query::matcher::DynMatcher innerMatcher;
492413 // maxDepth specifies the maximum depth that the matcher can traverse in the
493414 // DAG. For example, if maxDepth is 2, the matcher will explore the defining
494415 // operations of the top-level op up to 2 levels.
495416 int64_t maxDepth;
496417};
497-
498- // / Fills `forwardSlice` with the computed forward slice (i.e. all
499- // / the transitive uses of op)
500- // /
501- // /
502- // / The implementation traverses the use chains in postorder traversal for
503- // / efficiency reasons: if an operation is already in `forwardSlice`, no
504- // / need to traverse its uses again. Since use-def chains form a DAG, this
505- // / terminates.
506- // /
507- // / Upon return to the root call, `forwardSlice` is filled with a
508- // / postorder list of uses (i.e. a reverse topological order). To get a proper
509- // / topological order, we just reverse the order in `forwardSlice` before
510- // / returning.
511- // /
512- // / Example starting from node 0
513- // / ============================
514- // /
515- // / 0
516- // / ___________|___________
517- // / 1 2 3 4
518- // / |_______| |______|
519- // / | | |
520- // / | 5 6
521- // / |___|_____________|
522- // / | |
523- // / 7 8
524- // / |_______________|
525- // / |
526- // / 9
527- // /
528- // / Assuming all local orders match the numbering order:
529- // / 1. after getting back to the root getForwardSlice, `forwardSlice` may
530- // / contain:
531- // / {9, 7, 8, 5, 1, 2, 6, 3, 4}
532- // / 2. reversing the result of 1. gives:
533- // / {4, 3, 6, 2, 1, 5, 8, 7, 9}
534- // /
535- class ForwardSliceMatcher {
536- public:
537- ForwardSliceMatcher (mlir::query::matcher::DynMatcher &&innerMatcher,
538- int64_t maxDepth)
539- : innerMatcher(std::move(innerMatcher)), maxDepth(maxDepth) {}
540-
541- bool match (Operation *op, SetVector<Operation *> &forwardSlice,
542- mlir::query::QueryOptions &options) {
543- if (innerMatcher.match (op) &&
544- matches (op, forwardSlice, options, maxDepth)) {
545- if (!options.inclusive ) {
546- // Don't insert the top level operation, we just queried on it and don't
547- // want it in the results.
548- forwardSlice.remove (op);
549- }
550- // Reverse to get back the actual topological order.
551- // std::reverse does not work out of the box on SetVector and I want an
552- // in-place swap based thing (the real std::reverse, not the LLVM
553- // adapter).
554- SmallVector<Operation *, 0 > v (forwardSlice.takeVector ());
555- forwardSlice.insert (v.rbegin (), v.rend ());
556- return true ;
557- }
558- return false ;
559- }
560-
561- private:
562- bool matches (Operation *op, SetVector<Operation *> &forwardSlice,
563- mlir::query::QueryOptions &options, int64_t remainingDepth) {
564-
565- // We need to check the current depth level;
566- // if we have reached level 0, we stop further traversing and insert
567- // the last user in def-use chain
568- if (remainingDepth == 0 ) {
569- forwardSlice.insert (op);
570- return true ;
571- }
572-
573- for (Region ®ion : op->getRegions ())
574- for (Block &block : region)
575- for (Operation &blockOp : block)
576- if (forwardSlice.count (&blockOp) == 0 )
577- matches (&blockOp, forwardSlice, options, remainingDepth - 1 );
578- for (Value result : op->getResults ()) {
579- for (Operation *userOp : result.getUsers ())
580- // We omit traversing the same operations
581- if (forwardSlice.count (userOp) == 0 )
582- matches (userOp, forwardSlice, options, remainingDepth - 1 );
583- }
584-
585- forwardSlice.insert (op);
586- return true ;
587- }
588-
589- private:
590- // The outer matcher e.g (ForwardSliceMatcher) relies on the innerMatcher to
591- // determine whether we want to traverse the graph or not. E.g: we want to
592- // explore the DAG only if the top level operation name is "arith.addf"
593- mlir::query::matcher::DynMatcher innerMatcher;
594-
595- // maxDepth specifies the maximum depth that the matcher can traverse the
596- // graph E.g: if maxDepth is 2, the matcher will explore the user
597- // operations of the top level op up to 2 levels
598- int64_t maxDepth;
599- };
600-
601418} // namespace detail
602419
603420// Matches transitive defs of a top level operation up to 1 level
604421inline detail::BackwardSliceMatcher
605- m_DefinedBy (mlir:: query::matcher::DynMatcher innerMatcher) {
422+ m_DefinedBy (query::matcher::DynMatcher innerMatcher) {
606423 return detail::BackwardSliceMatcher (std::move (innerMatcher), 1 );
607424}
608425
609426// Matches transitive defs of a top level operation up to N levels
610427inline detail::BackwardSliceMatcher
611- m_GetDefinitions (mlir::query::matcher::DynMatcher innerMatcher,
612- int64_t maxDepth) {
428+ m_GetDefinitions (query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
613429 assert (maxDepth >= 0 && " maxDepth must be non-negative" );
614430 return detail::BackwardSliceMatcher (std::move (innerMatcher), maxDepth);
615431}
616432
617- // Matches uses of a top level operation up to 1 level
618- inline detail::ForwardSliceMatcher
619- m_UsedBy (mlir::query::matcher::DynMatcher innerMatcher) {
620- return detail::ForwardSliceMatcher (std::move (innerMatcher), 1 );
621- }
622-
623- // Matches uses of a top level operation up to N levels
624- inline detail::ForwardSliceMatcher
625- m_GetUses (mlir::query::matcher::DynMatcher innerMatcher, int64_t maxDepth) {
626- assert (maxDepth >= 0 && " maxDepth must be non-negative" );
627- return detail::ForwardSliceMatcher (std::move (innerMatcher), maxDepth);
628- }
629-
630433// / Matches a constant foldable operation.
631434inline detail::constant_op_matcher m_Constant () {
632435 return detail::constant_op_matcher ();
0 commit comments