4949import org .opensearch .core .common .breaker .NoopCircuitBreaker ;
5050import org .opensearch .core .index .Index ;
5151import org .opensearch .core .index .shard .ShardId ;
52+ import org .opensearch .core .tasks .TaskCancelledException ;
5253import org .opensearch .core .tasks .resourcetracker .TaskResourceInfo ;
5354import org .opensearch .core .tasks .resourcetracker .TaskResourceUsage ;
5455import org .opensearch .index .query .MatchAllQueryBuilder ;
6667import org .opensearch .threadpool .TestThreadPool ;
6768import org .opensearch .threadpool .ThreadPool ;
6869import org .opensearch .transport .Transport ;
70+ import org .opensearch .transport .TransportException ;
6971import org .junit .After ;
7072import org .junit .Before ;
7173
@@ -136,6 +138,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
136138 controlled ,
137139 false ,
138140 false ,
141+ false ,
139142 expected ,
140143 resourceUsage ,
141144 new SearchShardIterator (null , null , Collections .emptyList (), null )
@@ -148,6 +151,7 @@ private AbstractSearchAsyncAction<SearchPhaseResult> createAction(
148151 ActionListener <SearchResponse > listener ,
149152 final boolean controlled ,
150153 final boolean failExecutePhaseOnShard ,
154+ final boolean throw4xxExceptionOnShard ,
151155 final boolean catchExceptionWhenExecutePhaseOnShard ,
152156 final AtomicLong expected ,
153157 final TaskResourceUsage resourceUsage ,
@@ -217,7 +221,11 @@ protected void executePhaseOnShard(
217221 final SearchActionListener <SearchPhaseResult > listener
218222 ) {
219223 if (failExecutePhaseOnShard ) {
220- listener .onFailure (new ShardNotFoundException (shardIt .shardId ()));
224+ if (throw4xxExceptionOnShard ) {
225+ listener .onFailure (new TransportException (new TaskCancelledException (shardIt .shardId ().toString ())));
226+ } else {
227+ listener .onFailure (new ShardNotFoundException (shardIt .shardId ()));
228+ }
221229 } else {
222230 if (catchExceptionWhenExecutePhaseOnShard ) {
223231 try {
@@ -585,6 +593,7 @@ public void onFailure(Exception e) {
585593 false ,
586594 true ,
587595 false ,
596+ false ,
588597 new AtomicLong (),
589598 new TaskResourceUsage (randomLong (), randomLong ()),
590599 shards
@@ -601,6 +610,62 @@ public void onFailure(Exception e) {
601610 assertThat (searchResponse .getSuccessfulShards (), equalTo (0 ));
602611 }
603612
613+ public void testSkipInValidRetryInMultiReplicas () throws InterruptedException {
614+ final Index index = new Index ("test" , UUID .randomUUID ().toString ());
615+ final CountDownLatch latch = new CountDownLatch (1 );
616+ final AtomicBoolean fail = new AtomicBoolean (true );
617+
618+ List <String > targetNodeIds = List .of ("n1" , "n2" , "n3" );
619+ final SearchShardIterator [] shards = IntStream .range (2 , 4 )
620+ .mapToObj (i -> new SearchShardIterator (null , new ShardId (index , i ), targetNodeIds , null , null , null ))
621+ .toArray (SearchShardIterator []::new );
622+
623+ SearchRequest searchRequest = new SearchRequest ().allowPartialSearchResults (true );
624+ searchRequest .setMaxConcurrentShardRequests (1 );
625+
626+ final ArraySearchPhaseResults <SearchPhaseResult > queryResult = new ArraySearchPhaseResults <>(shards .length );
627+ AbstractSearchAsyncAction <SearchPhaseResult > action = createAction (
628+ searchRequest ,
629+ queryResult ,
630+ new ActionListener <SearchResponse >() {
631+ @ Override
632+ public void onResponse (SearchResponse response ) {
633+
634+ }
635+
636+ @ Override
637+ public void onFailure (Exception e ) {
638+ if (fail .compareAndExchange (true , false )) {
639+ try {
640+ throw new RuntimeException ("Simulated exception" );
641+ } finally {
642+ executor .submit (() -> latch .countDown ());
643+ }
644+ }
645+ }
646+ },
647+ false ,
648+ true ,
649+ true ,
650+ false ,
651+ new AtomicLong (),
652+ new TaskResourceUsage (randomLong (), randomLong ()),
653+ shards
654+ );
655+ action .run ();
656+ assertTrue (latch .await (1 , TimeUnit .SECONDS ));
657+ InternalSearchResponse internalSearchResponse = InternalSearchResponse .empty ();
658+ SearchResponse searchResponse = action .buildSearchResponse (internalSearchResponse , action .buildShardFailures (), null , null );
659+ assertSame (searchResponse .getAggregations (), internalSearchResponse .aggregations ());
660+ assertSame (searchResponse .getSuggest (), internalSearchResponse .suggest ());
661+ assertSame (searchResponse .getProfileResults (), internalSearchResponse .profile ());
662+ assertSame (searchResponse .getHits (), internalSearchResponse .hits ());
663+ assertThat (searchResponse .getSuccessfulShards (), equalTo (0 ));
664+ for (int i = 0 ; i < shards .length ; i ++) {
665+ assertEquals (targetNodeIds .size () - 1 , shards [i ].remaining ());
666+ }
667+ }
668+
604669 public void testOnShardSuccessPhaseDoneFailure () throws InterruptedException {
605670 final Index index = new Index ("test" , UUID .randomUUID ().toString ());
606671 final CountDownLatch latch = new CountDownLatch (1 );
@@ -633,6 +698,7 @@ public void onFailure(Exception e) {
633698 false ,
634699 false ,
635700 false ,
701+ false ,
636702 new AtomicLong (),
637703 new TaskResourceUsage (randomLong (), randomLong ()),
638704 shards
@@ -685,6 +751,7 @@ public void onFailure(Exception e) {
685751 },
686752 false ,
687753 false ,
754+ false ,
688755 catchExceptionWhenExecutePhaseOnShard ,
689756 new AtomicLong (),
690757 new TaskResourceUsage (randomLong (), randomLong ()),
0 commit comments