88//
99// Implements the base layer of the matcher framework.
1010//
11- // Matchers are methods that return a Matcher which provides a method
12- // match(Operation *op)
11+ // Matchers are methods that return a Matcher which provides a method one of the
12+ // following methods: match(Operation *op), match(Operation *op,
13+ // SetVector<Operation *> &matchedOps)
1314//
1415// The matcher functions are defined in include/mlir/IR/Matchers.h.
1516// This file contains the wrapper classes needed to construct matchers for
2526
2627namespace mlir ::query::matcher {
2728
29+ // Defaults to false if T has no match() method with the signature:
30+ // match(Operation* op).
31+ template <typename T, typename = void >
32+ struct has_simple_match : std::false_type {};
33+
34+ // Specialized type trait that evaluates to true if T has a match() method
35+ // with the signature: match(Operation* op).
36+ template <typename T>
37+ struct has_simple_match <T, std::void_t <decltype (std::declval<T>().match(
38+ std::declval<Operation *>()))>>
39+ : std::true_type {};
40+
41+ // Defaults to false if T has no match() method with the signature:
42+ // match(Operation* op, SetVector<Operation*>&).
43+ template <typename T, typename = void >
44+ struct has_bound_match : std::false_type {};
45+
46+ // Specialized type trait that evaluates to true if T has a match() method
47+ // with the signature: match(Operation* op, SetVector<Operation*>&).
48+ template <typename T>
49+ struct has_bound_match <T, std::void_t <decltype (std::declval<T>().match(
50+ std::declval<Operation *>(),
51+ std::declval<SetVector<Operation *> &>()))>>
52+ : std::true_type {};
53+
2854// Generic interface for matchers on an MLIR operation.
2955class MatcherInterface
3056 : public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
3157public:
3258 virtual ~MatcherInterface () = default ;
3359
3460 virtual bool match (Operation *op) = 0;
61+ virtual bool match (Operation *op, SetVector<Operation *> &matchedOps) = 0;
3562};
3663
3764// MatcherFnImpl takes a matcher function object and implements
@@ -40,14 +67,25 @@ template <typename MatcherFn>
4067class MatcherFnImpl : public MatcherInterface {
4168public:
4269 MatcherFnImpl (MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43- bool match (Operation *op) override { return matcherFn.match (op); }
70+
71+ bool match (Operation *op) override {
72+ if constexpr (has_simple_match<MatcherFn>::value)
73+ return matcherFn.match (op);
74+ return false ;
75+ }
76+
77+ bool match (Operation *op, SetVector<Operation *> &matchedOps) override {
78+ if constexpr (has_bound_match<MatcherFn>::value)
79+ return matcherFn.match (op, matchedOps);
80+ return false ;
81+ }
4482
4583private:
4684 MatcherFn matcherFn;
4785};
4886
49- // Matcher wraps a MatcherInterface implementation and provides a match()
50- // method that redirects calls to the underlying implementation.
87+ // Matcher wraps a MatcherInterface implementation and provides match()
88+ // methods that redirect calls to the underlying implementation.
5189class DynMatcher {
5290public:
5391 // Takes ownership of the provided implementation pointer.
@@ -62,12 +100,13 @@ class DynMatcher {
62100 }
63101
64102 bool match (Operation *op) const { return implementation->match (op); }
103+ bool match (Operation *op, SetVector<Operation *> &matchedOps) const {
104+ return implementation->match (op, matchedOps);
105+ }
65106
66- void setFunctionName (StringRef name) { functionName = name.str (); };
67-
68- bool hasFunctionName () const { return !functionName.empty (); };
69-
70- StringRef getFunctionName () const { return functionName; };
107+ void setFunctionName (StringRef name) { functionName = name.str (); }
108+ bool hasFunctionName () const { return !functionName.empty (); }
109+ StringRef getFunctionName () const { return functionName; }
71110
72111private:
73112 llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
0 commit comments