blob: 49d4043b5cd1767c74485eddfa3a3f38fcc79868 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.input;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.Set;
import org.apache.tez.runtime.api.ProgressFailedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hadoop.classification.InterfaceAudience.Public;
import org.apache.hadoop.io.RawComparator;
import org.apache.tez.runtime.api.Input;
import org.apache.tez.runtime.api.MergedLogicalInput;
import org.apache.tez.runtime.api.MergedInputContext;
import org.apache.tez.runtime.library.api.KeyValuesReader;
/**
* A {@link MergedLogicalInput} which merges multiple
* {@link OrderedGroupedKVInput}s and returns a single view of these by merging
* values which belong to the same key.
*
* Combiners and Secondary Sort are not implemented, so there is no guarantee on
* the order of values.
*/
@Public
public class OrderedGroupedMergedKVInput extends MergedLogicalInput {
private static final Logger LOG = LoggerFactory.getLogger(OrderedGroupedMergedKVInput.class);
private final Set<Input> completedInputs = Collections
.newSetFromMap(new IdentityHashMap<Input, Boolean>());
public OrderedGroupedMergedKVInput(MergedInputContext context, List<Input> inputs) {
super(context, inputs);
}
/**
* Provides an ordered {@link KeyValuesReader}
*/
@Override
public KeyValuesReader getReader() throws Exception {
return new OrderedGroupedMergedKeyValuesReader(getInputs(), getContext());
}
@Override
public void setConstituentInputIsReady(Input input) {
synchronized (completedInputs) {
completedInputs.add(input);
if (completedInputs.size() == getInputs().size()) {
informInputReady();
}
}
}
private static class OrderedGroupedMergedKeyValuesReader extends KeyValuesReader {
private final PriorityQueue<KeyValuesReader> pQueue;
@SuppressWarnings("rawtypes")
private final RawComparator keyComparator;
private final List<KeyValuesReader> finishedReaders;
private final ValuesIterable currentValues;
private KeyValuesReader nextKVReader;
private Object currentKey;
private final MergedInputContext context;
public OrderedGroupedMergedKeyValuesReader(List<Input> inputs, MergedInputContext context)
throws Exception {
keyComparator = ((OrderedGroupedKVInput) inputs.get(0))
.getInputKeyComparator();
pQueue = new PriorityQueue<KeyValuesReader>(inputs.size(),
new KVReaderComparator(keyComparator));
finishedReaders = new ArrayList<KeyValuesReader>(inputs.size());
for (Input input : inputs) {
KeyValuesReader reader = (KeyValuesReader) input.getReader();
if (reader.next()) {
pQueue.add(reader);
}
}
currentValues = new ValuesIterable();
this.context = context;
}
private void advanceAndAddToQueue(KeyValuesReader kvsReadr)
throws IOException {
if (kvsReadr.next()) {
pQueue.add(kvsReadr);
}
}
private void addToQueue(KeyValuesReader kvsReadr) throws IOException {
if (kvsReadr != null) {
pQueue.add(kvsReadr);
}
}
@Override
public boolean next() throws IOException {
// Skip values of current key if not consumed by the user
currentValues.discardCurrent();
for (KeyValuesReader reader : finishedReaders) {
// add them back to queue
advanceAndAddToQueue(reader);
}
finishedReaders.clear();
nextKVReader = pQueue.poll();
context.notifyProgress();
if (nextKVReader != null) {
currentKey = nextKVReader.getCurrentKey();
currentValues.moveToNext();
return true;
} else {
hasCompletedProcessing();
completedProcessing = true;
}
return false;
}
@Override
public Object getCurrentKey() throws IOException {
return currentKey;
}
@Override
public Iterable<Object> getCurrentValues() throws IOException {
return currentValues;
}
private class ValuesIterable implements Iterable<Object> {
private ValuesIterator iterator = new ValuesIterator();
@Override
public Iterator<Object> iterator() {
return iterator;
}
public void discardCurrent() throws IOException {
iterator.discardCurrent();
}
public void moveToNext() throws IOException {
iterator.moveToNext();
}
}
@SuppressWarnings("unchecked")
private class ValuesIterator implements Iterator<Object> {
private Iterator<Object> currentValuesIter;
public void moveToNext() throws IOException {
currentValuesIter = nextKVReader.getCurrentValues().iterator();
}
@Override
public boolean hasNext() {
if (currentValuesIter != null) { // No current key. next needs to be called.
if (currentValuesIter.hasNext()) {
return true;
} else {
finishedReaders.add(nextKVReader);
nextKVReader = pQueue.poll();
try {
if (nextKVReader != null
&& keyComparator.compare(currentKey, nextKVReader.getCurrentKey()) == 0) {
currentValuesIter = nextKVReader.getCurrentValues().iterator();
return true;
} else { // key changed or no more data.
// Add the reader back to queue
addToQueue(nextKVReader);
currentValuesIter = null;
return false;
}
} catch (IOException e) {
throw new RuntimeException(e);
}
}
} else {
return false;
}
}
public void discardCurrent() throws IOException {
if (currentValuesIter != null) {
do {
finishedReaders.add(nextKVReader);
nextKVReader = pQueue.poll();
} while (nextKVReader != null
&& keyComparator.compare(currentKey, nextKVReader.getCurrentKey()) == 0);
addToQueue(nextKVReader);
currentValuesIter = null;
}
}
@Override
public Object next() {
return currentValuesIter.next();
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}
/**
* Comparator that compares KeyValuesReader on their current key
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
private static class KVReaderComparator implements
Comparator<KeyValuesReader> {
private RawComparator keyComparator;
public KVReaderComparator(RawComparator keyComparator) {
this.keyComparator = keyComparator;
}
@Override
public int compare(KeyValuesReader o1, KeyValuesReader o2) {
try {
return keyComparator.compare(o1.getCurrentKey(), o2.getCurrentKey());
} catch (IOException e) {
LOG.error("Caught exception while comparing keys in shuffle input", e);
throw new RuntimeException(e);
}
}
}
}
public float getProgress() throws ProgressFailedException, InterruptedException {
float totalProgress = 0.0f;
for(Input input : getInputs()) {
totalProgress += ((OrderedGroupedKVInput)input).getProgress();
}
return (1.0f) * totalProgress/getInputs().size();
}
}