Skip to content
32 changes: 32 additions & 0 deletions core/src/main/java/io/grpc/internal/CloseWithHeadersMarker.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright 2025 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.internal;

import io.grpc.Status;

/**
* Marker to be used for Status sent to {@link ServerStream#cancel(Status)} to signal that stream
* should be closed by sending headers.
*/
public class CloseWithHeadersMarker extends Throwable {
private static final long serialVersionUID = 0L;

@Override
public synchronized Throwable fillInStackTrace() {
return this;
}
}
24 changes: 20 additions & 4 deletions core/src/main/java/io/grpc/internal/ServerCallImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ private void handleInternalError(Throwable internalError) {
serverCallTracer.reportCallEnded(false); // error so always false
}

/**
* Close the {@link ServerStream} because parsing request message failed.
* Similar to {@link #handleInternalError(Throwable)}.
*/
private void handleParseError(StatusRuntimeException parseError) {
cancelled = true;
log.log(Level.WARNING, "Cancelling the stream because of parse error", parseError);
stream.cancel(parseError.getStatus().withCause(new CloseWithHeadersMarker()));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be made more obvious with some API changes

serverCallTracer.reportCallEnded(false); // error so always false
}

/**
* All of these callbacks are assumed to called on an application thread, and the caller is
* responsible for handling thrown exceptions.
Expand Down Expand Up @@ -327,18 +338,23 @@ private void messagesAvailableInternal(final MessageProducer producer) {
return;
}

InputStream message;
InputStream message = null;
try {
while ((message = producer.next()) != null) {
ReqT parsed;
try {
listener.onMessage(call.method.parseRequest(message));
} catch (Throwable t) {
parsed = call.method.parseRequest(message);
} catch (StatusRuntimeException e) {
GrpcUtil.closeQuietly(message);
throw t;
GrpcUtil.closeQuietly(producer);
call.handleParseError(e);
return;
}
message.close();
listener.onMessage(parsed);
}
} catch (Throwable t) {
GrpcUtil.closeQuietly(message);
GrpcUtil.closeQuietly(producer);
Throwables.throwIfUnchecked(t);
throw new RuntimeException(t);
Expand Down
40 changes: 40 additions & 0 deletions core/src/test/java/io/grpc/internal/ServerCallImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,11 @@
import io.grpc.SecurityLevel;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.ServerCallImpl.ServerStreamListenerImpl;
import io.perfmark.PerfMark;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.junit.Before;
Expand All @@ -69,6 +71,8 @@ public class ServerCallImplTest {

@Mock private ServerStream stream;
@Mock private ServerCall.Listener<Long> callListener;
@Mock private StreamListener.MessageProducer messageProducer;
@Mock private InputStream message;

private final CallTracer serverCallTracer = CallTracer.getDefaultFactory().create();
private ServerCallImpl<Long, Long> call;
Expand Down Expand Up @@ -493,6 +497,42 @@ public void streamListener_unexpectedRuntimeException() {
assertThat(e).hasMessageThat().isEqualTo("unexpected exception");
}

@Test
public void streamListener_statusRuntimeException() throws IOException {
MethodDescriptor<Long, Long> failingParseMethod = MethodDescriptor.<Long, Long>newBuilder()
.setType(MethodType.UNARY)
.setFullMethodName("service/method")
.setRequestMarshaller(new LongMarshaller() {
@Override
public Long parse(InputStream stream) {
throw new StatusRuntimeException(Status.RESOURCE_EXHAUSTED
.withDescription("Decompressed gRPC message exceeds maximum size"));
}
})
.setResponseMarshaller(new LongMarshaller())
.build();

call = new ServerCallImpl<>(stream, failingParseMethod, requestHeaders, context,
DecompressorRegistry.getDefaultInstance(), CompressorRegistry.getDefaultInstance(),
serverCallTracer, PerfMark.createTag());

ServerStreamListenerImpl<Long> streamListener =
new ServerCallImpl.ServerStreamListenerImpl<>(call, callListener, context);

when(messageProducer.next()).thenReturn(message, (InputStream) null);
streamListener.messagesAvailable(messageProducer);
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);

verify(stream).cancel(statusCaptor.capture());
Status status = statusCaptor.getValue();
assertEquals(Status.RESOURCE_EXHAUSTED.getCode(), status.getCode());
assertEquals("Decompressed gRPC message exceeds maximum size", status.getDescription());

streamListener.halfClosed();
verify(callListener, never()).onHalfClose();
verify(callListener, never()).onMessage(any());
}

private static class LongMarshaller implements Marshaller<Long> {
@Override
public InputStream stream(Long value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2024,7 +2024,7 @@ private void assertPayload(Payload expected, Payload actual) {
}
}

private static void assertCodeEquals(Status.Code expected, Status actual) {
protected static void assertCodeEquals(Status.Code expected, Status actual) {
assertWithMessage("Unexpected status: %s", actual).that(actual.getCode()).isEqualTo(expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.testing.integration;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import com.google.protobuf.ByteString;
Expand All @@ -37,6 +38,8 @@
import io.grpc.ServerCall.Listener;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.GrpcUtil;
import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.InternalNettyServerBuilder;
Expand All @@ -53,7 +56,9 @@
import java.io.OutputStream;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TestName;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

Expand Down Expand Up @@ -84,10 +89,16 @@ public static void registerCompressors() {
compressors.register(Codec.Identity.NONE);
}

@Rule
public final TestName currentTest = new TestName();

@Override
protected ServerBuilder<?> getServerBuilder() {
NettyServerBuilder builder = NettyServerBuilder.forPort(0, InsecureServerCredentials.create())
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.maxInboundMessageSize(
DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME.equals(currentTest.getMethodName())
? 1000
: AbstractInteropTest.MAX_MESSAGE_SIZE)
.compressorRegistry(compressors)
.decompressorRegistry(decompressors)
.intercept(new ServerInterceptor() {
Expand Down Expand Up @@ -126,6 +137,22 @@ public void compresses() {
assertTrue(FZIPPER.anyWritten);
}

private static final String DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME =
"decompressedMessageTooLong";

@Test
public void decompressedMessageTooLong() {
assertEquals(DECOMPRESSED_MESSAGE_TOO_LONG_METHOD_NAME, currentTest.getMethodName());
final SimpleRequest bigRequest = SimpleRequest.newBuilder()
.setPayload(Payload.newBuilder().setBody(ByteString.copyFrom(new byte[10_000])))
.build();
StatusRuntimeException e = assertThrows(StatusRuntimeException.class,
() -> blockingStub.withCompression("gzip").unaryCall(bigRequest));
assertCodeEquals(Code.RESOURCE_EXHAUSTED, e.getStatus());
assertEquals("Decompressed gRPC message exceeds maximum size 1000",
e.getStatus().getDescription());
}

@Override
protected NettyChannelBuilder createChannelBuilder() {
NettyChannelBuilder builder = NettyChannelBuilder.forAddress(getListenAddress())
Expand Down
7 changes: 6 additions & 1 deletion netty/src/main/java/io/grpc/netty/NettyServerStream.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.Metadata;
import io.grpc.Status;
import io.grpc.internal.AbstractServerStream;
import io.grpc.internal.CloseWithHeadersMarker;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.internal.WritableBuffer;
Expand Down Expand Up @@ -130,7 +131,11 @@ public void writeTrailers(Metadata trailers, boolean headersSent, Status status)
@Override
public void cancel(Status status) {
try (TaskCloseable ignore = PerfMark.traceTask("NettyServerStream$Sink.cancel")) {
writeQueue.enqueue(CancelServerStreamCommand.withReset(transportState(), status), true);
CancelServerStreamCommand cmd =
status.getCause() instanceof CloseWithHeadersMarker
? CancelServerStreamCommand.withReason(transportState(), status)
: CancelServerStreamCommand.withReset(transportState(), status);
writeQueue.enqueue(cmd, true);
}
}
}
Expand Down