| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to you under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.calcite.runtime; |
| |
| import org.apache.calcite.DataContext; |
| import org.apache.calcite.avatica.SqlType; |
| import org.apache.calcite.linq4j.AbstractEnumerable; |
| import org.apache.calcite.linq4j.Enumerable; |
| import org.apache.calcite.linq4j.Enumerator; |
| import org.apache.calcite.linq4j.Linq4j; |
| import org.apache.calcite.linq4j.function.Function0; |
| import org.apache.calcite.linq4j.function.Function1; |
| import org.apache.calcite.linq4j.tree.Primitive; |
| import org.apache.calcite.util.Static; |
| |
| import org.checkerframework.checker.nullness.qual.Nullable; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.math.BigDecimal; |
| import java.net.URL; |
| import java.sql.Blob; |
| import java.sql.Clob; |
| import java.sql.Connection; |
| import java.sql.Date; |
| import java.sql.NClob; |
| import java.sql.PreparedStatement; |
| import java.sql.Ref; |
| import java.sql.ResultSet; |
| import java.sql.ResultSetMetaData; |
| import java.sql.RowId; |
| import java.sql.SQLException; |
| import java.sql.SQLFeatureNotSupportedException; |
| import java.sql.SQLXML; |
| import java.sql.Statement; |
| import java.sql.Time; |
| import java.sql.Timestamp; |
| import java.sql.Types; |
| import java.time.Instant; |
| import java.util.ArrayList; |
| import java.util.List; |
| import javax.sql.DataSource; |
| |
| import static org.apache.calcite.linq4j.Nullness.castNonNull; |
| |
| /** |
| * Executes a SQL statement and returns the result as an {@link Enumerable}. |
| * |
| * @param <T> Element type |
| */ |
| public class ResultSetEnumerable<T> extends AbstractEnumerable<T> { |
| private final DataSource dataSource; |
| private final String sql; |
| private final Function1<ResultSet, Function0<T>> rowBuilderFactory; |
| private final @Nullable PreparedStatementEnricher preparedStatementEnricher; |
| |
| private static final Logger LOGGER = |
| LoggerFactory.getLogger(ResultSetEnumerable.class); |
| |
| private @Nullable Long queryStart; |
| private long timeout; |
| private boolean timeoutSetFailed; |
| |
| private static final Function1<ResultSet, Function0<@Nullable Object>> AUTO_ROW_BUILDER_FACTORY = |
| resultSet -> { |
| final ResultSetMetaData metaData; |
| final int columnCount; |
| try { |
| metaData = resultSet.getMetaData(); |
| columnCount = metaData.getColumnCount(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| if (columnCount == 1) { |
| return () -> { |
| try { |
| return resultSet.getObject(1); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| }; |
| } else { |
| return () -> convertColumns(resultSet, metaData, columnCount); |
| } |
| }; |
| |
| private static @Nullable Object[] convertColumns(ResultSet resultSet, ResultSetMetaData metaData, |
| int columnCount) { |
| final List<@Nullable Object> list = new ArrayList<>(columnCount); |
| try { |
| for (int i = 0; i < columnCount; i++) { |
| if (metaData.getColumnType(i + 1) == Types.TIMESTAMP) { |
| long v = resultSet.getLong(i + 1); |
| if (v == 0 && resultSet.wasNull()) { |
| list.add(null); |
| } else { |
| list.add(v); |
| } |
| } else { |
| list.add(resultSet.getObject(i + 1)); |
| } |
| } |
| return list.toArray(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| private ResultSetEnumerable( |
| DataSource dataSource, |
| String sql, |
| Function1<ResultSet, Function0<T>> rowBuilderFactory, |
| @Nullable PreparedStatementEnricher preparedStatementEnricher) { |
| this.dataSource = dataSource; |
| this.sql = sql; |
| this.rowBuilderFactory = rowBuilderFactory; |
| this.preparedStatementEnricher = preparedStatementEnricher; |
| } |
| |
| private ResultSetEnumerable( |
| DataSource dataSource, |
| String sql, |
| Function1<ResultSet, Function0<T>> rowBuilderFactory) { |
| this(dataSource, sql, rowBuilderFactory, null); |
| } |
| |
| /** Creates a ResultSetEnumerable. */ |
| public static ResultSetEnumerable<@Nullable Object> of(DataSource dataSource, String sql) { |
| return of(dataSource, sql, AUTO_ROW_BUILDER_FACTORY); |
| } |
| |
| /** Creates a ResultSetEnumerable that retrieves columns as specific |
| * Java types. */ |
| public static ResultSetEnumerable<@Nullable Object> of(DataSource dataSource, String sql, |
| Primitive[] primitives) { |
| return of(dataSource, sql, primitiveRowBuilderFactory(primitives)); |
| } |
| |
| /** Executes a SQL query and returns the results as an enumerator, using a |
| * row builder to convert JDBC column values into rows. */ |
| public static <T> ResultSetEnumerable<T> of( |
| DataSource dataSource, |
| String sql, |
| Function1<ResultSet, Function0<T>> rowBuilderFactory) { |
| return new ResultSetEnumerable<>(dataSource, sql, rowBuilderFactory); |
| } |
| |
| /** Executes a SQL query and returns the results as an enumerator, using a |
| * row builder to convert JDBC column values into rows. |
| * |
| * <p>It uses a {@link PreparedStatement} for computing the query result, |
| * and that means that it can bind parameters. */ |
| public static <T> ResultSetEnumerable<T> of( |
| DataSource dataSource, |
| String sql, |
| Function1<ResultSet, Function0<T>> rowBuilderFactory, |
| PreparedStatementEnricher consumer) { |
| return new ResultSetEnumerable<>(dataSource, sql, rowBuilderFactory, consumer); |
| } |
| |
| public void setTimeout(DataContext context) { |
| this.queryStart = (Long) context.get(DataContext.Variable.UTC_TIMESTAMP.camelName); |
| Object timeout = context.get(DataContext.Variable.TIMEOUT.camelName); |
| if (timeout instanceof Long) { |
| this.timeout = (Long) timeout; |
| } else { |
| if (timeout != null) { |
| LOGGER.debug("Variable.TIMEOUT should be `long`. Given value was {}", timeout); |
| } |
| this.timeout = 0; |
| } |
| } |
| |
| /** Called from generated code that proposes to create a |
| * {@code ResultSetEnumerable} over a prepared statement. */ |
| public static PreparedStatementEnricher createEnricher(Integer[] indexes, |
| DataContext context) { |
| return preparedStatement -> { |
| for (int i = 0; i < indexes.length; i++) { |
| final int index = indexes[i]; |
| setDynamicParam(preparedStatement, i + 1, |
| context.get("?" + index)); |
| } |
| }; |
| } |
| |
| /** Assigns a value to a dynamic parameter in a prepared statement, calling |
| * the appropriate {@code setXxx} method based on the type of the value. */ |
| private static void setDynamicParam(PreparedStatement preparedStatement, |
| int i, @Nullable Object value) throws SQLException { |
| if (value == null) { |
| // TODO: use proper type instead of ANY |
| preparedStatement.setObject(i, null, SqlType.ANY.id); |
| } else if (value instanceof Timestamp) { |
| preparedStatement.setTimestamp(i, (Timestamp) value); |
| } else if (value instanceof Time) { |
| preparedStatement.setTime(i, (Time) value); |
| } else if (value instanceof String) { |
| preparedStatement.setString(i, (String) value); |
| } else if (value instanceof Integer) { |
| preparedStatement.setInt(i, (Integer) value); |
| } else if (value instanceof Double) { |
| preparedStatement.setDouble(i, (Double) value); |
| } else if (value instanceof java.sql.Array) { |
| preparedStatement.setArray(i, (java.sql.Array) value); |
| } else if (value instanceof BigDecimal) { |
| preparedStatement.setBigDecimal(i, (BigDecimal) value); |
| } else if (value instanceof Boolean) { |
| preparedStatement.setBoolean(i, (Boolean) value); |
| } else if (value instanceof Blob) { |
| preparedStatement.setBlob(i, (Blob) value); |
| } else if (value instanceof Byte) { |
| preparedStatement.setByte(i, (Byte) value); |
| } else if (value instanceof NClob) { |
| preparedStatement.setNClob(i, (NClob) value); |
| } else if (value instanceof Clob) { |
| preparedStatement.setClob(i, (Clob) value); |
| } else if (value instanceof byte[]) { |
| preparedStatement.setBytes(i, (byte[]) value); |
| } else if (value instanceof Date) { |
| preparedStatement.setDate(i, (Date) value); |
| } else if (value instanceof Float) { |
| preparedStatement.setFloat(i, (Float) value); |
| } else if (value instanceof Long) { |
| preparedStatement.setLong(i, (Long) value); |
| } else if (value instanceof Ref) { |
| preparedStatement.setRef(i, (Ref) value); |
| } else if (value instanceof RowId) { |
| preparedStatement.setRowId(i, (RowId) value); |
| } else if (value instanceof Short) { |
| preparedStatement.setShort(i, (Short) value); |
| } else if (value instanceof URL) { |
| preparedStatement.setURL(i, (URL) value); |
| } else if (value instanceof SQLXML) { |
| preparedStatement.setSQLXML(i, (SQLXML) value); |
| } else { |
| preparedStatement.setObject(i, value); |
| } |
| } |
| |
| @Override public Enumerator<T> enumerator() { |
| if (preparedStatementEnricher == null) { |
| return enumeratorBasedOnStatement(); |
| } else { |
| return enumeratorBasedOnPreparedStatement(); |
| } |
| } |
| |
| private Enumerator<T> enumeratorBasedOnStatement() { |
| Connection connection = null; |
| Statement statement = null; |
| try { |
| connection = dataSource.getConnection(); |
| statement = connection.createStatement(); |
| setTimeoutIfPossible(statement); |
| if (statement.execute(sql)) { |
| final ResultSet resultSet = statement.getResultSet(); |
| statement = null; |
| connection = null; |
| return new ResultSetEnumerator<>(resultSet, rowBuilderFactory); |
| } else { |
| Integer updateCount = statement.getUpdateCount(); |
| //noinspection unchecked |
| return Linq4j.singletonEnumerator((T) updateCount); |
| } |
| } catch (SQLException e) { |
| throw Static.RESOURCE.exceptionWhilePerformingQueryOnJdbcSubSchema(sql) |
| .ex(e); |
| } finally { |
| closeIfPossible(connection, statement); |
| } |
| } |
| |
| private Enumerator<T> enumeratorBasedOnPreparedStatement() { |
| Connection connection = null; |
| PreparedStatement preparedStatement = null; |
| try { |
| connection = dataSource.getConnection(); |
| preparedStatement = connection.prepareStatement(sql); |
| setTimeoutIfPossible(preparedStatement); |
| castNonNull(preparedStatementEnricher).enrich(preparedStatement); |
| if (preparedStatement.execute()) { |
| final ResultSet resultSet = preparedStatement.getResultSet(); |
| preparedStatement = null; |
| connection = null; |
| return new ResultSetEnumerator<>(resultSet, rowBuilderFactory); |
| } else { |
| Integer updateCount = preparedStatement.getUpdateCount(); |
| //noinspection unchecked |
| return Linq4j.singletonEnumerator((T) updateCount); |
| } |
| } catch (SQLException e) { |
| throw Static.RESOURCE.exceptionWhilePerformingQueryOnJdbcSubSchema(sql) |
| .ex(e); |
| } finally { |
| closeIfPossible(connection, preparedStatement); |
| } |
| } |
| |
| private void setTimeoutIfPossible(Statement statement) throws SQLException { |
| Long queryStart = this.queryStart; |
| if (timeout == 0 || queryStart == null) { |
| return; |
| } |
| long now = System.currentTimeMillis(); |
| long secondsLeft = (queryStart + timeout - now) / 1000; |
| if (secondsLeft <= 0) { |
| throw Static.RESOURCE.queryExecutionTimeoutReached( |
| String.valueOf(timeout), |
| String.valueOf(Instant.ofEpochMilli(queryStart))).ex(); |
| } |
| if (secondsLeft > Integer.MAX_VALUE) { |
| // Just ignore the timeout if it happens to be too big, we can't squeeze it into int |
| return; |
| } |
| try { |
| statement.setQueryTimeout((int) secondsLeft); |
| } catch (SQLFeatureNotSupportedException e) { |
| if (!timeoutSetFailed && LOGGER.isDebugEnabled()) { |
| // We don't really want to print this again and again if enumerable is used multiple times |
| LOGGER.debug("Failed to set query timeout " + secondsLeft + " seconds", e); |
| timeoutSetFailed = true; |
| } |
| } |
| } |
| |
| private static void closeIfPossible(@Nullable Connection connection, |
| @Nullable Statement statement) { |
| if (statement != null) { |
| try { |
| statement.close(); |
| } catch (SQLException e) { |
| // ignore |
| } |
| } |
| if (connection != null) { |
| try { |
| connection.close(); |
| } catch (SQLException e) { |
| // ignore |
| } |
| } |
| } |
| |
| /** Implementation of {@link Enumerator} that reads from a |
| * {@link ResultSet}. |
| * |
| * @param <T> element type */ |
| private static class ResultSetEnumerator<T> implements Enumerator<T> { |
| private final Function0<T> rowBuilder; |
| private @Nullable ResultSet resultSet; |
| |
| ResultSetEnumerator( |
| ResultSet resultSet, |
| Function1<ResultSet, Function0<T>> rowBuilderFactory) { |
| this.resultSet = resultSet; |
| this.rowBuilder = rowBuilderFactory.apply(resultSet); |
| } |
| |
| private ResultSet resultSet() { |
| return castNonNull(resultSet); |
| } |
| |
| @Override public T current() { |
| return rowBuilder.apply(); |
| } |
| |
| @Override public boolean moveNext() { |
| try { |
| return resultSet().next(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| @Override public void reset() { |
| try { |
| resultSet().beforeFirst(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| @Override public void close() { |
| ResultSet savedResultSet = resultSet; |
| if (savedResultSet != null) { |
| try { |
| resultSet = null; |
| final Statement statement = savedResultSet.getStatement(); |
| savedResultSet.close(); |
| if (statement != null) { |
| final Connection connection = statement.getConnection(); |
| statement.close(); |
| if (connection != null) { |
| connection.close(); |
| } |
| } |
| } catch (SQLException e) { |
| // ignore |
| } |
| } |
| } |
| } |
| |
| private static Function1<ResultSet, Function0<@Nullable Object>> |
| primitiveRowBuilderFactory(final Primitive[] primitives) { |
| return resultSet -> { |
| final ResultSetMetaData metaData; |
| final int columnCount; |
| try { |
| metaData = resultSet.getMetaData(); |
| columnCount = metaData.getColumnCount(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| assert columnCount == primitives.length; |
| if (columnCount == 1) { |
| return () -> { |
| try { |
| return resultSet.getObject(1); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| }; |
| } |
| return () -> convertPrimitiveColumns(primitives, resultSet, columnCount); |
| }; |
| } |
| |
| private static @Nullable Object[] convertPrimitiveColumns(Primitive[] primitives, |
| ResultSet resultSet, int columnCount) { |
| final List<@Nullable Object> list = new ArrayList<>(columnCount); |
| try { |
| for (int i = 0; i < columnCount; i++) { |
| list.add(primitives[i].jdbcGet(resultSet, i + 1)); |
| } |
| return list.toArray(); |
| } catch (SQLException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| /** |
| * Consumer for decorating a {@link PreparedStatement}, that is, setting |
| * its parameters. |
| */ |
| public interface PreparedStatementEnricher { |
| void enrich(PreparedStatement statement) throws SQLException; |
| } |
| } |