Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
import java.util.function.Consumer;
import java.util.function.Function;
import org.apache.ratis.grpc.GrpcUtil;
import org.apache.ratis.grpc.metrics.ZeroCopyMetrics;
import org.apache.ratis.grpc.util.ZeroCopyMessageMarshaller;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.protocol.RaftServerProtocol;
import org.apache.ratis.server.util.ServerStringUtils;
import org.apache.ratis.thirdparty.io.grpc.ServerServiceDefinition;
import org.apache.ratis.thirdparty.io.grpc.Status;
import org.apache.ratis.thirdparty.io.grpc.StatusRuntimeException;
import org.apache.ratis.thirdparty.io.grpc.stub.StreamObserver;
Expand All @@ -41,15 +44,19 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static org.apache.ratis.grpc.GrpcUtil.addMethodWithCustomMarshaller;
import static org.apache.ratis.proto.grpc.RaftServerProtocolServiceGrpc.getAppendEntriesMethod;

class GrpcServerProtocolService extends RaftServerProtocolServiceImplBase {
public static final Logger LOG = LoggerFactory.getLogger(GrpcServerProtocolService.class);

static class PendingServerRequest<REQUEST> {
private final REQUEST request;
private final CompletableFuture<Void> future = new CompletableFuture<>();

PendingServerRequest(REQUEST request) {
this.request = request;
PendingServerRequest(ReferenceCountedObject<REQUEST> requestRef) {
this.request = requestRef.retain();
this.future.whenComplete((r, e) -> requestRef.release());
}

REQUEST getRequest() {
Expand Down Expand Up @@ -83,7 +90,21 @@ private String getPreviousRequestString() {
.orElse(null);
}

abstract CompletableFuture<REPLY> process(REQUEST request) throws IOException;
CompletableFuture<REPLY> process(REQUEST request) throws IOException {
throw new UnsupportedOperationException("This method is not supported.");
}

CompletableFuture<REPLY> process(ReferenceCountedObject<REQUEST> requestRef)
throws IOException {
try {
return process(requestRef.retain());
} finally {
requestRef.release();
}
}

void release(REQUEST req) {
}

abstract long getCallId(REQUEST request);

Expand Down Expand Up @@ -120,22 +141,29 @@ void composeRequest(CompletableFuture<REPLY> current) {

@Override
public void onNext(REQUEST request) {
ReferenceCountedObject<REQUEST> requestRef = ReferenceCountedObject.wrap(request, () -> {}, released -> {
if (released) {
release(request);
}
});

if (!replyInOrder(request)) {
try {
composeRequest(process(request).thenApply(this::handleReply));
composeRequest(process(requestRef).thenApply(this::handleReply));
} catch (Exception e) {
handleError(e, request);
release(request);
}
return;
}

final PendingServerRequest<REQUEST> current = new PendingServerRequest<>(request);
final PendingServerRequest<REQUEST> current = new PendingServerRequest<>(requestRef);
final PendingServerRequest<REQUEST> previous = previousOnNext.getAndSet(current);
final CompletableFuture<Void> previousFuture = Optional.ofNullable(previous)
.map(PendingServerRequest::getFuture)
.orElse(CompletableFuture.completedFuture(null));
try {
final CompletableFuture<REPLY> f = process(request).exceptionally(e -> {
final CompletableFuture<REPLY> f = process(requestRef).exceptionally(e -> {
// Handle cases, such as RaftServer is paused
handleError(e, request);
current.getFuture().completeExceptionally(e);
Expand Down Expand Up @@ -176,16 +204,35 @@ public void onError(Throwable t) {

private final Supplier<RaftPeerId> idSupplier;
private final RaftServer server;
private final ZeroCopyMessageMarshaller<AppendEntriesRequestProto> zeroCopyRequestMarshaller;

GrpcServerProtocolService(Supplier<RaftPeerId> idSupplier, RaftServer server) {
GrpcServerProtocolService(Supplier<RaftPeerId> idSupplier, RaftServer server, ZeroCopyMetrics zeroCopyMetrics) {
this.idSupplier = idSupplier;
this.server = server;
this.zeroCopyRequestMarshaller = new ZeroCopyMessageMarshaller<>(AppendEntriesRequestProto.getDefaultInstance(),
zeroCopyMetrics::onZeroCopyMessage, zeroCopyMetrics::onNonZeroCopyMessage, zeroCopyMetrics::onReleasedMessage);
}

RaftPeerId getId() {
return idSupplier.get();
}

ServerServiceDefinition bindServiceWithZeroCopy() {
ServerServiceDefinition orig = super.bindService();
ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(orig.getServiceDescriptor().getName());

// Add appendEntries with zero copy marshaller.
addMethodWithCustomMarshaller(orig, builder, getAppendEntriesMethod(), zeroCopyRequestMarshaller);
// Add remaining methods as is.
orig.getMethods().stream().filter(
x -> !x.getMethodDescriptor().getFullMethodName().equals(getAppendEntriesMethod().getFullMethodName())
).forEach(
builder::addMethod
);

return builder.build();
}

@Override
public void requestVote(RequestVoteRequestProto request,
StreamObserver<RequestVoteReplyProto> responseObserver) {
Expand Down Expand Up @@ -226,8 +273,14 @@ public StreamObserver<AppendEntriesRequestProto> appendEntries(
return new ServerRequestStreamObserver<AppendEntriesRequestProto, AppendEntriesReplyProto>(
RaftServerProtocol.Op.APPEND_ENTRIES, responseObserver) {
@Override
CompletableFuture<AppendEntriesReplyProto> process(AppendEntriesRequestProto request) throws IOException {
return server.appendEntriesAsync(ReferenceCountedObject.wrap(request));
CompletableFuture<AppendEntriesReplyProto> process(ReferenceCountedObject<AppendEntriesRequestProto> requestRef)
throws IOException {
return server.appendEntriesAsync(requestRef);
}

@Override
void release(AppendEntriesRequestProto req) {
zeroCopyRequestMarshaller.release(req);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,10 @@ private GrpcService(RaftServer raftServer, Supplier<RaftPeerId> idSupplier,

final NettyServerBuilder serverBuilder =
startBuildingNettyServer(serverHost, serverPort, serverTlsConfig, grpcMessageSizeMax, flowControlWindow);
GrpcServerProtocolService serverProtocolService = new GrpcServerProtocolService(idSupplier, raftServer,
zeroCopyMetrics);
serverBuilder.addService(ServerInterceptors.intercept(
new GrpcServerProtocolService(idSupplier, raftServer), serverInterceptor));
serverProtocolService.bindServiceWithZeroCopy(), serverInterceptor));
if (!separateAdminServer) {
addAdminService(raftServer, serverBuilder);
}
Expand Down