Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,19 +140,19 @@ static class StreamInfo {
private final boolean primary;
private final LocalStream local;
private final Set<RemoteStream> remotes;
private final RaftServer server;
private final Division division;
private final AtomicReference<CompletableFuture<Void>> previous
= new AtomicReference<>(CompletableFuture.completedFuture(null));

StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<DataStream> stream, RaftServer server,
StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture<DataStream> stream, Division division,
CheckedBiFunction<RaftClientRequest, Set<RaftPeer>, Set<DataStreamOutputImpl>, IOException> getStreams,
Function<RequestType, RequestMetrics> metricsConstructor)
throws IOException {
this.request = request;
this.primary = primary;
this.local = new LocalStream(stream, metricsConstructor.apply(RequestType.LOCAL_WRITE));
this.server = server;
final Set<RaftPeer> successors = getSuccessors(server.getId());
this.division = division;
final Set<RaftPeer> successors = getSuccessors(division.getId());
final Set<DataStreamOutputImpl> outs = getStreams.apply(request, successors);
this.remotes = outs.stream()
.map(o -> new RemoteStream(o, metricsConstructor.apply(RequestType.REMOTE_WRITE)))
Expand All @@ -167,16 +167,12 @@ RaftClientRequest getRequest() {
return request;
}

Division getDivision() throws IOException {
return server.getDivision(request.getRaftGroupId());
Division getDivision() {
return division;
}

Collection<CommitInfoProto> getCommitInfos() {
try {
return getDivision().getCommitInfos();
} catch (IOException e) {
throw new IllegalStateException(e);
}
return getDivision().getCommitInfos();
}

boolean isPrimary() {
Expand All @@ -196,7 +192,7 @@ public String toString() {
return JavaUtils.getClassSimpleName(getClass()) + ":" + request;
}

private Set<RaftPeer> getSuccessors(RaftPeerId peerId) throws IOException {
private Set<RaftPeer> getSuccessors(RaftPeerId peerId) {
final RaftConfiguration conf = getDivision().getRaftConf();
final RoutingTable routingTable = request.getRoutingTable();

Expand All @@ -208,7 +204,7 @@ private Set<RaftPeer> getSuccessors(RaftPeerId peerId) throws IOException {
// Default start topology
// get the other peers from the current configuration
return conf.getCurrentPeers().stream()
.filter(p -> !p.getId().equals(server.getId()))
.filter(p -> !p.getId().equals(division.getId()))
.collect(Collectors.toSet());
}

Expand Down Expand Up @@ -276,7 +272,8 @@ private StreamInfo newStreamInfo(ByteBuf buf,
final RaftClientRequest request = ClientProtoUtils.toRaftClientRequest(
RaftClientRequestProto.parseFrom(buf.nioBuffer()));
final boolean isPrimary = server.getId().equals(request.getServerId());
return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), server, getStreams,
final Division division = server.getDivision(request.getRaftGroupId());
return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), division, getStreams,
getMetrics()::newRequestMetrics);
} catch (Throwable e) {
throw new CompletionException(e);
Expand Down Expand Up @@ -411,6 +408,18 @@ void read(DataStreamRequestByteBuf request, ChannelHandlerContext ctx,
readImpl(request, ctx, getStreams);
} catch (Throwable t) {
replyDataStreamException(t, request, ctx);
removeDataStream(ClientInvocationId.valueOf(request.getClientId(), request.getStreamId()), null);
}
}

private void removeDataStream(ClientInvocationId invocationId, StreamInfo info) {
final StreamInfo removed = streams.remove(invocationId);
if (info == null) {
info = removed;
}
if (info != null) {
info.getDivision().getDataStreamMap().remove(invocationId);
info.getLocal().cleanUp();
}
}

Expand All @@ -429,8 +438,6 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct
() -> newStreamInfo(request.slice(), getStreams));
info = streams.computeIfAbsent(key, id -> supplier.get());
if (!supplier.isInitialized()) {
final StreamInfo removed = streams.remove(key);
removed.getLocal().cleanUp();
throw new IllegalStateException("Failed to create a new stream for " + request
+ " since a stream already exists Key: " + key + " StreamInfo:" + info);
}
Expand Down Expand Up @@ -468,9 +475,8 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct
}, requestExecutor)).whenComplete((v, exception) -> {
try {
if (exception != null) {
final StreamInfo removed = streams.remove(key);
replyDataStreamException(server, exception, info.getRequest(), request, ctx);
removed.getLocal().cleanUp();
removeDataStream(key, info);
}
} finally {
request.release();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class MultiDataStreamStateMachine extends BaseStateMachine {
@Override
public CompletableFuture<DataStream> stream(RaftClientRequest request) {
final SingleDataStream s = new SingleDataStream(request);
LOG.info("XXX {} put {}, {}", this, ClientInvocationId.valueOf(request), s);
streams.put(ClientInvocationId.valueOf(request), s);
return CompletableFuture.completedFuture(s);
}
Expand Down Expand Up @@ -179,7 +180,9 @@ SingleDataStream getSingleDataStream(RaftClientRequest request) {
}

SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
return streams.get(invocationId);
final SingleDataStream s = streams.get(invocationId);
LOG.info("XXX {}: get {} return {}", this, invocationId, s);
return s;
}

Collection<SingleDataStream> getStreams() {
Expand Down Expand Up @@ -329,6 +332,8 @@ static CompletableFuture<RaftClientReply> writeAndCloseAndAssertReplies(

static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize, boolean stepDownLeader)
throws Exception {
LOG.info("XXX {}: dataSize={}, stepDownLeader={}, header={}",
server.getId(), dataSize, stepDownLeader, header);
// check header
Assertions.assertEquals(RaftClientRequest.dataStreamRequestType(), header.getType());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ public void setup() {
RaftConfigKeys.DataStream.setType(properties, SupportedDataStreamType.NETTY);
}

RaftServer.Division mockDivision(RaftServer server) {

RaftServer.Division mockDivision(RaftServer server, RaftGroupId groupId) {
final RaftServer.Division division = mock(RaftServer.Division.class);
when(division.getRaftServer()).thenReturn(server);
when(division.getRaftConf()).thenAnswer(i -> getRaftConf());

final MultiDataStreamStateMachine stateMachine = new MultiDataStreamStateMachine();
try {
stateMachine.initialize(server, groupId, null);
} catch (IOException e) {
throw new IllegalStateException(e);
}
when(division.getStateMachine()).thenReturn(stateMachine);

final DataStreamMap streamMap = RaftServerTestUtil.newDataStreamMap(server.getId());
Expand Down Expand Up @@ -95,7 +101,7 @@ private void testMockCluster(int numServers, RaftException leaderException,
when(raftServer.getId()).thenReturn(peerId);
when(raftServer.getPeer()).thenReturn(RaftPeer.newBuilder().setId(peerId).build());
if (getStateMachineException == null) {
final RaftServer.Division myDivision = mockDivision(raftServer);
final RaftServer.Division myDivision = mockDivision(raftServer, groupId);
when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenReturn(myDivision);
} else {
when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenThrow(getStateMachineException);
Expand Down