From 50e9f37e10e04b29a1abb23d7fbeca3140746b33 Mon Sep 17 00:00:00 2001 From: Pedro Matias Date: Sat, 18 Apr 2026 03:59:32 +0100 Subject: [PATCH 1/2] Reuse updated handle --- .../arrow/flight/sql/FlightSqlClient.java | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index 623f9311e8..09af644ba9 100644 --- a/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -1217,6 +1217,7 @@ protected void updateCommandBuilder(CommandStatementIngest.Builder builder) { public static class PreparedStatement implements AutoCloseable { private final FlightClient client; private final ActionCreatePreparedStatementResult preparedStatementResult; + private ByteString handle; private VectorSchemaRoot parameterBindingRoot; private boolean isClosed; private Schema resultSetSchema; @@ -1229,6 +1230,7 @@ public static class PreparedStatement implements AutoCloseable { preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( preparedStatementResults.next().getBody(), ActionCreatePreparedStatementResult.class); + handle = preparedStatementResult.getPreparedStatementHandle(); isClosed = false; } @@ -1292,8 +1294,7 @@ public SchemaResult fetchSchema(CallOption... options) { FlightDescriptor.command( Any.pack( CommandPreparedStatementQuery.newBuilder() - .setPreparedStatementHandle( - preparedStatementResult.getPreparedStatementHandle()) + .setPreparedStatementHandle(handle) .build()) .toByteArray()); return client.getSchema(descriptor, options); @@ -1324,8 +1325,7 @@ public FlightInfo execute(final CallOption... options) { FlightDescriptor.command( Any.pack( CommandPreparedStatementQuery.newBuilder() - .setPreparedStatementHandle( - preparedStatementResult.getPreparedStatementHandle()) + .setPreparedStatementHandle(handle) .build()) .toByteArray()); @@ -1339,12 +1339,16 @@ public FlightInfo execute(final CallOption... options) { try (final ArrowBuf metadata = read.getApplicationMetadata()) { final FlightSql.DoPutPreparedStatementResult doPutPreparedStatementResult = FlightSql.DoPutPreparedStatementResult.parseFrom(metadata.nioBuffer()); + final ByteString updatedHandle = + doPutPreparedStatementResult.getPreparedStatementHandle(); + if (!updatedHandle.isEmpty()) { + handle = updatedHandle; + } descriptor = FlightDescriptor.command( Any.pack( CommandPreparedStatementQuery.newBuilder() - .setPreparedStatementHandle( - doPutPreparedStatementResult.getPreparedStatementHandle()) + .setPreparedStatementHandle(handle) .build()) .toByteArray()); } @@ -1396,8 +1400,7 @@ public long executeUpdate(final CallOption... options) { FlightDescriptor.command( Any.pack( CommandPreparedStatementUpdate.newBuilder() - .setPreparedStatementHandle( - preparedStatementResult.getPreparedStatementHandle()) + .setPreparedStatementHandle(handle) .build()) .toByteArray()); setParameters(parameterBindingRoot == null ? VectorSchemaRoot.of() : parameterBindingRoot); @@ -1434,8 +1437,7 @@ public void close(final CallOption... options) { FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType(), Any.pack( ActionClosePreparedStatementRequest.newBuilder() - .setPreparedStatementHandle( - preparedStatementResult.getPreparedStatementHandle()) + .setPreparedStatementHandle(handle) .build()) .toByteArray()); final Iterator closePreparedStatementResults = client.doAction(action, options); From ea8d7d1c007ff7661789ee657a4737e6cfa9ceb4 Mon Sep 17 00:00:00 2001 From: Pedro Matias Date: Sat, 18 Apr 2026 23:58:11 +0100 Subject: [PATCH 2/2] Add test --- .../arrow/flight/sql/test/TestFlightSql.java | 151 ++++++++++++++++++ 1 file changed, 151 insertions(+) diff --git a/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java b/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java index e2934ab1e9..22c7336c1b 100644 --- a/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java +++ b/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/test/TestFlightSql.java @@ -27,9 +27,13 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.PipedInputStream; import java.io.PipedOutputStream; +import java.nio.channels.Channels; import java.nio.charset.StandardCharsets; import java.sql.SQLException; import java.util.ArrayList; @@ -42,17 +46,21 @@ import java.util.stream.IntStream; import org.apache.arrow.flight.CancelFlightInfoRequest; import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightServer; import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.flight.FlightStream; import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.PutResult; import org.apache.arrow.flight.RenewFlightEndpointRequest; +import org.apache.arrow.flight.Result; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.NoOpFlightSqlProducer; import org.apache.arrow.flight.sql.example.FlightSqlExample; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions; @@ -60,6 +68,7 @@ import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementIngest.TableDefinitionOptions.TableNotExistOption; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; @@ -67,6 +76,8 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -1594,4 +1605,144 @@ public void testRenewEndpoint() { new RenewFlightEndpointRequest(info.getEndpoints().get(0)))); assertEquals(FlightStatusCode.UNIMPLEMENTED, fre.status().code()); } + + @Test + public void testPreparedStatementUsesUpdatedHandleAfterDoPut() throws Exception { + final ByteString originalHandle = ByteString.copyFromUtf8("original-handle"); + final ByteString updatedHandle = ByteString.copyFromUtf8("updated-handle"); + + try (BufferAllocator testAllocator = new RootAllocator(Integer.MAX_VALUE)) { + final Schema paramSchema = + new Schema(singletonList(Field.nullable("id", MinorType.INT.getType()))); + final UpdatedHandleFlightSqlProducer mockProducer = + new UpdatedHandleFlightSqlProducer( + testAllocator, originalHandle, updatedHandle, paramSchema); + + try (FlightServer testServer = + FlightServer.builder( + testAllocator, Location.forGrpcInsecure(LOCALHOST, 0), mockProducer) + .build() + .start(); + FlightSqlClient testClient = + new FlightSqlClient( + FlightClient.builder( + testAllocator, Location.forGrpcInsecure(LOCALHOST, testServer.getPort())) + .build())) { + + try (PreparedStatement ps = testClient.prepare("test query with param=?"); + VectorSchemaRoot params = VectorSchemaRoot.create(paramSchema, testAllocator)) { + final IntVector v = (IntVector) params.getVector(0); + v.setSafe(0, 42); + params.setRowCount(1); + ps.setParameters(params); + ps.execute(); // DoPut → server returns updatedHandle in DoPutPreparedStatementResult + } // close() called here via try-with-resources + + assertAll( + () -> + assertThat(mockProducer.executeHandle) + .as("getFlightInfoPreparedStatement must use the updated handle") + .isEqualTo(updatedHandle), + () -> + assertThat(mockProducer.closeHandle) + .as("ClosePreparedStatement must use the updated handle") + .isEqualTo(updatedHandle)); + } + } + } + + /** + * Minimal producer that returns an updated prepared-statement handle in the {@code + * CommandPreparedStatementQuery} used with {@code DoPut} and records which handle is used in + * subsequent operations, allowing the test to verify that the client propagates the updated + * handle correctly. + */ + private static final class UpdatedHandleFlightSqlProducer extends NoOpFlightSqlProducer { + + private final BufferAllocator allocator; + private final ByteString originalHandle; + private final ByteString updatedHandle; + private final ByteString serializedParamSchema; + ByteString executeHandle; + ByteString closeHandle; + + UpdatedHandleFlightSqlProducer( + BufferAllocator allocator, + ByteString originalHandle, + ByteString updatedHandle, + Schema paramSchema) { + this.allocator = allocator; + this.originalHandle = originalHandle; + this.updatedHandle = updatedHandle; + this.serializedParamSchema = serializeSchema(paramSchema); + } + + private static ByteString serializeSchema(Schema schema) { + try { + final ByteArrayOutputStream out = new ByteArrayOutputStream(); + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(out)), schema); + return ByteString.copyFrom(out.toByteArray()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void createPreparedStatement( + FlightSql.ActionCreatePreparedStatementRequest request, + CallContext context, + StreamListener listener) { + listener.onNext( + new Result( + Any.pack( + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(originalHandle) + .setParameterSchema(serializedParamSchema) + .build()) + .toByteArray())); + listener.onCompleted(); + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + FlightSql.CommandPreparedStatementQuery command, + CallContext context, + FlightStream flightStream, + StreamListener ackStream) { + return () -> { + while (flightStream.next()) { + // consume parameter batches + } + final byte[] responseBytes = + FlightSql.DoPutPreparedStatementResult.newBuilder() + .setPreparedStatementHandle(updatedHandle) + .build() + .toByteArray(); + final ArrowBuf buf = allocator.buffer(responseBytes.length); + buf.writeBytes(responseBytes); + try (PutResult putResult = PutResult.metadata(buf)) { + ackStream.onNext(putResult); + ackStream.onCompleted(); + } + }; + } + + @Override + public FlightInfo getFlightInfoPreparedStatement( + FlightSql.CommandPreparedStatementQuery command, + CallContext context, + FlightDescriptor descriptor) { + executeHandle = command.getPreparedStatementHandle(); + return new FlightInfo(new Schema(emptyList()), descriptor, emptyList(), -1, -1); + } + + @Override + public void closePreparedStatement( + FlightSql.ActionClosePreparedStatementRequest request, + CallContext context, + StreamListener listener) { + closeHandle = request.getPreparedStatementHandle(); + listener.onCompleted(); + } + } }