diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java index a6e9b815ee..e265d8b924 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/server/DataStreamManagement.java @@ -140,19 +140,19 @@ static class StreamInfo { private final boolean primary; private final LocalStream local; private final Set remotes; - private final RaftServer server; + private final Division division; private final AtomicReference> previous = new AtomicReference<>(CompletableFuture.completedFuture(null)); - StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture stream, RaftServer server, + StreamInfo(RaftClientRequest request, boolean primary, CompletableFuture stream, Division division, CheckedBiFunction, Set, IOException> getStreams, Function metricsConstructor) throws IOException { this.request = request; this.primary = primary; this.local = new LocalStream(stream, metricsConstructor.apply(RequestType.LOCAL_WRITE)); - this.server = server; - final Set successors = getSuccessors(server.getId()); + this.division = division; + final Set successors = getSuccessors(division.getId()); final Set outs = getStreams.apply(request, successors); this.remotes = outs.stream() .map(o -> new RemoteStream(o, metricsConstructor.apply(RequestType.REMOTE_WRITE))) @@ -167,16 +167,12 @@ RaftClientRequest getRequest() { return request; } - Division getDivision() throws IOException { - return server.getDivision(request.getRaftGroupId()); + Division getDivision() { + return division; } Collection getCommitInfos() { - try { - return getDivision().getCommitInfos(); - } catch (IOException e) { - throw new IllegalStateException(e); - } + return getDivision().getCommitInfos(); } boolean isPrimary() { @@ -196,7 +192,7 @@ public String toString() { return JavaUtils.getClassSimpleName(getClass()) + ":" + request; } - private Set getSuccessors(RaftPeerId peerId) throws IOException { + private Set getSuccessors(RaftPeerId peerId) { final RaftConfiguration conf = getDivision().getRaftConf(); final RoutingTable routingTable = request.getRoutingTable(); @@ -208,7 +204,7 @@ private Set getSuccessors(RaftPeerId peerId) throws IOException { // Default start topology // get the other peers from the current configuration return conf.getCurrentPeers().stream() - .filter(p -> !p.getId().equals(server.getId())) + .filter(p -> !p.getId().equals(division.getId())) .collect(Collectors.toSet()); } @@ -276,7 +272,8 @@ private StreamInfo newStreamInfo(ByteBuf buf, final RaftClientRequest request = ClientProtoUtils.toRaftClientRequest( RaftClientRequestProto.parseFrom(buf.nioBuffer())); final boolean isPrimary = server.getId().equals(request.getServerId()); - return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), server, getStreams, + final Division division = server.getDivision(request.getRaftGroupId()); + return new StreamInfo(request, isPrimary, computeDataStreamIfAbsent(request), division, getStreams, getMetrics()::newRequestMetrics); } catch (Throwable e) { throw new CompletionException(e); @@ -411,6 +408,18 @@ void read(DataStreamRequestByteBuf request, ChannelHandlerContext ctx, readImpl(request, ctx, getStreams); } catch (Throwable t) { replyDataStreamException(t, request, ctx); + removeDataStream(ClientInvocationId.valueOf(request.getClientId(), request.getStreamId()), null); + } + } + + private void removeDataStream(ClientInvocationId invocationId, StreamInfo info) { + final StreamInfo removed = streams.remove(invocationId); + if (info == null) { + info = removed; + } + if (info != null) { + info.getDivision().getDataStreamMap().remove(invocationId); + info.getLocal().cleanUp(); } } @@ -429,8 +438,6 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct () -> newStreamInfo(request.slice(), getStreams)); info = streams.computeIfAbsent(key, id -> supplier.get()); if (!supplier.isInitialized()) { - final StreamInfo removed = streams.remove(key); - removed.getLocal().cleanUp(); throw new IllegalStateException("Failed to create a new stream for " + request + " since a stream already exists Key: " + key + " StreamInfo:" + info); } @@ -468,9 +475,8 @@ private void readImpl(DataStreamRequestByteBuf request, ChannelHandlerContext ct }, requestExecutor)).whenComplete((v, exception) -> { try { if (exception != null) { - final StreamInfo removed = streams.remove(key); replyDataStreamException(server, exception, info.getRequest(), request, ctx); - removed.getLocal().cleanUp(); + removeDataStream(key, info); } } finally { request.release(); diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java index 47138919df..7735c3e309 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java @@ -151,6 +151,7 @@ class MultiDataStreamStateMachine extends BaseStateMachine { @Override public CompletableFuture stream(RaftClientRequest request) { final SingleDataStream s = new SingleDataStream(request); + LOG.info("XXX {} put {}, {}", this, ClientInvocationId.valueOf(request), s); streams.put(ClientInvocationId.valueOf(request), s); return CompletableFuture.completedFuture(s); } @@ -179,7 +180,9 @@ SingleDataStream getSingleDataStream(RaftClientRequest request) { } SingleDataStream getSingleDataStream(ClientInvocationId invocationId) { - return streams.get(invocationId); + final SingleDataStream s = streams.get(invocationId); + LOG.info("XXX {}: get {} return {}", this, invocationId, s); + return s; } Collection getStreams() { @@ -329,6 +332,8 @@ static CompletableFuture writeAndCloseAndAssertReplies( static void assertHeader(RaftServer server, RaftClientRequest header, int dataSize, boolean stepDownLeader) throws Exception { + LOG.info("XXX {}: dataSize={}, stepDownLeader={}, header={}", + server.getId(), dataSize, stepDownLeader, header); // check header Assertions.assertEquals(RaftClientRequest.dataStreamRequestType(), header.getType()); diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamWithMock.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamWithMock.java index 503f8cf66e..1d8c67a43d 100644 --- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamWithMock.java +++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestNettyDataStreamWithMock.java @@ -59,12 +59,18 @@ public void setup() { RaftConfigKeys.DataStream.setType(properties, SupportedDataStreamType.NETTY); } - RaftServer.Division mockDivision(RaftServer server) { + + RaftServer.Division mockDivision(RaftServer server, RaftGroupId groupId) { final RaftServer.Division division = mock(RaftServer.Division.class); when(division.getRaftServer()).thenReturn(server); when(division.getRaftConf()).thenAnswer(i -> getRaftConf()); final MultiDataStreamStateMachine stateMachine = new MultiDataStreamStateMachine(); + try { + stateMachine.initialize(server, groupId, null); + } catch (IOException e) { + throw new IllegalStateException(e); + } when(division.getStateMachine()).thenReturn(stateMachine); final DataStreamMap streamMap = RaftServerTestUtil.newDataStreamMap(server.getId()); @@ -95,7 +101,7 @@ private void testMockCluster(int numServers, RaftException leaderException, when(raftServer.getId()).thenReturn(peerId); when(raftServer.getPeer()).thenReturn(RaftPeer.newBuilder().setId(peerId).build()); if (getStateMachineException == null) { - final RaftServer.Division myDivision = mockDivision(raftServer); + final RaftServer.Division myDivision = mockDivision(raftServer, groupId); when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenReturn(myDivision); } else { when(raftServer.getDivision(Mockito.any(RaftGroupId.class))).thenThrow(getStateMachineException);