blob: 61ba5603007975fd95c39f7a20bcb39f32aa5175 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapred.split;
import java.io.IOException;
import java.util.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.classification.InterfaceStability.Evolving;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.mapred.InputFormat;
import org.apache.hadoop.mapred.InputSplit;
import org.apache.hadoop.mapred.JobConf;
import org.apache.hadoop.mapred.RecordReader;
import org.apache.hadoop.mapred.Reporter;
import org.apache.tez.common.ReflectionUtils;
import org.apache.tez.dag.api.TezException;
import org.apache.tez.common.Preconditions;
/**
* An InputFormat that provides a generic grouping around the splits
* of a real InputFormat
*/
@Public
@Evolving
public class TezGroupedSplitsInputFormat<K, V>
implements InputFormat<K, V>, Configurable{
private static final Logger LOG = LoggerFactory.getLogger(TezGroupedSplitsInputFormat.class);
InputFormat<K, V> wrappedInputFormat;
int desiredNumSplits = 0;
Configuration conf;
SplitSizeEstimator estimator;
SplitLocationProvider locationProvider;
public TezGroupedSplitsInputFormat() {
}
public void setInputFormat(InputFormat<K, V> wrappedInputFormat) {
this.wrappedInputFormat = wrappedInputFormat;
if (LOG.isDebugEnabled()) {
LOG.debug("wrappedInputFormat: " + wrappedInputFormat.getClass().getName());
}
}
public void setSplitSizeEstimator(SplitSizeEstimator estimator) {
this.estimator = Objects.requireNonNull(estimator);
LOG.debug("Split size estimator : {}", estimator);
}
public void setSplitLocationProvider(SplitLocationProvider locationProvider) {
this.locationProvider = Objects.requireNonNull(locationProvider);
LOG.debug("Split size location provider: {}", locationProvider);
}
public void setDesiredNumberOfSplits(int num) {
Preconditions.checkArgument(num >= 0);
this.desiredNumSplits = num;
LOG.debug("desiredNumSplits: {}", desiredNumSplits);
}
@Override
public InputSplit[] getSplits(JobConf job, int numSplits) throws IOException {
InputSplit[] originalSplits = wrappedInputFormat.getSplits(job, numSplits);
TezMapredSplitsGrouper grouper = new TezMapredSplitsGrouper();
String wrappedInputFormatName = wrappedInputFormat.getClass().getName();
return grouper
.getGroupedSplits(conf, originalSplits, desiredNumSplits, wrappedInputFormatName, estimator,
locationProvider);
}
@Override
public RecordReader<K, V> getRecordReader(InputSplit split, JobConf job,
Reporter reporter) throws IOException {
TezGroupedSplit groupedSplit = (TezGroupedSplit) split;
try {
initInputFormatFromSplit(groupedSplit);
} catch (TezException e) {
throw new IOException(e);
}
return new TezGroupedSplitsRecordReader(groupedSplit, job, reporter);
}
@SuppressWarnings({ "unchecked", "rawtypes" })
void initInputFormatFromSplit(TezGroupedSplit split) throws TezException {
if (wrappedInputFormat == null) {
Class<? extends InputFormat> clazz = (Class<? extends InputFormat>)
getClassFromName(split.wrappedInputFormatName);
try {
wrappedInputFormat = org.apache.hadoop.util.ReflectionUtils.newInstance(clazz, conf);
} catch (Exception e) {
throw new TezException(e);
}
}
}
static Class<?> getClassFromName(String name) throws TezException {
return ReflectionUtils.getClazz(name);
}
public class TezGroupedSplitsRecordReader implements RecordReader<K, V> {
TezGroupedSplit groupedSplit;
JobConf job;
Reporter reporter;
int idx = 0;
long progress;
RecordReader<K, V> curReader;
public TezGroupedSplitsRecordReader(TezGroupedSplit split, JobConf job,
Reporter reporter) throws IOException {
this.groupedSplit = split;
this.job = job;
this.reporter = reporter;
initNextRecordReader();
}
@Override
public boolean next(K key, V value) throws IOException {
while ((curReader == null) || !curReader.next(key, value)) {
if (!initNextRecordReader()) {
return false;
}
}
return true;
}
@Override
public K createKey() {
return curReader.createKey();
}
@Override
public V createValue() {
return curReader.createValue();
}
@Override
public float getProgress() throws IOException {
return Math.min(1.0f, getPos()/(float)(groupedSplit.getLength()));
}
@Override
public void close() throws IOException {
if (curReader != null) {
curReader.close();
curReader = null;
}
}
protected boolean initNextRecordReader() throws IOException {
if (curReader != null) {
curReader.close();
curReader = null;
if (idx > 0) {
progress += groupedSplit.wrappedSplits.get(idx-1).getLength();
}
}
// if all chunks have been processed, nothing more to do.
if (idx == groupedSplit.wrappedSplits.size()) {
return false;
}
if (LOG.isDebugEnabled()) {
LOG.debug("Init record reader for index " + idx + " of " +
groupedSplit.wrappedSplits.size());
}
// get a record reader for the idx-th chunk
try {
curReader = wrappedInputFormat.getRecordReader(
groupedSplit.wrappedSplits.get(idx), job, reporter);
} catch (Exception e) {
throw new RuntimeException (e);
}
idx++;
return true;
}
@Override
public long getPos() throws IOException {
long subprogress = 0; // bytes processed in current split
if (null != curReader) {
// idx is always one past the current subsplit's true index.
subprogress = curReader.getPos();
}
return (progress + subprogress);
}
}
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return conf;
}
}