blob: d9aad1952d10a0da63d08ed5ec85f99a37723751 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.hadoop.cql3;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.net.InetAddress;
import java.nio.ByteBuffer;
import java.util.*;
import com.google.common.base.Function;
import com.google.common.base.Joiner;
import com.google.common.base.Splitter;
import com.datastax.driver.core.TypeCodec;
import org.apache.cassandra.utils.AbstractIterator;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.datastax.driver.core.Cluster;
import com.datastax.driver.core.ColumnDefinitions;
import com.datastax.driver.core.ColumnMetadata;
import com.datastax.driver.core.LocalDate;
import com.datastax.driver.core.Metadata;
import com.datastax.driver.core.ResultSet;
import com.datastax.driver.core.Row;
import com.datastax.driver.core.Session;
import com.datastax.driver.core.TableMetadata;
import com.datastax.driver.core.Token;
import com.datastax.driver.core.TupleValue;
import com.datastax.driver.core.UDTValue;
import com.google.common.reflect.TypeToken;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.hadoop.ColumnFamilySplit;
import org.apache.cassandra.hadoop.ConfigHelper;
import org.apache.cassandra.hadoop.HadoopCompat;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.Pair;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
/**
* <p>
* CqlRecordReader reads the rows return from the CQL query
* It uses CQL auto-paging.
* </p>
* <p>
* Return a Long as a local CQL row key starts from 0;
* </p>
* {@code
* Row as C* java driver CQL result set row
* 1) select clause must include partition key columns (to calculate the progress based on the actual CF row processed)
* 2) where clause must include token(partition_key1, ... , partition_keyn) > ? and
* token(partition_key1, ... , partition_keyn) <= ? (in the right order)
* }
*/
public class CqlRecordReader extends RecordReader<Long, Row>
implements org.apache.hadoop.mapred.RecordReader<Long, Row>, AutoCloseable
{
private static final Logger logger = LoggerFactory.getLogger(CqlRecordReader.class);
private ColumnFamilySplit split;
private RowIterator rowIterator;
private Pair<Long, Row> currentRow;
private int totalRowCount; // total number of rows to fetch
private String keyspace;
private String cfName;
private String cqlQuery;
private Cluster cluster;
private Session session;
private IPartitioner partitioner;
private String inputColumns;
private String userDefinedWhereClauses;
private List<String> partitionKeys = new ArrayList<>();
// partition keys -- key aliases
private LinkedHashMap<String, Boolean> partitionBoundColumns = Maps.newLinkedHashMap();
protected int nativeProtocolVersion = 1;
public CqlRecordReader()
{
super();
}
@Override
public void initialize(InputSplit split, TaskAttemptContext context) throws IOException
{
this.split = (ColumnFamilySplit) split;
Configuration conf = HadoopCompat.getConfiguration(context);
totalRowCount = (this.split.getLength() < Long.MAX_VALUE)
? (int) this.split.getLength()
: ConfigHelper.getInputSplitSize(conf);
cfName = ConfigHelper.getInputColumnFamily(conf);
keyspace = ConfigHelper.getInputKeyspace(conf);
partitioner = ConfigHelper.getInputPartitioner(conf);
inputColumns = CqlConfigHelper.getInputcolumns(conf);
userDefinedWhereClauses = CqlConfigHelper.getInputWhereClauses(conf);
try
{
if (cluster != null)
return;
// create a Cluster instance
String[] locations = split.getLocations();
cluster = CqlConfigHelper.getInputCluster(locations, conf);
}
catch (Exception e)
{
throw new RuntimeException(e);
}
if (cluster != null)
session = cluster.connect(quote(keyspace));
if (session == null)
throw new RuntimeException("Can't create connection session");
//get negotiated serialization protocol
nativeProtocolVersion = cluster.getConfiguration().getProtocolOptions().getProtocolVersion().toInt();
// If the user provides a CQL query then we will use it without validation
// otherwise we will fall back to building a query using the:
// inputColumns
// whereClauses
cqlQuery = CqlConfigHelper.getInputCql(conf);
// validate that the user hasn't tried to give us a custom query along with input columns
// and where clauses
if (StringUtils.isNotEmpty(cqlQuery) && (StringUtils.isNotEmpty(inputColumns) ||
StringUtils.isNotEmpty(userDefinedWhereClauses)))
{
throw new AssertionError("Cannot define a custom query with input columns and / or where clauses");
}
if (StringUtils.isEmpty(cqlQuery))
cqlQuery = buildQuery();
logger.trace("cqlQuery {}", cqlQuery);
rowIterator = new RowIterator();
logger.trace("created {}", rowIterator);
}
public void close()
{
if (session != null)
session.close();
if (cluster != null)
cluster.close();
}
public Long getCurrentKey()
{
return currentRow.left;
}
public Row getCurrentValue()
{
return currentRow.right;
}
public float getProgress()
{
if (!rowIterator.hasNext())
return 1.0F;
// the progress is likely to be reported slightly off the actual but close enough
float progress = ((float) rowIterator.totalRead / totalRowCount);
return progress > 1.0F ? 1.0F : progress;
}
public boolean nextKeyValue() throws IOException
{
if (!rowIterator.hasNext())
{
logger.trace("Finished scanning {} rows (estimate was: {})", rowIterator.totalRead, totalRowCount);
return false;
}
try
{
currentRow = rowIterator.next();
}
catch (Exception e)
{
// throw it as IOException, so client can catch it and handle it at client side
IOException ioe = new IOException(e.getMessage());
ioe.initCause(ioe.getCause());
throw ioe;
}
return true;
}
// Because the old Hadoop API wants us to write to the key and value
// and the new asks for them, we need to copy the output of the new API
// to the old. Thus, expect a small performance hit.
// And obviously this wouldn't work for wide rows. But since ColumnFamilyInputFormat
// and ColumnFamilyRecordReader don't support them, it should be fine for now.
public boolean next(Long key, Row value) throws IOException
{
if (nextKeyValue())
{
((WrappedRow)value).setRow(getCurrentValue());
return true;
}
return false;
}
public long getPos() throws IOException
{
return rowIterator.totalRead;
}
public Long createKey()
{
return Long.valueOf(0L);
}
public Row createValue()
{
return new WrappedRow();
}
/**
* Return native version protocol of the cluster connection
* @return serialization protocol version.
*/
public int getNativeProtocolVersion()
{
return nativeProtocolVersion;
}
/** CQL row iterator
* Input cql query
* 1) select clause must include key columns (if we use partition key based row count)
* 2) where clause must include token(partition_key1 ... partition_keyn) > ? and
* token(partition_key1 ... partition_keyn) <= ?
*/
private class RowIterator extends AbstractIterator<Pair<Long, Row>>
{
private long keyId = 0L;
protected int totalRead = 0; // total number of cf rows read
protected Iterator<Row> rows;
private Map<String, ByteBuffer> previousRowKey = new HashMap<String, ByteBuffer>(); // previous CF row key
public RowIterator()
{
AbstractType type = partitioner.getTokenValidator();
ResultSet rs = session.execute(cqlQuery, type.compose(type.fromString(split.getStartToken())), type.compose(type.fromString(split.getEndToken())) );
for (ColumnMetadata meta : cluster.getMetadata().getKeyspace(quote(keyspace)).getTable(quote(cfName)).getPartitionKey())
partitionBoundColumns.put(meta.getName(), Boolean.TRUE);
rows = rs.iterator();
}
protected Pair<Long, Row> computeNext()
{
if (rows == null || !rows.hasNext())
return endOfData();
Row row = rows.next();
Map<String, ByteBuffer> keyColumns = new HashMap<String, ByteBuffer>(partitionBoundColumns.size());
for (String column : partitionBoundColumns.keySet())
keyColumns.put(column, row.getBytesUnsafe(column));
// increase total CF row read
if (previousRowKey.isEmpty() && !keyColumns.isEmpty())
{
previousRowKey = keyColumns;
totalRead++;
}
else
{
for (String column : partitionBoundColumns.keySet())
{
// this is not correct - but we don't seem to have easy access to better type information here
if (ByteBufferUtil.compareUnsigned(keyColumns.get(column), previousRowKey.get(column)) != 0)
{
previousRowKey = keyColumns;
totalRead++;
break;
}
}
}
keyId ++;
return Pair.create(keyId, row);
}
}
private static class WrappedRow implements Row
{
private Row row;
public void setRow(Row row)
{
this.row = row;
}
@Override
public ColumnDefinitions getColumnDefinitions()
{
return row.getColumnDefinitions();
}
@Override
public boolean isNull(int i)
{
return row.isNull(i);
}
@Override
public boolean isNull(String name)
{
return row.isNull(name);
}
@Override
public Object getObject(int i)
{
return row.getObject(i);
}
@Override
public <T> T get(int i, Class<T> aClass)
{
return row.get(i, aClass);
}
@Override
public <T> T get(int i, TypeToken<T> typeToken)
{
return row.get(i, typeToken);
}
@Override
public <T> T get(int i, TypeCodec<T> typeCodec)
{
return row.get(i, typeCodec);
}
@Override
public Object getObject(String s)
{
return row.getObject(s);
}
@Override
public <T> T get(String s, Class<T> aClass)
{
return row.get(s, aClass);
}
@Override
public <T> T get(String s, TypeToken<T> typeToken)
{
return row.get(s, typeToken);
}
@Override
public <T> T get(String s, TypeCodec<T> typeCodec)
{
return row.get(s, typeCodec);
}
@Override
public boolean getBool(int i)
{
return row.getBool(i);
}
@Override
public boolean getBool(String name)
{
return row.getBool(name);
}
@Override
public short getShort(int i)
{
return row.getShort(i);
}
@Override
public short getShort(String s)
{
return row.getShort(s);
}
@Override
public byte getByte(int i)
{
return row.getByte(i);
}
@Override
public byte getByte(String s)
{
return row.getByte(s);
}
@Override
public int getInt(int i)
{
return row.getInt(i);
}
@Override
public int getInt(String name)
{
return row.getInt(name);
}
@Override
public long getLong(int i)
{
return row.getLong(i);
}
@Override
public long getLong(String name)
{
return row.getLong(name);
}
@Override
public Date getTimestamp(int i)
{
return row.getTimestamp(i);
}
@Override
public Date getTimestamp(String s)
{
return row.getTimestamp(s);
}
@Override
public LocalDate getDate(int i)
{
return row.getDate(i);
}
@Override
public LocalDate getDate(String s)
{
return row.getDate(s);
}
@Override
public long getTime(int i)
{
return row.getTime(i);
}
@Override
public long getTime(String s)
{
return row.getTime(s);
}
@Override
public float getFloat(int i)
{
return row.getFloat(i);
}
@Override
public float getFloat(String name)
{
return row.getFloat(name);
}
@Override
public double getDouble(int i)
{
return row.getDouble(i);
}
@Override
public double getDouble(String name)
{
return row.getDouble(name);
}
@Override
public ByteBuffer getBytesUnsafe(int i)
{
return row.getBytesUnsafe(i);
}
@Override
public ByteBuffer getBytesUnsafe(String name)
{
return row.getBytesUnsafe(name);
}
@Override
public ByteBuffer getBytes(int i)
{
return row.getBytes(i);
}
@Override
public ByteBuffer getBytes(String name)
{
return row.getBytes(name);
}
@Override
public String getString(int i)
{
return row.getString(i);
}
@Override
public String getString(String name)
{
return row.getString(name);
}
@Override
public BigInteger getVarint(int i)
{
return row.getVarint(i);
}
@Override
public BigInteger getVarint(String name)
{
return row.getVarint(name);
}
@Override
public BigDecimal getDecimal(int i)
{
return row.getDecimal(i);
}
@Override
public BigDecimal getDecimal(String name)
{
return row.getDecimal(name);
}
@Override
public UUID getUUID(int i)
{
return row.getUUID(i);
}
@Override
public UUID getUUID(String name)
{
return row.getUUID(name);
}
@Override
public InetAddress getInet(int i)
{
return row.getInet(i);
}
@Override
public InetAddress getInet(String name)
{
return row.getInet(name);
}
@Override
public <T> List<T> getList(int i, Class<T> elementsClass)
{
return row.getList(i, elementsClass);
}
@Override
public <T> List<T> getList(int i, TypeToken<T> typeToken)
{
return row.getList(i, typeToken);
}
@Override
public <T> List<T> getList(String name, Class<T> elementsClass)
{
return row.getList(name, elementsClass);
}
@Override
public <T> List<T> getList(String s, TypeToken<T> typeToken)
{
return row.getList(s, typeToken);
}
@Override
public <T> Set<T> getSet(int i, Class<T> elementsClass)
{
return row.getSet(i, elementsClass);
}
@Override
public <T> Set<T> getSet(int i, TypeToken<T> typeToken)
{
return row.getSet(i, typeToken);
}
@Override
public <T> Set<T> getSet(String name, Class<T> elementsClass)
{
return row.getSet(name, elementsClass);
}
@Override
public <T> Set<T> getSet(String s, TypeToken<T> typeToken)
{
return row.getSet(s, typeToken);
}
@Override
public <K, V> Map<K, V> getMap(int i, Class<K> keysClass, Class<V> valuesClass)
{
return row.getMap(i, keysClass, valuesClass);
}
@Override
public <K, V> Map<K, V> getMap(int i, TypeToken<K> typeToken, TypeToken<V> typeToken1)
{
return row.getMap(i, typeToken, typeToken1);
}
@Override
public <K, V> Map<K, V> getMap(String name, Class<K> keysClass, Class<V> valuesClass)
{
return row.getMap(name, keysClass, valuesClass);
}
@Override
public <K, V> Map<K, V> getMap(String s, TypeToken<K> typeToken, TypeToken<V> typeToken1)
{
return row.getMap(s, typeToken, typeToken1);
}
@Override
public UDTValue getUDTValue(int i)
{
return row.getUDTValue(i);
}
@Override
public UDTValue getUDTValue(String name)
{
return row.getUDTValue(name);
}
@Override
public TupleValue getTupleValue(int i)
{
return row.getTupleValue(i);
}
@Override
public TupleValue getTupleValue(String name)
{
return row.getTupleValue(name);
}
@Override
public Token getToken(int i)
{
return row.getToken(i);
}
@Override
public Token getToken(String name)
{
return row.getToken(name);
}
@Override
public Token getPartitionKeyToken()
{
return row.getPartitionKeyToken();
}
}
/**
* Build a query for the reader of the form:
*
* SELECT * FROM ks>cf token(pk1,...pkn)>? AND token(pk1,...pkn)<=? [AND user where clauses] [ALLOW FILTERING]
*/
private String buildQuery()
{
fetchKeys();
List<String> columns = getSelectColumns();
String selectColumnList = columns.size() == 0 ? "*" : makeColumnList(columns);
String partitionKeyList = makeColumnList(partitionKeys);
return String.format("SELECT %s FROM %s.%s WHERE token(%s)>? AND token(%s)<=?" + getAdditionalWhereClauses(),
selectColumnList, quote(keyspace), quote(cfName), partitionKeyList, partitionKeyList);
}
private String getAdditionalWhereClauses()
{
String whereClause = "";
if (StringUtils.isNotEmpty(userDefinedWhereClauses))
whereClause += " AND " + userDefinedWhereClauses;
if (StringUtils.isNotEmpty(userDefinedWhereClauses))
whereClause += " ALLOW FILTERING";
return whereClause;
}
private List<String> getSelectColumns()
{
List<String> selectColumns = new ArrayList<>();
if (StringUtils.isNotEmpty(inputColumns))
{
// We must select all the partition keys plus any other columns the user wants
selectColumns.addAll(partitionKeys);
for (String column : Splitter.on(',').split(inputColumns))
{
if (!partitionKeys.contains(column))
selectColumns.add(column);
}
}
return selectColumns;
}
private String makeColumnList(Collection<String> columns)
{
return Joiner.on(',').join(Iterables.transform(columns, new Function<String, String>()
{
public String apply(String column)
{
return quote(column);
}
}));
}
private void fetchKeys()
{
// get CF meta data
TableMetadata tableMetadata = session.getCluster()
.getMetadata()
.getKeyspace(Metadata.quote(keyspace))
.getTable(Metadata.quote(cfName));
if (tableMetadata == null)
{
throw new RuntimeException("No table metadata found for " + keyspace + "." + cfName);
}
//Here we assume that tableMetadata.getPartitionKey() always
//returns the list of columns in order of component_index
for (ColumnMetadata partitionKey : tableMetadata.getPartitionKey())
{
partitionKeys.add(partitionKey.getName());
}
}
private String quote(String identifier)
{
return "\"" + identifier.replaceAll("\"", "\"\"") + "\"";
}
}