diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java index c7c6e65d3..8b9ede1db 100644 --- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -19,6 +19,7 @@ import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; import io.netty.util.concurrent.DefaultThreadFactory; +import java.sql.PreparedStatement; import java.sql.SQLException; import java.util.ArrayList; import java.util.HashMap; @@ -260,4 +261,40 @@ BufferAllocator getBufferAllocator() { public ArrowFlightMetaImpl getMeta() { return (ArrowFlightMetaImpl) this.meta; } + + @Override + public PreparedStatement prepareStatement(String sql) throws SQLException { + return NamedParamStatement.wrap(sql, s -> super.prepareStatement(s)); + } + + @Override + public PreparedStatement prepareStatement(String sql, int resultSetType, int resultSetConcurrency) + throws SQLException { + return NamedParamStatement.wrap( + sql, s -> super.prepareStatement(s, resultSetType, resultSetConcurrency)); + } + + @Override + public PreparedStatement prepareStatement( + String sql, int resultSetType, int resultSetConcurrency, int resultSetHoldability) + throws SQLException { + return NamedParamStatement.wrap( + sql, + s -> super.prepareStatement(s, resultSetType, resultSetConcurrency, resultSetHoldability)); + } + + @Override + public PreparedStatement prepareStatement(String sql, int autoGeneratedKeys) throws SQLException { + return NamedParamStatement.wrap(sql, s -> super.prepareStatement(s, autoGeneratedKeys)); + } + + @Override + public PreparedStatement prepareStatement(String sql, int[] columnIndexes) throws SQLException { + return NamedParamStatement.wrap(sql, s -> super.prepareStatement(s, columnIndexes)); + } + + @Override + public PreparedStatement prepareStatement(String sql, String[] columnNames) throws SQLException { + return NamedParamStatement.wrap(sql, s -> super.prepareStatement(s, columnNames)); + } } diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ForwardingPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ForwardingPreparedStatement.java new file mode 100644 index 000000000..41cf9eaac --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ForwardingPreparedStatement.java @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.Date; +import java.sql.NClob; +import java.sql.ParameterMetaData; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; + +/** + * Abstract {@link PreparedStatement} decorator that forwards every method to a {@link #delegate()}. + * + *

Subclasses override only the methods they need to customize; all other calls are forwarded + * transparently. + */ +abstract class ForwardingPreparedStatement implements PreparedStatement { + + /** Returns the underlying {@link PreparedStatement} to which all calls are forwarded. */ + protected abstract PreparedStatement delegate(); + + // --- PreparedStatement --- + + @Override + public ResultSet executeQuery() throws SQLException { + return delegate().executeQuery(); + } + + @Override + public int executeUpdate() throws SQLException { + return delegate().executeUpdate(); + } + + @Override + public void setNull(int parameterIndex, int sqlType) throws SQLException { + delegate().setNull(parameterIndex, sqlType); + } + + @Override + public void setBoolean(int parameterIndex, boolean x) throws SQLException { + delegate().setBoolean(parameterIndex, x); + } + + @Override + public void setByte(int parameterIndex, byte x) throws SQLException { + delegate().setByte(parameterIndex, x); + } + + @Override + public void setShort(int parameterIndex, short x) throws SQLException { + delegate().setShort(parameterIndex, x); + } + + @Override + public void setInt(int parameterIndex, int x) throws SQLException { + delegate().setInt(parameterIndex, x); + } + + @Override + public void setLong(int parameterIndex, long x) throws SQLException { + delegate().setLong(parameterIndex, x); + } + + @Override + public void setFloat(int parameterIndex, float x) throws SQLException { + delegate().setFloat(parameterIndex, x); + } + + @Override + public void setDouble(int parameterIndex, double x) throws SQLException { + delegate().setDouble(parameterIndex, x); + } + + @Override + public void setBigDecimal(int parameterIndex, BigDecimal x) throws SQLException { + delegate().setBigDecimal(parameterIndex, x); + } + + @Override + public void setString(int parameterIndex, String x) throws SQLException { + delegate().setString(parameterIndex, x); + } + + @Override + public void setBytes(int parameterIndex, byte[] x) throws SQLException { + delegate().setBytes(parameterIndex, x); + } + + @Override + public void setDate(int parameterIndex, Date x) throws SQLException { + delegate().setDate(parameterIndex, x); + } + + @Override + public void setTime(int parameterIndex, Time x) throws SQLException { + delegate().setTime(parameterIndex, x); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x) throws SQLException { + delegate().setTimestamp(parameterIndex, x); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate().setAsciiStream(parameterIndex, x, length); + } + + @Override + @Deprecated + public void setUnicodeStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate().setUnicodeStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, int length) throws SQLException { + delegate().setBinaryStream(parameterIndex, x, length); + } + + @Override + public void clearParameters() throws SQLException { + delegate().clearParameters(); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType) throws SQLException { + delegate().setObject(parameterIndex, x, targetSqlType); + } + + @Override + public void setObject(int parameterIndex, Object x) throws SQLException { + delegate().setObject(parameterIndex, x); + } + + @Override + public boolean execute() throws SQLException { + return delegate().execute(); + } + + @Override + public void addBatch() throws SQLException { + delegate().addBatch(); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, int length) + throws SQLException { + delegate().setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setRef(int parameterIndex, Ref x) throws SQLException { + delegate().setRef(parameterIndex, x); + } + + @Override + public void setBlob(int parameterIndex, Blob x) throws SQLException { + delegate().setBlob(parameterIndex, x); + } + + @Override + public void setClob(int parameterIndex, Clob x) throws SQLException { + delegate().setClob(parameterIndex, x); + } + + @Override + public void setArray(int parameterIndex, Array x) throws SQLException { + delegate().setArray(parameterIndex, x); + } + + @Override + public ResultSetMetaData getMetaData() throws SQLException { + return delegate().getMetaData(); + } + + @Override + public void setDate(int parameterIndex, Date x, Calendar cal) throws SQLException { + delegate().setDate(parameterIndex, x, cal); + } + + @Override + public void setTime(int parameterIndex, Time x, Calendar cal) throws SQLException { + delegate().setTime(parameterIndex, x, cal); + } + + @Override + public void setTimestamp(int parameterIndex, Timestamp x, Calendar cal) throws SQLException { + delegate().setTimestamp(parameterIndex, x, cal); + } + + @Override + public void setNull(int parameterIndex, int sqlType, String typeName) throws SQLException { + delegate().setNull(parameterIndex, sqlType, typeName); + } + + @Override + public void setURL(int parameterIndex, URL x) throws SQLException { + delegate().setURL(parameterIndex, x); + } + + @Override + public ParameterMetaData getParameterMetaData() throws SQLException { + return delegate().getParameterMetaData(); + } + + @Override + public void setRowId(int parameterIndex, RowId x) throws SQLException { + delegate().setRowId(parameterIndex, x); + } + + @Override + public void setNString(int parameterIndex, String value) throws SQLException { + delegate().setNString(parameterIndex, value); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value, long length) + throws SQLException { + delegate().setNCharacterStream(parameterIndex, value, length); + } + + @Override + public void setNClob(int parameterIndex, NClob value) throws SQLException { + delegate().setNClob(parameterIndex, value); + } + + @Override + public void setClob(int parameterIndex, Reader reader, long length) throws SQLException { + delegate().setClob(parameterIndex, reader, length); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream, long length) + throws SQLException { + delegate().setBlob(parameterIndex, inputStream, length); + } + + @Override + public void setNClob(int parameterIndex, Reader reader, long length) throws SQLException { + delegate().setNClob(parameterIndex, reader, length); + } + + @Override + public void setSQLXML(int parameterIndex, SQLXML xmlObject) throws SQLException { + delegate().setSQLXML(parameterIndex, xmlObject); + } + + @Override + public void setObject(int parameterIndex, Object x, int targetSqlType, int scaleOrLength) + throws SQLException { + delegate().setObject(parameterIndex, x, targetSqlType, scaleOrLength); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x, long length) throws SQLException { + delegate().setAsciiStream(parameterIndex, x, length); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x, long length) throws SQLException { + delegate().setBinaryStream(parameterIndex, x, length); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader, long length) + throws SQLException { + delegate().setCharacterStream(parameterIndex, reader, length); + } + + @Override + public void setAsciiStream(int parameterIndex, InputStream x) throws SQLException { + delegate().setAsciiStream(parameterIndex, x); + } + + @Override + public void setBinaryStream(int parameterIndex, InputStream x) throws SQLException { + delegate().setBinaryStream(parameterIndex, x); + } + + @Override + public void setCharacterStream(int parameterIndex, Reader reader) throws SQLException { + delegate().setCharacterStream(parameterIndex, reader); + } + + @Override + public void setNCharacterStream(int parameterIndex, Reader value) throws SQLException { + delegate().setNCharacterStream(parameterIndex, value); + } + + @Override + public void setClob(int parameterIndex, Reader reader) throws SQLException { + delegate().setClob(parameterIndex, reader); + } + + @Override + public void setBlob(int parameterIndex, InputStream inputStream) throws SQLException { + delegate().setBlob(parameterIndex, inputStream); + } + + @Override + public void setNClob(int parameterIndex, Reader reader) throws SQLException { + delegate().setNClob(parameterIndex, reader); + } + + // --- Statement --- + + @Override + public ResultSet executeQuery(String sql) throws SQLException { + return delegate().executeQuery(sql); + } + + @Override + public int executeUpdate(String sql) throws SQLException { + return delegate().executeUpdate(sql); + } + + @Override + public void close() throws SQLException { + delegate().close(); + } + + @Override + public int getMaxFieldSize() throws SQLException { + return delegate().getMaxFieldSize(); + } + + @Override + public void setMaxFieldSize(int max) throws SQLException { + delegate().setMaxFieldSize(max); + } + + @Override + public int getMaxRows() throws SQLException { + return delegate().getMaxRows(); + } + + @Override + public void setMaxRows(int max) throws SQLException { + delegate().setMaxRows(max); + } + + @Override + public void setEscapeProcessing(boolean enable) throws SQLException { + delegate().setEscapeProcessing(enable); + } + + @Override + public int getQueryTimeout() throws SQLException { + return delegate().getQueryTimeout(); + } + + @Override + public void setQueryTimeout(int seconds) throws SQLException { + delegate().setQueryTimeout(seconds); + } + + @Override + public void cancel() throws SQLException { + delegate().cancel(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return delegate().getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + delegate().clearWarnings(); + } + + @Override + public void setCursorName(String name) throws SQLException { + delegate().setCursorName(name); + } + + @Override + public boolean execute(String sql) throws SQLException { + return delegate().execute(sql); + } + + @Override + public ResultSet getResultSet() throws SQLException { + return delegate().getResultSet(); + } + + @Override + public int getUpdateCount() throws SQLException { + return delegate().getUpdateCount(); + } + + @Override + public boolean getMoreResults() throws SQLException { + return delegate().getMoreResults(); + } + + @Override + public void setFetchDirection(int direction) throws SQLException { + delegate().setFetchDirection(direction); + } + + @Override + public int getFetchDirection() throws SQLException { + return delegate().getFetchDirection(); + } + + @Override + public void setFetchSize(int rows) throws SQLException { + delegate().setFetchSize(rows); + } + + @Override + public int getFetchSize() throws SQLException { + return delegate().getFetchSize(); + } + + @Override + public int getResultSetConcurrency() throws SQLException { + return delegate().getResultSetConcurrency(); + } + + @Override + public int getResultSetType() throws SQLException { + return delegate().getResultSetType(); + } + + @Override + public void addBatch(String sql) throws SQLException { + delegate().addBatch(sql); + } + + @Override + public void clearBatch() throws SQLException { + delegate().clearBatch(); + } + + @Override + public int[] executeBatch() throws SQLException { + return delegate().executeBatch(); + } + + @Override + public Connection getConnection() throws SQLException { + return delegate().getConnection(); + } + + @Override + public boolean getMoreResults(int current) throws SQLException { + return delegate().getMoreResults(current); + } + + @Override + public ResultSet getGeneratedKeys() throws SQLException { + return delegate().getGeneratedKeys(); + } + + @Override + public int executeUpdate(String sql, int autoGeneratedKeys) throws SQLException { + return delegate().executeUpdate(sql, autoGeneratedKeys); + } + + @Override + public int executeUpdate(String sql, int[] columnIndexes) throws SQLException { + return delegate().executeUpdate(sql, columnIndexes); + } + + @Override + public int executeUpdate(String sql, String[] columnNames) throws SQLException { + return delegate().executeUpdate(sql, columnNames); + } + + @Override + public boolean execute(String sql, int autoGeneratedKeys) throws SQLException { + return delegate().execute(sql, autoGeneratedKeys); + } + + @Override + public boolean execute(String sql, int[] columnIndexes) throws SQLException { + return delegate().execute(sql, columnIndexes); + } + + @Override + public boolean execute(String sql, String[] columnNames) throws SQLException { + return delegate().execute(sql, columnNames); + } + + @Override + public int getResultSetHoldability() throws SQLException { + return delegate().getResultSetHoldability(); + } + + @Override + public boolean isClosed() throws SQLException { + return delegate().isClosed(); + } + + @Override + public void setPoolable(boolean poolable) throws SQLException { + delegate().setPoolable(poolable); + } + + @Override + public boolean isPoolable() throws SQLException { + return delegate().isPoolable(); + } + + @Override + public void closeOnCompletion() throws SQLException { + delegate().closeOnCompletion(); + } + + @Override + public boolean isCloseOnCompletion() throws SQLException { + return delegate().isCloseOnCompletion(); + } + + // --- Wrapper --- + + @Override + public T unwrap(Class iface) throws SQLException { + if (iface.isInstance(this)) { + return iface.cast(this); + } + return delegate().unwrap(iface); + } + + @Override + public boolean isWrapperFor(Class iface) throws SQLException { + if (iface.isInstance(this)) { + return true; + } + return delegate().isWrapperFor(iface); + } +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedParamStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedParamStatement.java new file mode 100644 index 000000000..55cf0d35a --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedParamStatement.java @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLType; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.List; +import org.apache.arrow.driver.jdbc.utils.NamedSqlParser; + +/** + * Package-private decorator for {@link PreparedStatement} that implements {@link + * NamedPreparedStatement}, providing setter methods using parameter names instead of positional + * indices. Callers should use the {@link NamedPreparedStatement} interface — never this class + * directly. + * + *

All standard {@link PreparedStatement} / {@link java.sql.Statement} methods are forwarded + * transparently via {@link ForwardingPreparedStatement}. + */ +class NamedParamStatement extends ForwardingPreparedStatement implements NamedPreparedStatement { + + /** + * Functional interface for the caller-supplied {@link PreparedStatement} factory, allowing + * lambdas that throw {@link SQLException}. + */ + @FunctionalInterface + interface PrepareFunction { + PreparedStatement prepare(String positionalSql) throws SQLException; + } + + /** + * Parses {@code sql} for named parameters and wraps the delegate produced by {@code fn} in a + * {@link NamedParamStatement}. + * + * @param sql the original SQL, possibly containing {@code :name} tokens + * @param fn supplies the underlying {@link PreparedStatement} from the translated positional SQL + * @return a {@link NamedPreparedStatement} ready for use + * @throws SQLException if parsing fails or {@code fn} throws + */ + static NamedPreparedStatement wrap(String sql, PrepareFunction fn) throws SQLException { + NamedSqlParser.ParseResult parsed = NamedSqlParser.parse(sql); + return new NamedParamStatement(fn.prepare(parsed.positionalSql), parsed); + } + + private final PreparedStatement delegate; + private final NamedSqlParser.ParseResult parseResult; + + /** + * Creates a statement that uses named parameters. + * + * @param delegate the underlying PreparedStatement + * @param parseResult the result of parsing the named-parameter SQL + */ + NamedParamStatement(PreparedStatement delegate, NamedSqlParser.ParseResult parseResult) { + this.delegate = delegate; + this.parseResult = parseResult; + } + + @Override + protected PreparedStatement delegate() { + return delegate; + } + + private List getIndices(String name) throws SQLException { + List indices = parseResult.nameToIndices.get(name); + if (indices == null || indices.isEmpty()) { + throw new SQLException("Unknown parameter name: '" + name + "'"); + } + return indices; + } + + // --- Named setters --- + + @Override + public void setNull(String name, int sqlType) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNull(index, sqlType); + } + } + + @Override + public void setBoolean(String name, boolean x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBoolean(index, x); + } + } + + @Override + public void setByte(String name, byte x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setByte(index, x); + } + } + + @Override + public void setShort(String name, short x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setShort(index, x); + } + } + + @Override + public void setInt(String name, int x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setInt(index, x); + } + } + + @Override + public void setLong(String name, long x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setLong(index, x); + } + } + + @Override + public void setFloat(String name, float x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setFloat(index, x); + } + } + + @Override + public void setDouble(String name, double x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setDouble(index, x); + } + } + + @Override + public void setBigDecimal(String name, BigDecimal x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBigDecimal(index, x); + } + } + + @Override + public void setString(String name, String x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setString(index, x); + } + } + + @Override + public void setBytes(String name, byte[] x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBytes(index, x); + } + } + + @Override + public void setDate(String name, Date x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setDate(index, x); + } + } + + @Override + public void setTime(String name, Time x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setTime(index, x); + } + } + + @Override + public void setTimestamp(String name, Timestamp x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setTimestamp(index, x); + } + } + + @Override + public void setObject(String name, Object x, int targetSqlType) throws SQLException { + for (int index : getIndices(name)) { + delegate.setObject(index, x, targetSqlType); + } + } + + @Override + public void setObject(String name, Object x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setObject(index, x); + } + } + + @Override + public void setArray(String name, Array x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setArray(index, x); + } + } + + @Override + public void setNull(String name, int sqlType, String typeName) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNull(index, sqlType, typeName); + } + } + + @Override + public void setDate(String name, Date x, Calendar cal) throws SQLException { + for (int index : getIndices(name)) { + delegate.setDate(index, x, cal); + } + } + + @Override + public void setTime(String name, Time x, Calendar cal) throws SQLException { + for (int index : getIndices(name)) { + delegate.setTime(index, x, cal); + } + } + + @Override + public void setTimestamp(String name, Timestamp x, Calendar cal) throws SQLException { + for (int index : getIndices(name)) { + delegate.setTimestamp(index, x, cal); + } + } + + @Override + public void setObject(String name, Object x, int targetSqlType, int scaleOrLength) + throws SQLException { + for (int index : getIndices(name)) { + delegate.setObject(index, x, targetSqlType, scaleOrLength); + } + } + + @Override + public void setURL(String name, URL x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setURL(index, x); + } + } + + @Override + public void setRef(String name, Ref x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setRef(index, x); + } + } + + @Override + public void setBlob(String name, Blob x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBlob(index, x); + } + } + + @Override + public void setBlob(String name, InputStream inputStream, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBlob(index, inputStream, length); + } + } + + @Override + public void setBlob(String name, InputStream inputStream) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBlob(index, inputStream); + } + } + + @Override + public void setClob(String name, Clob x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setClob(index, x); + } + } + + @Override + public void setClob(String name, Reader reader, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setClob(index, reader, length); + } + } + + @Override + public void setClob(String name, Reader reader) throws SQLException { + for (int index : getIndices(name)) { + delegate.setClob(index, reader); + } + } + + @Override + public void setNClob(String name, NClob value) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNClob(index, value); + } + } + + @Override + public void setNClob(String name, Reader reader, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNClob(index, reader, length); + } + } + + @Override + public void setNClob(String name, Reader reader) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNClob(index, reader); + } + } + + @Override + public void setNString(String name, String value) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNString(index, value); + } + } + + @Override + public void setNCharacterStream(String name, Reader value, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNCharacterStream(index, value, length); + } + } + + @Override + public void setNCharacterStream(String name, Reader value) throws SQLException { + for (int index : getIndices(name)) { + delegate.setNCharacterStream(index, value); + } + } + + @Override + public void setAsciiStream(String name, InputStream x, int length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setAsciiStream(index, x, length); + } + } + + @Override + public void setAsciiStream(String name, InputStream x, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setAsciiStream(index, x, length); + } + } + + @Override + public void setAsciiStream(String name, InputStream x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setAsciiStream(index, x); + } + } + + @Override + public void setBinaryStream(String name, InputStream x, int length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBinaryStream(index, x, length); + } + } + + @Override + public void setBinaryStream(String name, InputStream x, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBinaryStream(index, x, length); + } + } + + @Override + public void setBinaryStream(String name, InputStream x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setBinaryStream(index, x); + } + } + + @Override + public void setCharacterStream(String name, Reader reader, int length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setCharacterStream(index, reader, length); + } + } + + @Override + public void setCharacterStream(String name, Reader reader, long length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setCharacterStream(index, reader, length); + } + } + + @Override + public void setCharacterStream(String name, Reader reader) throws SQLException { + for (int index : getIndices(name)) { + delegate.setCharacterStream(index, reader); + } + } + + @Override + public void setRowId(String name, RowId x) throws SQLException { + for (int index : getIndices(name)) { + delegate.setRowId(index, x); + } + } + + @Override + public void setSQLXML(String name, SQLXML xmlObject) throws SQLException { + for (int index : getIndices(name)) { + delegate.setSQLXML(index, xmlObject); + } + } + + @Override + @Deprecated + public void setUnicodeStream(String name, InputStream x, int length) throws SQLException { + for (int index : getIndices(name)) { + delegate.setUnicodeStream(index, x, length); + } + } + + @Override + public void setObject(String name, Object x, SQLType targetSqlType, int scaleOrLength) + throws SQLException { + for (int index : getIndices(name)) { + delegate.setObject(index, x, targetSqlType, scaleOrLength); + } + } + + @Override + public void setObject(String name, Object x, SQLType targetSqlType) throws SQLException { + for (int index : getIndices(name)) { + delegate.setObject(index, x, targetSqlType); + } + } +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedPreparedStatement.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedPreparedStatement.java new file mode 100644 index 000000000..230bafa61 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/NamedPreparedStatement.java @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.Ref; +import java.sql.RowId; +import java.sql.SQLException; +import java.sql.SQLType; +import java.sql.SQLXML; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; + +/** + * Extension of {@link PreparedStatement} that adds setter methods accepting parameter names instead + * of positional indices. + * + *

Instances are returned by {@link ArrowFlightConnection#prepareStatement(String)} when the SQL + * contains at least one {@code :name} token. Every {@code setXxx(String name, …)} method resolves + * the name to one or more 1-based positional indices and forwards to the corresponding {@code + * setXxx(int, …)} method on the underlying delegate statement. + * + *

Usage: + * + *

{@code
+ * NamedPreparedStatement ps = (NamedPreparedStatement)
+ *     conn.prepareStatement("SELECT * FROM t WHERE id = :id AND status = :status");
+ * ps.setInt("id", 42);
+ * ps.setString("status", "active");
+ * ResultSet rs = ps.executeQuery();
+ * }
+ */ +public interface NamedPreparedStatement extends PreparedStatement { + + /** Sets the designated parameter to SQL NULL. */ + void setNull(String name, int sqlType) throws SQLException; + + /** Sets the designated parameter to the given Java boolean value. */ + void setBoolean(String name, boolean x) throws SQLException; + + /** Sets the designated parameter to the given Java byte value. */ + void setByte(String name, byte x) throws SQLException; + + /** Sets the designated parameter to the given Java short value. */ + void setShort(String name, short x) throws SQLException; + + /** Sets the designated parameter to the given Java int value. */ + void setInt(String name, int x) throws SQLException; + + /** Sets the designated parameter to the given Java long value. */ + void setLong(String name, long x) throws SQLException; + + /** Sets the designated parameter to the given Java float value. */ + void setFloat(String name, float x) throws SQLException; + + /** Sets the designated parameter to the given Java double value. */ + void setDouble(String name, double x) throws SQLException; + + /** Sets the designated parameter to the given {@link BigDecimal} value. */ + void setBigDecimal(String name, BigDecimal x) throws SQLException; + + /** Sets the designated parameter to the given Java String value. */ + void setString(String name, String x) throws SQLException; + + /** Sets the designated parameter to the given Java array of bytes. */ + void setBytes(String name, byte[] x) throws SQLException; + + /** Sets the designated parameter to the given {@link Date} value. */ + void setDate(String name, Date x) throws SQLException; + + /** Sets the designated parameter to the given {@link Time} value. */ + void setTime(String name, Time x) throws SQLException; + + /** Sets the designated parameter to the given {@link Timestamp} value. */ + void setTimestamp(String name, Timestamp x) throws SQLException; + + /** Sets the value of the designated parameter with the given object and SQL type. */ + void setObject(String name, Object x, int targetSqlType) throws SQLException; + + /** Sets the value of the designated parameter using the given object. */ + void setObject(String name, Object x) throws SQLException; + + /** Sets the designated parameter to the given {@link Array} object. */ + void setArray(String name, Array x) throws SQLException; + + /** Sets the designated parameter to SQL NULL with the given type name. */ + void setNull(String name, int sqlType, String typeName) throws SQLException; + + /** + * Sets the designated parameter to the given {@link Date} value using the given {@link Calendar}. + */ + void setDate(String name, Date x, Calendar cal) throws SQLException; + + /** + * Sets the designated parameter to the given {@link Time} value using the given {@link Calendar}. + */ + void setTime(String name, Time x, Calendar cal) throws SQLException; + + /** + * Sets the designated parameter to the given {@link Timestamp} value using the given {@link + * Calendar}. + */ + void setTimestamp(String name, Timestamp x, Calendar cal) throws SQLException; + + /** + * Sets the value of the designated parameter with the given object, SQL type, and scale/length. + */ + void setObject(String name, Object x, int targetSqlType, int scaleOrLength) throws SQLException; + + /** Sets the designated parameter to the given {@link URL} value. */ + void setURL(String name, URL x) throws SQLException; + + /** Sets the designated parameter to the given {@link Ref} object. */ + void setRef(String name, Ref x) throws SQLException; + + /** Sets the designated parameter to the given {@link Blob} object. */ + void setBlob(String name, Blob x) throws SQLException; + + /** + * Sets the designated parameter to a {@link Blob} read from the given stream of the given length. + */ + void setBlob(String name, InputStream inputStream, long length) throws SQLException; + + /** Sets the designated parameter to a {@link Blob} read from the given stream. */ + void setBlob(String name, InputStream inputStream) throws SQLException; + + /** Sets the designated parameter to the given {@link Clob} object. */ + void setClob(String name, Clob x) throws SQLException; + + /** + * Sets the designated parameter to a {@link Clob} read from the given reader of the given length. + */ + void setClob(String name, Reader reader, long length) throws SQLException; + + /** Sets the designated parameter to a {@link Clob} read from the given reader. */ + void setClob(String name, Reader reader) throws SQLException; + + /** Sets the designated parameter to the given {@link NClob} object. */ + void setNClob(String name, NClob value) throws SQLException; + + /** + * Sets the designated parameter to a {@link NClob} read from the given reader of the given + * length. + */ + void setNClob(String name, Reader reader, long length) throws SQLException; + + /** Sets the designated parameter to a {@link NClob} read from the given reader. */ + void setNClob(String name, Reader reader) throws SQLException; + + /** Sets the designated parameter to the given national character set {@link String} value. */ + void setNString(String name, String value) throws SQLException; + + /** + * Sets the designated parameter to the given national character {@link Reader} of the given + * length. + */ + void setNCharacterStream(String name, Reader value, long length) throws SQLException; + + /** Sets the designated parameter to the given national character {@link Reader}. */ + void setNCharacterStream(String name, Reader value) throws SQLException; + + /** Sets the designated parameter to the given ASCII stream of the given length (int). */ + void setAsciiStream(String name, InputStream x, int length) throws SQLException; + + /** Sets the designated parameter to the given ASCII stream of the given length (long). */ + void setAsciiStream(String name, InputStream x, long length) throws SQLException; + + /** Sets the designated parameter to the given ASCII stream. */ + void setAsciiStream(String name, InputStream x) throws SQLException; + + /** Sets the designated parameter to the given binary stream of the given length (int). */ + void setBinaryStream(String name, InputStream x, int length) throws SQLException; + + /** Sets the designated parameter to the given binary stream of the given length (long). */ + void setBinaryStream(String name, InputStream x, long length) throws SQLException; + + /** Sets the designated parameter to the given binary stream. */ + void setBinaryStream(String name, InputStream x) throws SQLException; + + /** Sets the designated parameter to the given character stream of the given length (int). */ + void setCharacterStream(String name, Reader reader, int length) throws SQLException; + + /** Sets the designated parameter to the given character stream of the given length (long). */ + void setCharacterStream(String name, Reader reader, long length) throws SQLException; + + /** Sets the designated parameter to the given character stream. */ + void setCharacterStream(String name, Reader reader) throws SQLException; + + /** Sets the designated parameter to the given {@link RowId} object. */ + void setRowId(String name, RowId x) throws SQLException; + + /** Sets the designated parameter to the given {@link SQLXML} object. */ + void setSQLXML(String name, SQLXML xmlObject) throws SQLException; + + /** + * Sets the designated parameter to the given Unicode stream (deprecated). + * + * @deprecated {@code setUnicodeStream} is deprecated; use {@link #setCharacterStream} instead. + */ + @Deprecated + void setUnicodeStream(String name, InputStream x, int length) throws SQLException; + + /** Sets the value of the designated parameter with the given object and {@link SQLType}. */ + void setObject(String name, Object x, SQLType targetSqlType, int scaleOrLength) + throws SQLException; + + /** Sets the value of the designated parameter using the given object and {@link SQLType}. */ + void setObject(String name, Object x, SQLType targetSqlType) throws SQLException; +} diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParser.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParser.java new file mode 100644 index 000000000..3cdd3570e --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParser.java @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.utils; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Single-pass SQL scanner that translates named bind parameters ({@code :name}) to positional + * {@code ?} placeholders, building the index mappings needed by {@link + * org.apache.arrow.driver.jdbc.NamedPreparedStatement}. + * + *

The scanner understands string literals ({@code '…'} and {@code "…"}), line comments ({@code + * --}), block comments ({@code /* … *\/}), and PostgreSQL cast syntax ({@code ::type}) so that + * colon characters in those contexts are never mistaken for parameter markers. + */ +public final class NamedSqlParser { + + private NamedSqlParser() {} + + // ------------------------------------------------------------------------- + // Public API + // ------------------------------------------------------------------------- + + /** + * Parses {@code sql}, replacing every {@code :name} token with {@code ?} and recording the + * position mappings. + * + * @throws SQLException if the query mixes positional {@code ?} and named {@code :name} + * parameters. + */ + public static ParseResult parse(String sql) throws SQLException { + if (sql == null || sql.isEmpty()) { + return new ParseResult(sql, Collections.emptyMap(), Collections.emptyList()); + } + + StringBuilder out = new StringBuilder(sql.length()); + Map> nameToIndices = new HashMap<>(); + List orderedNames = new ArrayList<>(); + + boolean foundPositional = false; + boolean foundNamed = false; + + int paramIndex = 0; // 0-based; 1-based slot = paramIndex + 1 + int i = 0; + int len = sql.length(); + + State state = State.NORMAL; + char stringDelimiter = 0; + + while (i < len) { + char c = sql.charAt(i); + + switch (state) { + case NORMAL: + if (c == '-' && i + 1 < len && sql.charAt(i + 1) == '-') { + out.append(c).append(sql.charAt(i + 1)); + i += 2; + state = State.LINE_COMMENT; + } else if (c == '/' && i + 1 < len && sql.charAt(i + 1) == '*') { + out.append(c).append(sql.charAt(i + 1)); + i += 2; + state = State.BLOCK_COMMENT; + } else if (c == '\'' || c == '"') { + stringDelimiter = c; + out.append(c); + i++; + state = State.STRING; + } else if (c == ':') { + // PostgreSQL cast :: + if (i + 1 < len && sql.charAt(i + 1) == ':') { + out.append("::"); // emit both, move past them + i += 2; + } else if (i + 1 < len && isWordChar(sql.charAt(i + 1))) { + // Named parameter — collect name + int nameStart = i + 1; + int j = nameStart; + while (j < len && isWordChar(sql.charAt(j))) { + j++; + } + String name = sql.substring(nameStart, j); + if (foundPositional) { + throw new SQLException( + "Cannot mix positional '?' and named ':name' parameters in the same query"); + } + foundNamed = true; + int slot = ++paramIndex; // 1-based + nameToIndices.computeIfAbsent(name, k -> new ArrayList<>()).add(slot); + orderedNames.add(name); + out.append('?'); + i = j; + } else { + // Bare ':' with no name — emit as-is + out.append(c); + i++; + } + } else if (c == '?') { + if (foundNamed) { + throw new SQLException( + "Cannot mix positional '?' and named ':name' parameters in the same query"); + } + foundPositional = true; + out.append(c); + i++; + } else { + out.append(c); + i++; + } + break; + + case STRING: + out.append(c); + i++; + if (c == stringDelimiter) { + // Handle escaped delimiter ('') or ("") + if (i < len && sql.charAt(i) == stringDelimiter) { + out.append(sql.charAt(i)); + i++; + } else { + state = State.NORMAL; + } + } + break; + + case LINE_COMMENT: + out.append(c); + i++; + if (c == '\n') { + state = State.NORMAL; + } + break; + + case BLOCK_COMMENT: + out.append(c); + i++; + if (c == '*' && i < len && sql.charAt(i) == '/') { + out.append(sql.charAt(i)); + i++; + state = State.NORMAL; + } + break; + + default: + out.append(c); + i++; + } + } + + return new ParseResult(out.toString(), nameToIndices, orderedNames); + } + + // ------------------------------------------------------------------------- + // Helpers + // ------------------------------------------------------------------------- + + private static boolean isWordChar(char c) { + return Character.isLetterOrDigit(c) || c == '_'; + } + + private enum State { + NORMAL, + STRING, + LINE_COMMENT, + BLOCK_COMMENT + } + + // ------------------------------------------------------------------------- + // Result + // ------------------------------------------------------------------------- + + /** Immutable result of a {@link #parse} call. */ + public static final class ParseResult { + /** The original SQL with every {@code :name} token replaced by {@code ?}. */ + public final String positionalSql; + + /** Maps each parameter name to the list of 1-based positional slots it occupies. */ + public final Map> nameToIndices; + + /** + * Ordered list of parameter names in the order they appear in the SQL (0-based index + * corresponds to the (index+1)-th {@code ?} placeholder). + */ + public final List orderedNames; + + ParseResult( + String positionalSql, Map> nameToIndices, List orderedNames) { + this.positionalSql = positionalSql; + this.nameToIndices = + Collections.unmodifiableMap(nameToIndices != null ? nameToIndices : new HashMap<>()); + this.orderedNames = + Collections.unmodifiableList(orderedNames != null ? orderedNames : new ArrayList<>()); + } + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/NamedParamStatementTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/NamedParamStatementTest.java new file mode 100644 index 000000000..7474a1dc0 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/NamedParamStatementTest.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Types; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class NamedParamStatementTest { + + public static final MockFlightSqlProducer PRODUCER = CoreMockedSqlProducers.getLegacyProducer(); + + @RegisterExtension + public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION = + FlightServerTestExtension.createStandardTestExtension(PRODUCER); + + private static Connection connection; + + @BeforeAll + public static void setup() throws SQLException { + connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false); + } + + @AfterAll + public static void tearDown() throws SQLException { + connection.close(); + } + + @BeforeEach + public void before() { + PRODUCER.clearActionTypeCounter(); + } + + /** + * Every prepareStatement call returns a NamedPreparedStatement — even for positional SQL. The + * cast is always safe, matching Oracle's OraclePreparedStatement behaviour. + */ + @Test + public void testPositionalSqlCastIsAlwaysSafe() throws SQLException { + final String positionalQuery = "SELECT 1 WHERE id = ?"; + PRODUCER.addSelectQuery( + positionalQuery, + new Schema( + Collections.singletonList( + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))), + Collections.singletonList( + listener -> { + try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + VectorSchemaRoot root = + VectorSchemaRoot.create( + new Schema( + Collections.singletonList( + Field.nullable( + "", + org.apache.arrow.vector.types.Types.MinorType.INT + .getType()))), + allocator)) { + root.setRowCount(0); + listener.start(root); + listener.putNext(); + } catch (Throwable t) { + listener.error(t); + } finally { + listener.completed(); + } + })); + + // Must NOT throw ClassCastException — cast is always safe + PreparedStatement ps = connection.prepareStatement(positionalQuery); + assertTrue( + ps instanceof NamedPreparedStatement, + "prepareStatement must always return NamedPreparedStatement"); + ps.close(); + } + + @Test + public void testSingleNamedParamSelect() throws SQLException { + final String clientQuery = "Fake query with :p1 and :p2"; + final String serverQuery = "Fake query with ? and ?"; + + final Schema schema = + new Schema( + Collections.singletonList( + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))); + final Schema parameterSchema = + new Schema( + Arrays.asList( + Field.nullable("", ArrowType.Utf8.INSTANCE), + new Field( + "", + FieldType.nullable(ArrowType.List.INSTANCE), + Collections.singletonList( + Field.nullable( + "", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))))); + final List> expected = + Collections.singletonList(Arrays.asList(new Text("foo"), new Integer[] {1, 2, null})); + + PRODUCER.addSelectQuery( + serverQuery, + schema, + Collections.singletonList( + listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + ((IntVector) root.getVector(0)).setSafe(0, 10); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + + PRODUCER.addExpectedParameters(serverQuery, parameterSchema, expected); + + try (final NamedPreparedStatement ps = + (NamedPreparedStatement) connection.prepareStatement(clientQuery)) { + ps.setString("p1", "foo"); + ps.setArray("p2", connection.createArrayOf("INTEGER", new Integer[] {1, 2, null})); + + try (final ResultSet resultSet = ps.executeQuery()) { + assertTrue(resultSet.next()); + assertEquals(10, resultSet.getInt(1)); + } + } + } + + @Test + public void testDuplicateNamedParam() throws SQLException { + final String clientQuery = "SELECT * FROM t WHERE id = :id OR parent_id = :id"; + final String serverQuery = "SELECT * FROM t WHERE id = ? OR parent_id = ?"; + + final Schema parameterSchema = + new Schema( + Arrays.asList( + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()), + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))); + final List> expected = Collections.singletonList(Arrays.asList(42, 42)); + + PRODUCER.addUpdateQuery(serverQuery, 1); + PRODUCER.addExpectedParameters(serverQuery, parameterSchema, expected); + + try (final NamedPreparedStatement ps = + (NamedPreparedStatement) connection.prepareStatement(clientQuery)) { + ps.setInt("id", 42); + assertEquals(1, ps.executeUpdate()); + } + } + + @Test + public void testUnknownNameThrows() throws SQLException { + final String clientQuery = "named unknown test :val"; + final String serverQuery = "named unknown test ?"; + + final Schema paramSchema = + new Schema(Collections.singletonList(Field.nullable("", ArrowType.Utf8.INSTANCE))); + PRODUCER.addUpdateQuery(serverQuery, 0); + PRODUCER.addExpectedParameters(serverQuery, paramSchema, null); + + try (final NamedPreparedStatement ps = + (NamedPreparedStatement) connection.prepareStatement(clientQuery)) { + // Known name: must not throw + ps.setString("val", "test"); + + // Unknown name: must throw with descriptive message + SQLException ex = assertThrows(SQLException.class, () -> ps.setInt("nonexistent", 1)); + assertTrue(ex.getMessage().contains("Unknown parameter name: 'nonexistent'")); + } + } + + @Test + public void testSetNullExecutesUpdate() throws SQLException { + final String clientQuery = "setNull test :id"; + final String serverQuery = "setNull test ?"; + + final Schema paramSchema = + new Schema( + Collections.singletonList( + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))); + PRODUCER.addUpdateQuery(serverQuery, 7); + PRODUCER.addExpectedParameters(serverQuery, paramSchema, null); + + try (final NamedPreparedStatement ps = + (NamedPreparedStatement) connection.prepareStatement(clientQuery)) { + ps.setNull("id", Types.INTEGER); + assertEquals(7, ps.executeUpdate()); + } + } + + @Test + public void testAddBatch() throws SQLException { + final String clientQuery = "UPDATE t SET name = :name WHERE id = :id"; + final String serverQuery = "UPDATE t SET name = ? WHERE id = ?"; + + Schema parameterSchema = + new Schema( + Arrays.asList( + Field.nullable("", ArrowType.Utf8.INSTANCE), + Field.nullable("", org.apache.arrow.vector.types.Types.MinorType.INT.getType()))); + List> expected = + Arrays.asList(Arrays.asList(new Text("foo"), 1), Arrays.asList(new Text("bar"), 2)); + + PRODUCER.addUpdateQuery(serverQuery, 42); + PRODUCER.addExpectedParameters(serverQuery, parameterSchema, expected); + + try (final NamedPreparedStatement ps = + (NamedPreparedStatement) connection.prepareStatement(clientQuery)) { + ps.setString("name", "foo"); + ps.setInt("id", 1); + ps.addBatch(); + + ps.setString("name", "bar"); + ps.setInt("id", 2); + ps.addBatch(); + + int[] updated = ps.executeBatch(); + assertEquals(42, updated[0]); + } + } +} diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParserTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParserTest.java new file mode 100644 index 000000000..af783b6d4 --- /dev/null +++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/NamedSqlParserTest.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.arrow.driver.jdbc.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import org.junit.jupiter.api.Test; + +public class NamedSqlParserTest { + + @Test + public void testNoParameters() throws SQLException { + String sql = "SELECT * FROM t"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals(sql, res.positionalSql); + assertTrue(res.nameToIndices.isEmpty()); + assertTrue(res.orderedNames.isEmpty()); + } + + @Test + public void testSingleNamedParam() throws SQLException { + String sql = "SELECT * FROM t WHERE id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t WHERE id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + assertEquals(Collections.singletonList("id"), res.orderedNames); + } + + @Test + public void testDuplicateNamedParam() throws SQLException { + String sql = "SELECT * FROM t WHERE start_date = :date OR end_date = :date"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t WHERE start_date = ? OR end_date = ?", res.positionalSql); + assertEquals(Arrays.asList(1, 2), res.nameToIndices.get("date")); + assertEquals(Arrays.asList("date", "date"), res.orderedNames); + } + + @Test + public void testTwoDistinctNames() throws SQLException { + String sql = "SELECT * FROM t WHERE a = :first AND b = :second"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t WHERE a = ? AND b = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("first")); + assertEquals(Collections.singletonList(2), res.nameToIndices.get("second")); + assertEquals(Arrays.asList("first", "second"), res.orderedNames); + } + + @Test + public void testInsideStringLiteral() throws SQLException { + String sql = "SELECT * FROM t WHERE name = ':ignored' AND id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t WHERE name = ':ignored' AND id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + } + + @Test + public void testInsideDoubleQuotedLiteral() throws SQLException { + String sql = "SELECT * FROM \":ignored\" WHERE id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM \":ignored\" WHERE id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + } + + @Test + public void testInsideLineComment() throws SQLException { + String sql = "SELECT * FROM t -- comment with :ignored \n WHERE id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t -- comment with :ignored \n WHERE id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + } + + @Test + public void testInsideBlockComment() throws SQLException { + String sql = "SELECT * FROM t /* comment with :ignored */ WHERE id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t /* comment with :ignored */ WHERE id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + } + + @Test + public void testPostgreSQLCast() throws SQLException { + String sql = "SELECT col::int FROM t WHERE id = :id"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT col::int FROM t WHERE id = ?", res.positionalSql); + assertEquals(Collections.singletonList(1), res.nameToIndices.get("id")); + } + + @Test + public void testEmptyStringName() throws SQLException { + String sql = "SELECT * FROM t WHERE id = : AND val = :val"; + NamedSqlParser.ParseResult res = NamedSqlParser.parse(sql); + assertEquals("SELECT * FROM t WHERE id = : AND val = ?", res.positionalSql); + assertTrue(res.nameToIndices.containsKey("val")); + assertFalse(res.nameToIndices.containsKey("")); + } + + @Test + public void testMixedParams() { + String sql = "SELECT * FROM t WHERE id = ? AND name = :name"; + assertThrows(SQLException.class, () -> NamedSqlParser.parse(sql)); + } +}