diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index 166c1157a1..107cfa0c2f 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -282,7 +282,12 @@ public PreparedStatement prepareStatement( final int resultSetHoldability) throws SQLException { checkOpen(); - return getMeta() - .createPreparedStatement(sql, resultSetType, resultSetConcurrency, resultSetHoldability); + return ArrowFlightPreparedStatement.builder(this) + .withQuery(sql) + .withGeneratedHandle() + .withResultSetType(resultSetType) + .withResultSetConcurrency(resultSetConcurrency) + .withResultSetHoldability(resultSetHoldability) + .build(); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java deleted file mode 100644 index 37ee93722a..0000000000 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.arrow.driver.jdbc; - -import java.sql.SQLException; -import java.sql.Statement; -import org.apache.arrow.flight.FlightInfo; - -/** A {@link Statement} that deals with {@link FlightInfo}. */ -public interface ArrowFlightInfoStatement extends Statement { - - @Override - ArrowFlightConnection getConnection() throws SQLException; - - /** - * Executes the query this {@link Statement} is holding. - * - * @return the {@link FlightInfo} for the results. - * @throws SQLException on error. - */ - FlightInfo executeFlightInfoQuery() throws SQLException; -} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index f230c95340..d383d239d1 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -67,7 +67,7 @@ public final class ArrowFlightJdbcFlightStreamResultSet throws SQLException { super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); this.connection = (ArrowFlightConnection) statement.connection; - this.flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + this.flightInfo = ((ArrowFlightMetaStatement) statement).executeFlightInfoQuery(); } /** Private constructor for fromFlightInfo. */ diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index e0a8727630..4da182ca63 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -48,10 +48,7 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) { @Override public void closeStatement(final StatementHandle statementHandle) { - AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (statement instanceof ArrowFlightPreparedStatement) { - ((ArrowFlightPreparedStatement) statement).closePreparedResources(); - } + getMetaStatement(statementHandle).closeStatement(); } @Override @@ -64,8 +61,7 @@ public ExecuteResult execute( final StatementHandle statementHandle, final List typedValues, final long maxRowCount) { - return getPreparedStatementInstance(statementHandle) - .executeWithTypedValues(statementHandle, typedValues, maxRowCount); + return getMetaStatement(statementHandle).execute(statementHandle, typedValues, maxRowCount); } @Override @@ -80,8 +76,7 @@ public ExecuteResult execute( public ExecuteBatchResult executeBatch( final StatementHandle statementHandle, final List> parameterValuesList) throws IllegalStateException { - return getPreparedStatementInstance(statementHandle) - .executeBatchWithTypedValues(statementHandle, parameterValuesList); + return getMetaStatement(statementHandle).executeBatch(statementHandle, parameterValuesList); } @Override @@ -96,31 +91,16 @@ public Frame fetch( String.format("%s does not use frames.", this), AvaticaConnection.HELPER.unsupported()); } - ArrowFlightPreparedStatement createPreparedStatement( - final String query, - final int resultSetType, - final int resultSetConcurrency, - final int resultSetHoldability) - throws SQLException { - return ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) - .withQuery(query) - .withGeneratedHandle() - .withResultSetType(resultSetType) - .withResultSetConcurrency(resultSetConcurrency) - .withResultSetHoldability(resultSetHoldability) - .build(); - } - @Override public StatementHandle prepare( final ConnectionHandle connectionHandle, final String query, final long maxRowCount) { try { - return createPreparedStatement( - query, - ResultSet.TYPE_FORWARD_ONLY, - ResultSet.CONCUR_READ_ONLY, - connection.getHoldability()) - .handle; + // This is the Avatica entry point used by Connection.prepareStatement(String). + ArrowFlightPreparedStatement stmt = + (ArrowFlightPreparedStatement) + connection.prepareStatement( + query, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + return stmt.handle; } catch (SQLException e) { throw new RuntimeException(e); } @@ -133,6 +113,7 @@ public ExecuteResult prepareAndExecute( final long maxRowCount, final PrepareCallback prepareCallback) throws NoSuchStatementException { + // This is the Avatica entry point used by Statement.execute(String). return prepareAndExecute( statementHandle, query, maxRowCount, -1 /* Not used */, prepareCallback); } @@ -146,20 +127,9 @@ public ExecuteResult prepareAndExecute( final PrepareCallback callback) throws NoSuchStatementException { try { - final AvaticaStatement statement = connection.statementMap.get(handle.id); - if (!(statement instanceof ArrowFlightStatement) - && !(statement instanceof ArrowFlightPreparedStatement)) { - throw new IllegalStateException("Prepared statement not found: " + handle); - } - if (statement instanceof ArrowFlightPreparedStatement) { - ((ArrowFlightPreparedStatement) statement).closePreparedResources(); - } - final ArrowFlightPreparedStatement preparedStatement = - ArrowFlightPreparedStatement.builder((ArrowFlightConnection) connection) - .withQuery(query) - .withExistingStatement(statement) - .build(); - return preparedStatement.prepareAndExecute(callback); + // This is the Avatica entry point used by Statement.execute(String). + return getMetaStatement(handle) + .prepareAndExecute(query, maxRowCount, maxRowsInFirstFrame, callback); } catch (SQLTimeoutException e) { // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and // Runtime @@ -216,30 +186,37 @@ void setDefaultConnectionProperties() { .setTransactionIsolation(Connection.TRANSACTION_NONE); } - private ArrowFlightPreparedStatement getPreparedStatementInstance( - StatementHandle statementHandle) { + private ArrowFlightMetaStatement getMetaStatement(StatementHandle statementHandle) { AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (!(statement instanceof ArrowFlightPreparedStatement)) { - throw new IllegalStateException("Prepared statement not found: " + statementHandle); + if (statement instanceof ArrowFlightMetaStatement) { + return (ArrowFlightMetaStatement) statement; } - return (ArrowFlightPreparedStatement) statement; + throw new IllegalStateException("Statement not found: " + statementHandle); } - ArrowFlightPreparedStatement getPreparedStatementInstanceOrNull(StatementHandle statementHandle) { - AvaticaStatement statement = connection.statementMap.get(statementHandle.id); - if (statement instanceof ArrowFlightPreparedStatement) { - return (ArrowFlightPreparedStatement) statement; - } - return null; + public static Signature buildDefaultSignature() { + return buildSignature(null, StatementType.SELECT); } - public static Signature buildDefaultSignature() { - return buildSignature(null, null, null); + public static Signature buildSignature(final String sql, final StatementType type) { + return buildSignature(sql, null, null, type); } /** Builds an Avatica signature from Arrow result and parameter schemas. */ public static Signature buildSignature( final String sql, final Schema resultSetSchema, final Schema parameterSchema) { + StatementType statementType = + resultSetSchema == null || resultSetSchema.getFields().isEmpty() + ? StatementType.IS_DML + : StatementType.SELECT; + return buildSignature(sql, resultSetSchema, parameterSchema, statementType); + } + + private static Signature buildSignature( + final String sql, + final Schema resultSetSchema, + final Schema parameterSchema, + final StatementType statementType) { List columnMetaData = resultSetSchema == null ? new ArrayList<>() @@ -248,10 +225,6 @@ public static Signature buildSignature( parameterSchema == null ? new ArrayList<>() : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields()); - StatementType statementType = - resultSetSchema == null || resultSetSchema.getFields().isEmpty() - ? StatementType.IS_DML - : StatementType.SELECT; return new Signature( columnMetaData, sql, diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java new file mode 100644 index 0000000000..415af19e8f --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaStatement.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.sql.Statement; +import java.util.List; +import org.apache.arrow.flight.FlightInfo; +import org.apache.calcite.avatica.Meta.ExecuteBatchResult; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; +import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.remote.TypedValue; + +/** Statement capabilities used by {@link ArrowFlightMetaImpl}. */ +interface ArrowFlightMetaStatement extends Statement { + + @Override + ArrowFlightConnection getConnection() throws SQLException; + + FlightInfo executeFlightInfoQuery() throws SQLException; + + /** + * Avatica routes {@link Statement#execute(String)} through Meta.prepareAndExecute(...), so plain + * statements still need this hook even when they support direct executeQuery/executeUpdate paths. + */ + ExecuteResult prepareAndExecute( + String query, long maxRowCount, int maxRowsInFirstFrame, PrepareCallback callback) + throws SQLException; + + default ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + throw new IllegalStateException( + "Statement operation is not supported for handle: " + statementHandle); + } + + default void closeStatement() {} +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index 1c6f0cdb21..bd7ebbe0e4 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -37,7 +37,7 @@ /** Arrow Flight JDBC's implementation {@link java.sql.PreparedStatement}. */ public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement - implements ArrowFlightInfoStatement { + implements ArrowFlightMetaStatement { private ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; @@ -80,6 +80,21 @@ ExecuteResult prepareAndExecute(final PrepareCallback callback) throws SQLExcept return new ExecuteResult(Collections.singletonList(metaResultSet)); } + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + Schema getDataSetSchema() { ensurePrepared(); return preparedStatement.getDataSetSchema(); @@ -143,6 +158,25 @@ ExecuteBatchResult executeBatchWithTypedValues( return new ExecuteBatchResult(updatedCounts); } + @Override + public ExecuteResult execute( + final StatementHandle statementHandle, + final List typedValues, + final long maxRowCount) { + return executeWithTypedValues(statementHandle, typedValues, maxRowCount); + } + + @Override + public ExecuteBatchResult executeBatch( + final StatementHandle statementHandle, final List> parameterValuesList) { + return executeBatchWithTypedValues(statementHandle, parameterValuesList); + } + + @Override + public void closeStatement() { + closePreparedResources(); + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { ensurePrepared(); diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java index 9e514ccc9f..0df8f20d2a 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -16,16 +16,24 @@ */ package org.apache.arrow.driver.jdbc; +import java.sql.ResultSet; import java.sql.SQLException; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStatusCode; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaResultSet; import org.apache.calcite.avatica.AvaticaStatement; import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.ExecuteResult; +import org.apache.calcite.avatica.Meta.PrepareCallback; import org.apache.calcite.avatica.Meta.StatementHandle; +import org.apache.calcite.avatica.Meta.StatementType; /** A SQL statement for querying data from an Arrow Flight server. */ -public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightInfoStatement { +public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightMetaStatement { ArrowFlightStatement( final ArrowFlightConnection connection, @@ -41,23 +49,140 @@ public ArrowFlightConnection getConnection() throws SQLException { return (ArrowFlightConnection) super.getConnection(); } + @Override + public ExecuteResult prepareAndExecute( + final String query, + final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws SQLException { + // Keep Avatica Statement.execute(String) behavior: Avatica calls Meta.prepareAndExecute, + // which resolves to this statement hook. + this.closeStatement(); + + return ArrowFlightPreparedStatement.builder(getConnection()) + .withQuery(query) + .withExistingStatement(this) + .build() + .prepareAndExecute(callback); + } + + @Override + public ResultSet executeQuery(final String sql) throws SQLException { + checkOpen(); + updateCount = -1; + switchToDirectStatementMode(); + try { + final Meta.Signature signature = + ArrowFlightMetaImpl.buildSignature(sql, StatementType.SELECT); + setSignature(signature); + return executeQueryInternal(signature, false); + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + + @Override + public long executeLargeUpdate(final String sql) throws SQLException { + checkOpen(); + clearOpenResultSet(); + updateCount = -1; + switchToDirectStatementMode(); + + try { + final long updatedCount = getConnection().getClientHandler().executeUpdate(sql); + setSignature(ArrowFlightMetaImpl.buildSignature(sql, StatementType.IS_DML)); + updateCount = updatedCount; + return updatedCount; + } catch (Exception exception) { + throw wrapStatementExecutionException(sql, exception); + } + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { - final ArrowFlightPreparedStatement preparedStatement = - getConnection().getMeta().getPreparedStatementInstanceOrNull(handle); + final ArrowFlightConnection connection = getConnection(); final Meta.Signature signature = getSignature(); if (signature == null) { return null; } - if (preparedStatement != null) { - final Schema resultSetSchema = preparedStatement.getDataSetSchema(); + // A Statement handle can point to either this direct statement instance or a prepared + // statement instance created by Avatica Statement.execute(String) through + // Meta.prepareAndExecute. + final AvaticaStatement currentStatement = connection.statementMap.get(handle.id); + if (currentStatement instanceof ArrowFlightPreparedStatement) { + // Prepared path: reuse the current statement implementation associated with the handle. + final FlightInfo flightInfo = + ((ArrowFlightPreparedStatement) currentStatement).executeFlightInfoQuery(); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } + + // Direct Statement.executeQuery(String) / executeUpdate(String) path. + final FlightInfo flightInfo = connection.getClientHandler().getInfo(signature.sql); + updateSignatureColumnsFromFlightInfo(signature, flightInfo); + return flightInfo; + } + + private void updateSignatureColumnsFromFlightInfo( + final Meta.Signature signature, final FlightInfo flightInfo) { + final Schema resultSetSchema = flightInfo.getSchemaOptional().orElse(null); + if (resultSetSchema != null) { signature.columns.addAll( ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); setSignature(signature); - return preparedStatement.executeFlightInfoQuery(); } + } + + private SQLException wrapStatementExecutionException(final String sql, final Exception exception) + throws SQLException { + if (!(exception instanceof SQLException)) { + return AvaticaConnection.HELPER.createException( + "Error while executing SQL \"" + sql + "\": " + exception.getMessage(), exception); + } + final SQLException sqlException = (SQLException) exception; + final String prefix = "Error while executing SQL \"" + sql + "\""; + final String message = sqlException.getMessage(); + if (message != null && message.startsWith(prefix)) { + return sqlException; + } + final Throwable cause = sqlException.getCause(); + if (cause instanceof FlightRuntimeException) { + final FlightStatusCode statusCode = ((FlightRuntimeException) cause).status().code(); + if (statusCode == FlightStatusCode.UNAVAILABLE) { + return sqlException; + } + } + return AvaticaConnection.HELPER.createException(prefix + ": " + message, sqlException); + } - throw new IllegalStateException("Prepared statement query not found: " + handle); + private void clearOpenResultSet() throws SQLException { + synchronized (this) { + if (openResultSet != null) { + final AvaticaResultSet resultSet = openResultSet; + openResultSet = null; + try { + resultSet.close(); + } catch (Exception exception) { + throw AvaticaConnection.HELPER.createException( + "Error while closing previous result set", exception); + } + } + } + } + + private void switchToDirectStatementMode() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final AvaticaStatement existingStatement = connection.statementMap.get(handle.id); + if (existingStatement == this) { + return; + } + if (existingStatement instanceof ArrowFlightPreparedStatement) { + // Release resources from previously attached statement implementation before switching back + // to direct statement mode for executeQuery/executeUpdate. + ((ArrowFlightPreparedStatement) existingStatement).closeStatement(); + } + connection.statementMap.put(handle.id, this); } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index f0ea284239..08b2c5f93e 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -267,6 +267,16 @@ public FlightInfo getInfo(final String query) { return sqlClient.execute(query, getOptions()); } + /** + * Executes an update query directly, without creating a prepared statement first. + * + * @param query The update query. + * @return the number of rows affected. + */ + public long executeUpdate(final String query) { + return sqlClient.executeUpdate(query, getOptions()); + } + @Override public void close() throws SQLException { if (catalog.isPresent()) { diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java new file mode 100644 index 0000000000..10ce0bd285 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/example/ArrowFlightJdbcSampleApp.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.example; + +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; + +/** + * Minimal sample app for using the Arrow Flight SQL JDBC driver. + * + *

Defaults are configured for a local Dremio instance: + * + *

    + *
  • host: {@code localhost} + *
  • port: {@code 32010} + *
  • user: {@code dremio} + *
  • password: {@code dremio123} + *
+ * + *

Arguments are optional and positional: + * + *

+ *   [host] [port] [user] [password] [selectSql] [updateSql]
+ * 
+ * + *

If {@code updateSql} is omitted, only {@code Statement.executeQuery(...)} is executed. + */ +public final class ArrowFlightJdbcSampleApp { + private static final String DEFAULT_HOST = "localhost"; + private static final int DEFAULT_PORT = 32010; + private static final String DEFAULT_USER = "dremio"; + private static final String DEFAULT_PASSWORD = "dremio123"; + private static final String DEFAULT_SELECT_SQL = "SELECT 1 AS sample_value"; + + private ArrowFlightJdbcSampleApp() {} + + public static void main(final String[] args) throws Exception { + final String host = getArg(args, 0, DEFAULT_HOST); + final int port = Integer.parseInt(getArg(args, 1, Integer.toString(DEFAULT_PORT))); + final String user = getArg(args, 2, DEFAULT_USER); + final String password = getArg(args, 3, DEFAULT_PASSWORD); + final String selectSql = getArg(args, 4, DEFAULT_SELECT_SQL); + final String updateSql = getArg(args, 5, ""); + + final String url = String.format("jdbc:arrow-flight-sql://%s:%d", host, port); + final Properties properties = new Properties(); + properties.setProperty("user", user); + properties.setProperty("password", password); + properties.setProperty("useEncryption", "false"); + + System.out.println("Connecting to " + url); + try (Connection connection = DriverManager.getConnection(url, properties); + Statement statement = connection.createStatement()) { + runSelect(statement, selectSql); + + if (updateSql.isEmpty()) { + System.out.println( + "Skipping Statement.executeUpdate(...) because no updateSql argument was provided."); + } else { + runUpdate(statement, updateSql); + } + } + } + + private static void runSelect(final Statement statement, final String selectSql) + throws SQLException { + System.out.println("Running Statement.executeQuery: " + selectSql); + try (ResultSet resultSet = statement.executeQuery(selectSql)) { + final ResultSetMetaData metadata = resultSet.getMetaData(); + final int columnCount = metadata.getColumnCount(); + int rowCount = 0; + while (resultSet.next()) { + rowCount++; + final StringBuilder rowBuilder = new StringBuilder(); + for (int i = 1; i <= columnCount; i++) { + if (i > 1) { + rowBuilder.append(", "); + } + rowBuilder.append(metadata.getColumnLabel(i)).append('=').append(resultSet.getObject(i)); + } + System.out.println("row " + rowCount + ": " + rowBuilder); + } + System.out.println("Statement.executeQuery returned " + rowCount + " row(s)"); + } + } + + private static void runUpdate(final Statement statement, final String updateSql) + throws SQLException { + System.out.println("Running Statement.executeUpdate: " + updateSql); + final int updateCount = statement.executeUpdate(updateSql); + System.out.println("Statement.executeUpdate affected " + updateCount + " row(s)"); + } + + private static String getArg(final String[] args, final int index, final String defaultValue) { + if (index >= args.length) { + return defaultValue; + } + final String arg = args[index]; + return arg == null || arg.isEmpty() ? defaultValue : arg; + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java index d4e7a0953d..138a8e5b76 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -109,15 +109,10 @@ public void testPrepareStatementRegistersCreatedStatementByGeneratedHandle() thr final ArrowFlightPreparedStatement arrowPreparedStatement = (ArrowFlightPreparedStatement) preparedStatement; - assertNotNull( - flightConnection - .getMeta() - .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + assertNotNull(flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); assertSame( arrowPreparedStatement, - flightConnection - .getMeta() - .getPreparedStatementInstanceOrNull(arrowPreparedStatement.handle)); + flightConnection.statementMap.get(arrowPreparedStatement.handle.id)); } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java index e4df71967b..20e2059722 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -18,6 +18,7 @@ import static org.hamcrest.CoreMatchers.allOf; import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.CoreMatchers.nullValue; @@ -146,12 +147,25 @@ public void testExecuteReplacesStatementMapEntryWithPreparedStatement() throws S assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); - final ArrowFlightPreparedStatement preparedStatement = - arrowConnection.getMeta().getPreparedStatementInstanceOrNull(arrowStatement.handle); + final Object preparedStatement = arrowConnection.statementMap.get(arrowStatement.handle.id); assertNotNull(preparedStatement); assertSame(preparedStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); - assertThat(preparedStatement.handle.id, is(equalTo(arrowStatement.handle.id))); + assertThat(preparedStatement, instanceOf(ArrowFlightPreparedStatement.class)); + } + + @Test + public void testExecuteQueryRestoresStatementMapEntryWithStatement() throws SQLException { + final ArrowFlightStatement arrowStatement = (ArrowFlightStatement) statement; + final ArrowFlightConnection arrowConnection = (ArrowFlightConnection) connection; + + assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); + + try (ResultSet resultSet = statement.executeQuery(SAMPLE_QUERY_CMD)) { + assertThat(resultSet.next(), is(true)); + } + + assertSame(arrowStatement, arrowConnection.statementMap.get(arrowStatement.handle.id)); } @Test diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java index f7c31c590c..05e85227f0 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java @@ -22,6 +22,7 @@ import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.startsWith; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -224,4 +225,13 @@ public void testShouldFailToPrepareStatementForBadStatement() { } assertThat(count, is(1)); } + + @Test + public void testExecuteLargeUpdateShouldWrapBadStatement() { + final String badQuery = "BAD INVALID UPDATE"; + final SQLException exception = + assertThrows(SQLException.class, () -> statement.executeLargeUpdate(badQuery)); + assertThat( + exception.getMessage(), startsWith(format("Error while executing SQL \"%s\"", badQuery))); + } } diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java new file mode 100644 index 0000000000..c5c35a9173 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementProtocolTest.java @@ -0,0 +1,376 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.protobuf.Message; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.function.Consumer; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.FlightSqlUtils; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class ArrowFlightStatementProtocolTest { + private static final String SELECT_QUERY = "SELECT * FROM PROTOCOL_SELECT"; + private static final String UPDATE_QUERY = "UPDATE PROTOCOL_UPDATE"; + private static final Schema QUERY_SCHEMA = + new Schema(Collections.singletonList(Field.nullable("id", MinorType.INT.getType()))); + + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + @RegisterExtension + public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION = + FlightServerTestExtension.createStandardTestExtension(PRODUCER); + + private Connection connection; + + @BeforeAll + public static void setUpBeforeClass() { + PRODUCER.addSelectQuery( + SELECT_QUERY, + QUERY_SCHEMA, + Collections.singletonList( + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(QUERY_SCHEMA, allocator)) { + IntVector vector = (IntVector) root.getVector("id"); + vector.setSafe(0, 1); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + PRODUCER.addUpdateQuery(UPDATE_QUERY, 1); + + final Message commandGetDbSchemas = CommandGetDbSchemas.getDefaultInstance(); + final Consumer commandGetSchemasResultProducer = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = + VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + catalogName.setSafe(0, new Text("catalog_name #0")); + schemaName.setSafe(0, new Text("db_schema_name #0")); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + PRODUCER.addCatalogQuery(commandGetDbSchemas, commandGetSchemasResultProducer); + } + + @BeforeEach + public void setUp() throws SQLException { + PRODUCER.clearActionTypeCounter(); + PRODUCER.clearCommandTypeCounter(); + connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false); + } + + @AfterEach + public void tearDown() throws Exception { + AutoCloseables.close(connection); + } + + @AfterAll + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testStatementExecuteQueryUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testStatementExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(0)); + } + + @Test + public void testStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteThenExecuteUpdateUsesStatementProtocol() throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(SELECT_QUERY), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + assertThat(statement.executeUpdate(UPDATE_QUERY), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testStatementExecuteUpdateThenExecuteQueryUsesStatementProtocol() + throws SQLException { + try (Statement statement = connection.createStatement()) { + assertThat(statement.execute(UPDATE_QUERY), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + try (ResultSet resultSet = statement.executeQuery(SELECT_QUERY)) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteQueryUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY); + ResultSet resultSet = statement.executeQuery()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForQuery() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(SELECT_QUERY)) { + assertThat(statement.execute(), is(true)); + try (ResultSet resultSet = statement.getResultSet()) { + assertTrue(resultSet.next()); + } + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_QUERY, 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_STATEMENT_QUERY, 0), + is(0)); + } + + @Test + public void testPreparedStatementExecuteUpdateUsesPreparedProtocol() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.executeUpdate(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testPreparedStatementExecuteUsesPreparedProtocolForUpdate() throws SQLException { + try (PreparedStatement statement = connection.prepareStatement(UPDATE_QUERY)) { + assertThat(statement.execute(), is(false)); + assertThat(statement.getUpdateCount(), is(1)); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(1)); + assertThat( + PRODUCER + .getCommandTypeCounter() + .getOrDefault(MockFlightSqlProducer.COMMAND_PREPARED_STATEMENT_UPDATE, 0), + is(1)); + } + + @Test + public void testMetadataGetSchemasUsesJdbcApi() throws SQLException { + final DatabaseMetaData metaData = connection.getMetaData(); + try (ResultSet resultSet = metaData.getSchemas()) { + assertTrue(resultSet.next()); + } + + assertThat( + PRODUCER + .getActionTypeCounter() + .getOrDefault(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), 0), + is(0)); + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java index 45c2a96404..230c1346fb 100644 --- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -103,6 +103,12 @@ public final class MockFlightSqlProducer implements FlightSqlProducer { private final Map>> expectedParameterValues = new HashMap<>(); private final Map actionTypeCounter = new HashMap<>(); + private final Map commandTypeCounter = new HashMap<>(); + + public static final String COMMAND_STATEMENT_QUERY = "statement_query"; + public static final String COMMAND_STATEMENT_UPDATE = "statement_update"; + public static final String COMMAND_PREPARED_STATEMENT_QUERY = "prepared_statement_query"; + public static final String COMMAND_PREPARED_STATEMENT_UPDATE = "prepared_statement_update"; private static FlightInfo getFlightInfoExportedAndImportedKeys( final Message message, final FlightDescriptor descriptor) { @@ -269,6 +275,7 @@ public FlightInfo getFlightInfoStatement( final CommandStatementQuery commandStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_STATEMENT_QUERY); final String query = commandStatementQuery.getQuery(); final Entry> queryInfo = Preconditions.checkNotNull( @@ -289,6 +296,7 @@ public FlightInfo getFlightInfoPreparedStatement( final CommandPreparedStatementQuery commandPreparedStatementQuery, final CallContext callContext, final FlightDescriptor flightDescriptor) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_QUERY); final ByteString preparedStatementHandle = commandPreparedStatementQuery.getPreparedStatementHandle(); @@ -356,6 +364,7 @@ public Runnable acceptPutStatement( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_STATEMENT_UPDATE); return () -> { final String query = commandStatementUpdate.getQuery(); final BiConsumer> resultProvider = @@ -429,6 +438,7 @@ public Runnable acceptPutPreparedStatementUpdate( final CallContext callContext, final FlightStream flightStream, final StreamListener streamListener) { + incrementCommandTypeCounter(COMMAND_PREPARED_STATEMENT_UPDATE); final ByteString handle = commandPreparedStatementUpdate.getPreparedStatementHandle(); final String query = Preconditions.checkNotNull( @@ -651,10 +661,22 @@ public void clearActionTypeCounter() { actionTypeCounter.clear(); } + public void clearCommandTypeCounter() { + commandTypeCounter.clear(); + } + public Map getActionTypeCounter() { return actionTypeCounter; } + public Map getCommandTypeCounter() { + return commandTypeCounter; + } + + private void incrementCommandTypeCounter(String commandType) { + commandTypeCounter.put(commandType, commandTypeCounter.getOrDefault(commandType, 0) + 1); + } + private void getStreamCatalogFunctions( final Message ticket, final ServerStreamListener serverStreamListener) { Preconditions.checkNotNull(