blob: 9b37c7844dd28c19b1c93ca0ee6726d922ea7af9 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tez.runtime.library.common.writers;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.BitSet;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataInputBuffer;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.compress.DefaultCodec;
import org.apache.hadoop.util.DiskChecker.DiskErrorException;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezJobConfig;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtils;
import org.apache.tez.common.counters.TaskCounter;
import org.apache.tez.common.counters.TezCounter;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.TezOutputContext;
import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
import org.apache.tez.runtime.library.api.Partitioner;
import org.apache.tez.runtime.library.common.sort.impl.IFile;
import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord;
import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord;
import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutput;
import org.apache.tez.runtime.library.common.task.local.output.TezTaskOutputFiles;
import org.apache.tez.runtime.library.partitioner.HashPartitioner;
import org.apache.tez.runtime.library.shuffle.common.ShuffleUtils;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads.DataMovementEventPayloadProto;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.Sets;
@RunWith(value = Parameterized.class)
public class TestUnorderedPartitionedKVWriter {
private static final Log LOG = LogFactory.getLog(TestUnorderedPartitionedKVWriter.class);
private static final String HOST_STRING = "localhost";
private static final int SHUFFLE_PORT = 4000;
private static String testTmpDir = System.getProperty("test.build.data", "/tmp");
private static final Path TEST_ROOT_DIR = new Path(testTmpDir,
TestUnorderedPartitionedKVWriter.class.getSimpleName());
private static FileSystem localFs;
private boolean shouldCompress;
public TestUnorderedPartitionedKVWriter(boolean shouldCompress) {
this.shouldCompress = shouldCompress;
}
@Parameters
public static Collection<Object[]> data() {
Object[][] data = new Object[][] { { false }, { true } };
return Arrays.asList(data);
}
@Before
public void setup() throws IOException {
LOG.info("Setup. Using test dir: " + TEST_ROOT_DIR);
localFs = FileSystem.getLocal(new Configuration());
localFs.delete(TEST_ROOT_DIR, true);
localFs.mkdirs(TEST_ROOT_DIR);
}
@After
public void cleanup() throws IOException {
LOG.info("CleanUp");
localFs.delete(TEST_ROOT_DIR, true);
}
@Test(timeout = 10000)
public void testBufferSizing() throws IOException {
ApplicationId appId = ApplicationId.newInstance(10000, 1);
TezCounters counters = new TezCounters();
String uniqueId = UUID.randomUUID().toString();
TezOutputContext outputContext = createMockOutputContext(counters, appId, uniqueId);
int maxSingleBufferSizeBytes = 2047;
Configuration conf = createConfiguration(outputContext, IntWritable.class, LongWritable.class,
false, maxSingleBufferSizeBytes);
int numOutputs = 10;
UnorderedPartitionedKVWriter kvWriter = null;
kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, 2048);
assertEquals(2, kvWriter.numBuffers);
assertEquals(1024, kvWriter.sizePerBuffer);
assertEquals(1, kvWriter.numInitializedBuffers);
kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs,
maxSingleBufferSizeBytes * 3);
assertEquals(3, kvWriter.numBuffers);
assertEquals(maxSingleBufferSizeBytes - maxSingleBufferSizeBytes % 4, kvWriter.sizePerBuffer);
assertEquals(1, kvWriter.numInitializedBuffers);
kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs,
maxSingleBufferSizeBytes * 2 + 1);
assertEquals(3, kvWriter.numBuffers);
assertEquals(1364, kvWriter.sizePerBuffer);
assertEquals(1, kvWriter.numInitializedBuffers);
kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext, conf, numOutputs, 10240);
assertEquals(6, kvWriter.numBuffers);
assertEquals(1704, kvWriter.sizePerBuffer);
assertEquals(1, kvWriter.numInitializedBuffers);
}
@Test(timeout = 10000)
public void testNoSpill() throws IOException, InterruptedException {
baseTest(10, 10, null, shouldCompress);
}
@Test(timeout = 10000)
public void testSingleSpill() throws IOException, InterruptedException {
baseTest(50, 10, null, shouldCompress);
}
@Test(timeout = 10000)
public void testMultipleSpills() throws IOException, InterruptedException {
baseTest(200, 10, null, shouldCompress);
}
@Test(timeout = 10000)
public void testNoRecords() throws IOException, InterruptedException {
baseTest(0, 10, null, shouldCompress);
}
@Test(timeout = 10000)
public void testSkippedPartitions() throws IOException, InterruptedException {
baseTest(200, 10, Sets.newHashSet(2, 5), shouldCompress);
}
@Test(timeout = 10000)
public void testRandomText() throws IOException, InterruptedException {
textTest(100, 10, 2048, 0, 0, 0);
}
@Test(timeout = 10000)
public void testLargeKeys() throws IOException, InterruptedException {
textTest(0, 10, 2048, 10, 0, 0);
}
@Test(timeout = 10000)
public void testLargevalues() throws IOException, InterruptedException {
textTest(0, 10, 2048, 0, 10, 0);
}
@Test(timeout = 10000)
public void testLargeKvPairs() throws IOException, InterruptedException {
textTest(0, 10, 2048, 0, 0, 10);
}
@Test(timeout = 10000)
public void testTextMixedRecords() throws IOException, InterruptedException {
textTest(100, 10, 2048, 10, 10, 10);
}
public void textTest(int numRegularRecords, int numPartitions, long availableMemory,
int numLargeKeys, int numLargevalues, int numLargeKvPairs) throws IOException,
InterruptedException {
Partitioner partitioner = new HashPartitioner();
ApplicationId appId = ApplicationId.newInstance(10000, 1);
TezCounters counters = new TezCounters();
String uniqueId = UUID.randomUUID().toString();
TezOutputContext outputContext = createMockOutputContext(counters, appId, uniqueId);
Random random = new Random();
Configuration conf = createConfiguration(outputContext, Text.class, Text.class, shouldCompress,
-1, HashPartitioner.class);
CompressionCodec codec = null;
if (shouldCompress) {
codec = new DefaultCodec();
((Configurable) codec).setConf(conf);
}
int numRecordsWritten = 0;
Map<Integer, Multimap<String, String>> expectedValues = new HashMap<Integer, Multimap<String, String>>();
for (int i = 0; i < numPartitions; i++) {
expectedValues.put(i, LinkedListMultimap.<String, String> create());
}
UnorderedPartitionedKVWriter kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext,
conf, numPartitions, availableMemory);
int sizePerBuffer = kvWriter.sizePerBuffer;
BitSet partitionsWithData = new BitSet(numPartitions);
Text keyText = new Text();
Text valText = new Text();
for (int i = 0; i < numRegularRecords; i++) {
String key = createRandomString(Math.abs(random.nextInt(10)));
String val = createRandomString(Math.abs(random.nextInt(20)));
keyText.set(key);
valText.set(val);
int partition = partitioner.getPartition(keyText, valText, numPartitions);
partitionsWithData.set(partition);
expectedValues.get(partition).put(key, val);
kvWriter.write(keyText, valText);
numRecordsWritten++;
}
// Write Large key records
for (int i = 0; i < numLargeKeys; i++) {
String key = createRandomString(sizePerBuffer + Math.abs(random.nextInt(100)));
String val = createRandomString(Math.abs(random.nextInt(20)));
keyText.set(key);
valText.set(val);
int partition = partitioner.getPartition(keyText, valText, numPartitions);
partitionsWithData.set(partition);
expectedValues.get(partition).put(key, val);
kvWriter.write(keyText, valText);
numRecordsWritten++;
}
// Write Large val records
for (int i = 0; i < numLargevalues; i++) {
String key = createRandomString(Math.abs(random.nextInt(10)));
String val = createRandomString(sizePerBuffer + Math.abs(random.nextInt(100)));
keyText.set(key);
valText.set(val);
int partition = partitioner.getPartition(keyText, valText, numPartitions);
partitionsWithData.set(partition);
expectedValues.get(partition).put(key, val);
kvWriter.write(keyText, valText);
numRecordsWritten++;
}
// Write records where key + val are large (but both can fit in the buffer individually)
for (int i = 0; i < numLargeKvPairs; i++) {
String key = createRandomString(sizePerBuffer / 2 + Math.abs(random.nextInt(100)));
String val = createRandomString(sizePerBuffer / 2 + Math.abs(random.nextInt(100)));
keyText.set(key);
valText.set(val);
int partition = partitioner.getPartition(keyText, valText, numPartitions);
partitionsWithData.set(partition);
expectedValues.get(partition).put(key, val);
kvWriter.write(keyText, valText);
numRecordsWritten++;
}
List<Event> events = kvWriter.close();
verify(outputContext, never()).fatalError(any(Throwable.class), any(String.class));
TezCounter outputLargeRecordsCounter = counters.findCounter(TaskCounter.OUTPUT_LARGE_RECORDS);
assertEquals(numLargeKeys + numLargevalues + numLargeKvPairs,
outputLargeRecordsCounter.getValue());
// Validate the event
assertEquals(1, events.size());
assertTrue(events.get(0) instanceof CompositeDataMovementEvent);
CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0);
assertEquals(0, cdme.getSourceIndexStart());
assertEquals(numPartitions, cdme.getCount());
DataMovementEventPayloadProto eventProto = DataMovementEventPayloadProto.parseFrom(cdme
.getUserPayload());
assertFalse(eventProto.hasData());
BitSet emptyPartitionBits = null;
if (partitionsWithData.cardinality() != numPartitions) {
assertTrue(eventProto.hasEmptyPartitions());
byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(eventProto
.getEmptyPartitions());
emptyPartitionBits = TezUtils.fromByteArray(emptyPartitions);
assertEquals(numPartitions - partitionsWithData.cardinality(),
emptyPartitionBits.cardinality());
} else {
assertFalse(eventProto.hasEmptyPartitions());
emptyPartitionBits = new BitSet(numPartitions);
}
assertEquals(HOST_STRING, eventProto.getHost());
assertEquals(SHUFFLE_PORT, eventProto.getPort());
assertEquals(uniqueId, eventProto.getPathComponent());
// Verify the data
// Verify the actual data
TezTaskOutput taskOutput = new TezTaskOutputFiles(conf, uniqueId);
Path outputFilePath = null;
Path spillFilePath = null;
try {
outputFilePath = taskOutput.getOutputFile();
} catch (DiskErrorException e) {
if (numRecordsWritten > 0) {
fail();
} else {
// Record checking not required.
return;
}
}
try {
spillFilePath = taskOutput.getOutputIndexFile();
} catch (DiskErrorException e) {
if (numRecordsWritten > 0) {
fail();
} else {
// Record checking not required.
return;
}
}
// Special case for 0 records.
TezSpillRecord spillRecord = new TezSpillRecord(spillFilePath, conf);
DataInputBuffer keyBuffer = new DataInputBuffer();
DataInputBuffer valBuffer = new DataInputBuffer();
Text keyDeser = new Text();
Text valDeser = new Text();
for (int i = 0; i < numPartitions; i++) {
if (emptyPartitionBits.get(i)) {
continue;
}
TezIndexRecord indexRecord = spillRecord.getIndex(i);
FSDataInputStream inStream = FileSystem.getLocal(conf).open(outputFilePath);
inStream.seek(indexRecord.getStartOffset());
IFile.Reader reader = new IFile.Reader(inStream, indexRecord.getPartLength(), codec, null,
null, false, 0, -1);
while (reader.nextRawKey(keyBuffer)) {
reader.nextRawValue(valBuffer);
keyDeser.readFields(keyBuffer);
valDeser.readFields(valBuffer);
int partition = partitioner.getPartition(keyDeser, valDeser, numPartitions);
assertTrue(expectedValues.get(partition).remove(keyDeser.toString(), valDeser.toString()));
}
inStream.close();
}
for (int i = 0; i < numPartitions; i++) {
assertEquals(0, expectedValues.get(i).size());
expectedValues.remove(i);
}
assertEquals(0, expectedValues.size());
}
private void baseTest(int numRecords, int numPartitions, Set<Integer> skippedPartitions,
boolean shouldCompress) throws IOException, InterruptedException {
PartitionerForTest partitioner = new PartitionerForTest();
ApplicationId appId = ApplicationId.newInstance(10000, 1);
TezCounters counters = new TezCounters();
String uniqueId = UUID.randomUUID().toString();
TezOutputContext outputContext = createMockOutputContext(counters, appId, uniqueId);
Configuration conf = createConfiguration(outputContext, IntWritable.class, LongWritable.class,
shouldCompress, -1);
CompressionCodec codec = null;
if (shouldCompress) {
codec = new DefaultCodec();
((Configurable) codec).setConf(conf);
}
int numOutputs = numPartitions;
long availableMemory = 2048;
int numRecordsWritten = 0;
Map<Integer, Multimap<Integer, Long>> expectedValues = new HashMap<Integer, Multimap<Integer, Long>>();
for (int i = 0; i < numOutputs; i++) {
expectedValues.put(i, LinkedListMultimap.<Integer, Long> create());
}
UnorderedPartitionedKVWriter kvWriter = new UnorderedPartitionedKVWriterForTest(outputContext,
conf, numOutputs, availableMemory);
int sizePerBuffer = kvWriter.sizePerBuffer;
int sizePerRecord = 4 + 8; // IntW + LongW
int sizePerRecordWithOverhead = sizePerRecord + 12; // Record + META_OVERHEAD
IntWritable intWritable = new IntWritable();
LongWritable longWritable = new LongWritable();
for (int i = 0; i < numRecords; i++) {
intWritable.set(i);
longWritable.set(i);
int partition = partitioner.getPartition(intWritable, longWritable, numOutputs);
if (skippedPartitions != null && skippedPartitions.contains(partition)) {
continue;
}
expectedValues.get(partition).put(intWritable.get(), longWritable.get());
kvWriter.write(intWritable, longWritable);
numRecordsWritten++;
}
List<Event> events = kvWriter.close();
int recordsPerBuffer = sizePerBuffer / sizePerRecordWithOverhead;
int numExpectedSpills = numRecordsWritten / recordsPerBuffer;
verify(outputContext, never()).fatalError(any(Throwable.class), any(String.class));
// Verify the status of the buffers
if (numExpectedSpills == 0) {
assertEquals(1, kvWriter.numInitializedBuffers);
} else {
assertTrue(kvWriter.numInitializedBuffers > 1);
}
assertNull(kvWriter.currentBuffer);
assertEquals(0, kvWriter.availableBuffers.size());
// Verify the counters
TezCounter outputRecordBytesCounter = counters.findCounter(TaskCounter.OUTPUT_BYTES);
TezCounter outputRecordsCounter = counters.findCounter(TaskCounter.OUTPUT_RECORDS);
TezCounter outputBytesWithOverheadCounter = counters
.findCounter(TaskCounter.OUTPUT_BYTES_WITH_OVERHEAD);
TezCounter fileOutputBytesCounter = counters.findCounter(TaskCounter.OUTPUT_BYTES_PHYSICAL);
TezCounter spilledRecordsCounter = counters.findCounter(TaskCounter.SPILLED_RECORDS);
TezCounter additionalSpillBytesWritternCounter = counters
.findCounter(TaskCounter.ADDITIONAL_SPILLS_BYTES_WRITTEN);
TezCounter additionalSpillBytesReadCounter = counters
.findCounter(TaskCounter.ADDITIONAL_SPILLS_BYTES_READ);
TezCounter numAdditionalSpillsCounter = counters
.findCounter(TaskCounter.ADDITIONAL_SPILL_COUNT);
assertEquals(numRecordsWritten * sizePerRecord, outputRecordBytesCounter.getValue());
assertEquals(numRecordsWritten, outputRecordsCounter.getValue());
assertEquals(numRecordsWritten * sizePerRecordWithOverhead,
outputBytesWithOverheadCounter.getValue());
long fileOutputBytes = fileOutputBytesCounter.getValue();
if (numRecordsWritten > 0) {
assertTrue(fileOutputBytes > 0);
if (!shouldCompress) {
assertTrue(fileOutputBytes > outputRecordBytesCounter.getValue());
}
} else {
assertEquals(0, fileOutputBytes);
}
assertEquals(recordsPerBuffer * numExpectedSpills, spilledRecordsCounter.getValue());
long additionalSpillBytesWritten = additionalSpillBytesWritternCounter.getValue();
long additionalSpillBytesRead = additionalSpillBytesReadCounter.getValue();
if (numExpectedSpills == 0) {
assertEquals(0, additionalSpillBytesWritten);
assertEquals(0, additionalSpillBytesRead);
} else {
assertTrue(additionalSpillBytesWritten > 0);
assertTrue(additionalSpillBytesRead > 0);
if (!shouldCompress) {
assertTrue(additionalSpillBytesWritten > (recordsPerBuffer * numExpectedSpills * sizePerRecord));
assertTrue(additionalSpillBytesRead > (recordsPerBuffer * numExpectedSpills * sizePerRecord));
}
}
assertTrue(additionalSpillBytesWritten == additionalSpillBytesRead);
assertEquals(numExpectedSpills, numAdditionalSpillsCounter.getValue());
BitSet emptyPartitionBits = null;
// Verify the event returned
assertEquals(1, events.size());
assertTrue(events.get(0) instanceof CompositeDataMovementEvent);
CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0);
assertEquals(0, cdme.getSourceIndexStart());
assertEquals(numOutputs, cdme.getCount());
DataMovementEventPayloadProto eventProto = DataMovementEventPayloadProto.parseFrom(cdme
.getUserPayload());
assertFalse(eventProto.hasData());
if (skippedPartitions == null && numRecordsWritten > 0) {
assertFalse(eventProto.hasEmptyPartitions());
emptyPartitionBits = new BitSet(numPartitions);
} else {
assertTrue(eventProto.hasEmptyPartitions());
byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(eventProto
.getEmptyPartitions());
emptyPartitionBits = TezUtils.fromByteArray(emptyPartitions);
if (numRecordsWritten == 0) {
assertEquals(numPartitions, emptyPartitionBits.cardinality());
} else {
for (Integer e : skippedPartitions) {
assertTrue(emptyPartitionBits.get(e));
}
assertEquals(skippedPartitions.size(), emptyPartitionBits.cardinality());
}
}
if (emptyPartitionBits.cardinality() != numPartitions) {
assertEquals(HOST_STRING, eventProto.getHost());
assertEquals(SHUFFLE_PORT, eventProto.getPort());
assertEquals(uniqueId, eventProto.getPathComponent());
} else {
assertFalse(eventProto.hasHost());
assertFalse(eventProto.hasPort());
assertFalse(eventProto.hasPathComponent());
}
// Verify the actual data
TezTaskOutput taskOutput = new TezTaskOutputFiles(conf, uniqueId);
Path outputFilePath = null;
Path spillFilePath = null;
try {
outputFilePath = taskOutput.getOutputFile();
} catch (DiskErrorException e) {
if (numRecordsWritten > 0) {
fail();
} else {
// Record checking not required.
return;
}
}
try {
spillFilePath = taskOutput.getOutputIndexFile();
} catch (DiskErrorException e) {
if (numRecordsWritten > 0) {
fail();
} else {
// Record checking not required.
return;
}
}
// Special case for 0 records.
TezSpillRecord spillRecord = new TezSpillRecord(spillFilePath, conf);
DataInputBuffer keyBuffer = new DataInputBuffer();
DataInputBuffer valBuffer = new DataInputBuffer();
IntWritable keyDeser = new IntWritable();
LongWritable valDeser = new LongWritable();
for (int i = 0; i < numOutputs; i++) {
if (skippedPartitions != null && skippedPartitions.contains(i)) {
continue;
}
TezIndexRecord indexRecord = spillRecord.getIndex(i);
FSDataInputStream inStream = FileSystem.getLocal(conf).open(outputFilePath);
inStream.seek(indexRecord.getStartOffset());
IFile.Reader reader = new IFile.Reader(inStream, indexRecord.getPartLength(), codec, null,
null, false, 0, -1);
while (reader.nextRawKey(keyBuffer)) {
reader.nextRawValue(valBuffer);
keyDeser.readFields(keyBuffer);
valDeser.readFields(valBuffer);
int partition = partitioner.getPartition(keyDeser, valDeser, numOutputs);
assertTrue(expectedValues.get(partition).remove(keyDeser.get(), valDeser.get()));
}
inStream.close();
}
for (int i = 0; i < numOutputs; i++) {
assertEquals(0, expectedValues.get(i).size());
expectedValues.remove(i);
}
assertEquals(0, expectedValues.size());
}
private static String createRandomString(int size) {
StringBuilder sb = new StringBuilder();
Random random = new Random();
for (int i = 0; i < size; i++) {
int r = Math.abs(random.nextInt()) % 26;
sb.append((char) (65 + r));
}
return sb.toString();
}
private TezOutputContext createMockOutputContext(TezCounters counters, ApplicationId appId,
String uniqueId) {
TezOutputContext outputContext = mock(TezOutputContext.class);
doReturn(counters).when(outputContext).getCounters();
doReturn(appId).when(outputContext).getApplicationId();
doReturn(1).when(outputContext).getDAGAttemptNumber();
doReturn("dagName").when(outputContext).getDAGName();
doReturn("destinationVertexName").when(outputContext).getDestinationVertexName();
doReturn(1).when(outputContext).getOutputIndex();
doReturn(1).when(outputContext).getTaskAttemptNumber();
doReturn(1).when(outputContext).getTaskIndex();
doReturn(1).when(outputContext).getTaskVertexIndex();
doReturn("vertexName").when(outputContext).getTaskVertexName();
doReturn(uniqueId).when(outputContext).getUniqueIdentifier();
ByteBuffer portBuffer = ByteBuffer.allocate(4);
portBuffer.mark();
portBuffer.putInt(SHUFFLE_PORT);
portBuffer.reset();
doReturn(portBuffer).when(outputContext).getServiceProviderMetaData(
eq(ShuffleUtils.SHUFFLE_HANDLER_SERVICE_ID));
Path outDirBase = new Path(TEST_ROOT_DIR, "outDir_" + uniqueId);
String[] outDirs = new String[] { outDirBase.toString() };
doReturn(outDirs).when(outputContext).getWorkDirs();
return outputContext;
}
private Configuration createConfiguration(TezOutputContext outputContext,
Class<? extends Writable> keyClass, Class<? extends Writable> valClass,
boolean shouldCompress, int maxSingleBufferSizeBytes) {
return createConfiguration(outputContext, keyClass, valClass, shouldCompress,
maxSingleBufferSizeBytes, PartitionerForTest.class);
}
private Configuration createConfiguration(TezOutputContext outputContext,
Class<? extends Writable> keyClass, Class<? extends Writable> valClass,
boolean shouldCompress, int maxSingleBufferSizeBytes,
Class<? extends Partitioner> partitionerClass) {
Configuration conf = new Configuration(false);
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, outputContext.getWorkDirs());
conf.set(TezJobConfig.TEZ_RUNTIME_KEY_CLASS, keyClass.getName());
conf.set(TezJobConfig.TEZ_RUNTIME_VALUE_CLASS, valClass.getName());
conf.set(TezJobConfig.TEZ_RUNTIME_PARTITIONER_CLASS, partitionerClass.getName());
if (maxSingleBufferSizeBytes >= 0) {
conf.setInt(TezJobConfig.TEZ_RUNTIME_UNORDERED_OUTPUT_MAX_PER_BUFFER_SIZE_BYTES,
maxSingleBufferSizeBytes);
}
conf.setBoolean(TezJobConfig.TEZ_RUNTIME_COMPRESS, shouldCompress);
if (shouldCompress) {
conf.set(TezJobConfig.TEZ_RUNTIME_COMPRESS_CODEC,
DefaultCodec.class.getName());
}
return conf;
}
public static class PartitionerForTest implements Partitioner {
@Override
public int getPartition(Object key, Object value, int numPartitions) {
if (key instanceof IntWritable) {
return ((IntWritable) key).get() % numPartitions;
} else {
throw new UnsupportedOperationException(
"Test partitioner expected to be called with IntWritable only");
}
}
}
private static class UnorderedPartitionedKVWriterForTest extends UnorderedPartitionedKVWriter {
public UnorderedPartitionedKVWriterForTest(TezOutputContext outputContext, Configuration conf,
int numOutputs, long availableMemoryBytes) throws IOException {
super(outputContext, conf, numOutputs, availableMemoryBytes);
}
@Override
String getHost() {
return HOST_STRING;
}
}
}