diff --git a/gomock/matchers.go b/gomock/matchers.go index 2822fb2c..b181fa4d 100644 --- a/gomock/matchers.go +++ b/gomock/matchers.go @@ -271,6 +271,37 @@ func (m inAnyOrderMatcher) String() string { return fmt.Sprintf("has the same elements as %v", m.x) } +type inMatcher struct { + x interface{} +} + +func (m inMatcher) Matches(x interface{}) bool { + matchers := m.prepareValue(m.x) + for _, matcher := range matchers { + if matcher.Matches(x) { + return true + } + } + return false +} + +func (m inMatcher) prepareValue(x interface{}) []Matcher { + xValue := reflect.ValueOf(x) + kind := xValue.Kind() + if kind != reflect.Slice && kind != reflect.Array { + return nil + } + matchers := make([]Matcher, 0, xValue.Len()) + for i := 0; i < xValue.Len(); i++ { + matchers = append(matchers, Eq(xValue.Index(i).Interface())) + } + return matchers +} + +func (m inMatcher) String() string { + return fmt.Sprintf("match one of %v", m.x) +} + // Constructors // All returns a composite Matcher that returns true if and only all of the @@ -339,3 +370,14 @@ func AssignableToTypeOf(x interface{}) Matcher { func InAnyOrder(x interface{}) Matcher { return inAnyOrderMatcher{x} } + +// In is a Matcher that returns true if the received value Eq one of the elements +// +// Example usage: +// m := In([]int{1,2}) +// m.Matches(1) // returns true +// m.Matches(2) // returns true +// m.Matches(3) // returns false +func In(x interface{}) Matcher { + return inMatcher{x} +} diff --git a/gomock/matchers_test.go b/gomock/matchers_test.go index 61bc1993..48e902b6 100644 --- a/gomock/matchers_test.go +++ b/gomock/matchers_test.go @@ -293,3 +293,44 @@ func TestInAnyOrder(t *testing.T) { }) } } + +func TestInMatcher_Matches(t *testing.T) { + tests := []struct { + name string + wanted interface{} + given interface{} + wantMatch bool + }{ + { + "match successfully", + []int{1, 2}, + 2, + true, + }, + { + "given not found", + []int{1, 2}, + 3, + false, + }, + { + "type not match, wanted should be slice", + 1, + 1, + false, + }, + { + "empty slice", + []int{}, + 1, + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := gomock.In(tt.wanted).Matches(tt.given); got != tt.wantMatch { + t.Errorf("got = %v, wantMatch %v", got, tt.wantMatch) + } + }) + } +}