blob: 78c3a0fd8e811903787a128532ff95d810ecbd48 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapreduce.lib.db;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.util.StringUtils;
/**
* A InputFormat that reads input data from an SQL table.
* <p>
* DBInputFormat emits LongWritables containing the record number as
* key and DBWritables as value.
*
* The SQL query, and input class can be using one of the two
* setInput methods.
*/
@InterfaceAudience.Public
@InterfaceStability.Stable
public class DBInputFormat<T extends DBWritable>
extends InputFormat<LongWritable, T> implements Configurable {
private static final Log LOG = LogFactory.getLog(DBInputFormat.class);
protected String dbProductName = "DEFAULT";
/**
* A Class that does nothing, implementing DBWritable
*/
@InterfaceStability.Evolving
public static class NullDBWritable implements DBWritable, Writable {
@Override
public void readFields(DataInput in) throws IOException { }
@Override
public void readFields(ResultSet arg0) throws SQLException { }
@Override
public void write(DataOutput out) throws IOException { }
@Override
public void write(PreparedStatement arg0) throws SQLException { }
}
/**
* A InputSplit that spans a set of rows
*/
@InterfaceStability.Evolving
public static class DBInputSplit extends InputSplit implements Writable {
private long end = 0;
private long start = 0;
/**
* Default Constructor
*/
public DBInputSplit() {
}
/**
* Convenience Constructor
* @param start the index of the first row to select
* @param end the index of the last row to select
*/
public DBInputSplit(long start, long end) {
this.start = start;
this.end = end;
}
/** {@inheritDoc} */
public String[] getLocations() throws IOException {
// TODO Add a layer to enable SQL "sharding" and support locality
return new String[] {};
}
/**
* @return The index of the first row to select
*/
public long getStart() {
return start;
}
/**
* @return The index of the last row to select
*/
public long getEnd() {
return end;
}
/**
* @return The total row count in this split
*/
public long getLength() throws IOException {
return end - start;
}
/** {@inheritDoc} */
public void readFields(DataInput input) throws IOException {
start = input.readLong();
end = input.readLong();
}
/** {@inheritDoc} */
public void write(DataOutput output) throws IOException {
output.writeLong(start);
output.writeLong(end);
}
}
protected String conditions;
protected Connection connection;
protected String tableName;
protected String[] fieldNames;
protected DBConfiguration dbConf;
/** {@inheritDoc} */
public void setConf(Configuration conf) {
dbConf = new DBConfiguration(conf);
try {
this.connection = createConnection();
DatabaseMetaData dbMeta = connection.getMetaData();
this.dbProductName =
StringUtils.toUpperCase(dbMeta.getDatabaseProductName());
}
catch (Exception ex) {
throw new RuntimeException(ex);
}
tableName = dbConf.getInputTableName();
fieldNames = dbConf.getInputFieldNames();
conditions = dbConf.getInputConditions();
}
public Configuration getConf() {
return dbConf.getConf();
}
public DBConfiguration getDBConf() {
return dbConf;
}
public Connection getConnection() {
// TODO Remove this code that handles backward compatibility.
if (this.connection == null) {
this.connection = createConnection();
}
return this.connection;
}
public Connection createConnection() {
try {
Connection newConnection = dbConf.getConnection();
newConnection.setAutoCommit(false);
newConnection.setTransactionIsolation(
Connection.TRANSACTION_SERIALIZABLE);
return newConnection;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
public String getDBProductName() {
return dbProductName;
}
protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
Configuration conf) throws IOException {
@SuppressWarnings("unchecked")
Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
try {
// use database product name to determine appropriate record reader.
if (dbProductName.startsWith("ORACLE")) {
// use Oracle-specific db reader.
return new OracleDBRecordReader<T>(split, inputClass,
conf, createConnection(), getDBConf(), conditions, fieldNames,
tableName);
} else if (dbProductName.startsWith("MYSQL")) {
// use MySQL-specific db reader.
return new MySQLDBRecordReader<T>(split, inputClass,
conf, createConnection(), getDBConf(), conditions, fieldNames,
tableName);
} else {
// Generic reader.
return new DBRecordReader<T>(split, inputClass,
conf, createConnection(), getDBConf(), conditions, fieldNames,
tableName);
}
} catch (SQLException ex) {
throw new IOException(ex.getMessage());
}
}
/** {@inheritDoc} */
public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
TaskAttemptContext context) throws IOException, InterruptedException {
return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
}
/** {@inheritDoc} */
public List<InputSplit> getSplits(JobContext job) throws IOException {
ResultSet results = null;
Statement statement = null;
try {
statement = connection.createStatement();
results = statement.executeQuery(getCountQuery());
results.next();
long count = results.getLong(1);
int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
long chunkSize = (count / chunks);
results.close();
statement.close();
List<InputSplit> splits = new ArrayList<InputSplit>();
// Split the rows into n-number of chunks and adjust the last chunk
// accordingly
for (int i = 0; i < chunks; i++) {
DBInputSplit split;
if ((i + 1) == chunks)
split = new DBInputSplit(i * chunkSize, count);
else
split = new DBInputSplit(i * chunkSize, (i * chunkSize)
+ chunkSize);
splits.add(split);
}
connection.commit();
return splits;
} catch (SQLException e) {
throw new IOException("Got SQLException", e);
} finally {
try {
if (results != null) { results.close(); }
} catch (SQLException e1) {}
try {
if (statement != null) { statement.close(); }
} catch (SQLException e1) {}
closeConnection();
}
}
/** Returns the query for getting the total number of rows,
* subclasses can override this for custom behaviour.*/
protected String getCountQuery() {
if(dbConf.getInputCountQuery() != null) {
return dbConf.getInputCountQuery();
}
StringBuilder query = new StringBuilder();
query.append("SELECT COUNT(*) FROM " + tableName);
if (conditions != null && conditions.length() > 0)
query.append(" WHERE " + conditions);
return query.toString();
}
/**
* Initializes the map-part of the job with the appropriate input settings.
*
* @param job The map-reduce job
* @param inputClass the class object implementing DBWritable, which is the
* Java object holding tuple fields.
* @param tableName The table to read data from
* @param conditions The condition which to select data with,
* eg. '(updated &gt; 20070101 AND length &gt; 0)'
* @param orderBy the fieldNames in the orderBy clause.
* @param fieldNames The field names in the table
* @see #setInput(Job, Class, String, String)
*/
public static void setInput(Job job,
Class<? extends DBWritable> inputClass,
String tableName,String conditions,
String orderBy, String... fieldNames) {
job.setInputFormatClass(DBInputFormat.class);
DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
dbConf.setInputClass(inputClass);
dbConf.setInputTableName(tableName);
dbConf.setInputFieldNames(fieldNames);
dbConf.setInputConditions(conditions);
dbConf.setInputOrderBy(orderBy);
}
/**
* Initializes the map-part of the job with the appropriate input settings.
*
* @param job The map-reduce job
* @param inputClass the class object implementing DBWritable, which is the
* Java object holding tuple fields.
* @param inputQuery the input query to select fields. Example :
* "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
* @param inputCountQuery the input query that returns
* the number of records in the table.
* Example : "SELECT COUNT(f1) FROM Mytable"
* @see #setInput(Job, Class, String, String, String, String...)
*/
public static void setInput(Job job,
Class<? extends DBWritable> inputClass,
String inputQuery, String inputCountQuery) {
job.setInputFormatClass(DBInputFormat.class);
DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
dbConf.setInputClass(inputClass);
dbConf.setInputQuery(inputQuery);
dbConf.setInputCountQuery(inputCountQuery);
}
protected void closeConnection() {
try {
if (null != this.connection) {
this.connection.close();
this.connection = null;
}
} catch (SQLException sqlE) {
LOG.debug("Exception on close", sqlE);
}
}
}