blob: 446801ac2fc39c4a08073d33a7dcfc35511468a1 [file] [log] [blame]
package org.apache.tez.runtime.library.common.shuffle;
import com.google.common.collect.Lists;
import com.google.protobuf.ByteString;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.DataOutputBuffer;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.compress.CompressionCodec;
import org.apache.hadoop.io.compress.CompressionInputStream;
import org.apache.hadoop.io.compress.CompressionOutputStream;
import org.apache.hadoop.io.compress.Compressor;
import org.apache.hadoop.io.compress.Decompressor;
import org.apache.hadoop.yarn.api.records.ApplicationId;
import org.apache.tez.common.TezCommonUtils;
import org.apache.tez.common.TezRuntimeFrameworkConfigs;
import org.apache.tez.common.TezUtilsInternal;
import org.apache.tez.common.counters.TezCounters;
import org.apache.tez.dag.api.TezConfiguration;
import org.apache.tez.runtime.api.Event;
import org.apache.tez.runtime.api.InputContext;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.api.events.CompositeDataMovementEvent;
import org.apache.tez.runtime.api.events.VertexManagerEvent;
import org.apache.tez.runtime.api.impl.ExecutionContextImpl;
import org.apache.tez.runtime.library.api.TezRuntimeConfiguration;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils.FetchStatsLogger;
import org.apache.tez.runtime.library.common.sort.impl.TezIndexRecord;
import org.apache.tez.runtime.library.common.sort.impl.TezSpillRecord;
import org.apache.tez.runtime.library.partitioner.HashPartitioner;
import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.mockito.Matchers;
import org.slf4j.Logger;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.BitSet;
import java.util.concurrent.ThreadLocalRandom;
import java.util.List;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.anyInt;
import static org.mockito.Mockito.anyString;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p/>
* http://www.apache.org/licenses/LICENSE-2.0
* <p/>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
public class TestShuffleUtils {
private OutputContext outputContext;
private Configuration conf;
private FileSystem localFs;
private Path workingDir;
private InputContext createTezInputContext() {
ApplicationId applicationId = ApplicationId.newInstance(1, 1);
InputContext inputContext = mock(InputContext.class);
doReturn(applicationId).when(inputContext).getApplicationId();
doReturn("sourceVertex").when(inputContext).getSourceVertexName();
when(inputContext.getCounters()).thenReturn(new TezCounters());
return inputContext;
}
private OutputContext createTezOutputContext() throws IOException {
ApplicationId applicationId = ApplicationId.newInstance(1, 1);
OutputContext outputContext = mock(OutputContext.class);
ExecutionContextImpl executionContext = mock(ExecutionContextImpl.class);
doReturn("localhost").when(executionContext).getHostName();
doReturn(executionContext).when(outputContext).getExecutionContext();
DataOutputBuffer serviceProviderMetaData = new DataOutputBuffer();
serviceProviderMetaData.writeInt(80);
doReturn(ByteBuffer.wrap(serviceProviderMetaData.getData())).when(outputContext)
.getServiceProviderMetaData
(conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT));
doReturn(1).when(outputContext).getTaskVertexIndex();
doReturn(1).when(outputContext).getOutputIndex();
doReturn(0).when(outputContext).getDAGAttemptNumber();
doReturn("destVertex").when(outputContext).getDestinationVertexName();
when(outputContext.getCounters()).thenReturn(new TezCounters());
return outputContext;
}
@Before
public void setup() throws Exception {
conf = new Configuration();
outputContext = createTezOutputContext();
conf.set("fs.defaultFS", "file:///");
localFs = FileSystem.getLocal(conf);
workingDir = new Path(
new Path(System.getProperty("test.build.data", "/tmp")),
TestShuffleUtils.class.getName())
.makeQualified(localFs.getUri(), localFs.getWorkingDirectory());
String localDirs = workingDir.toString();
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_KEY_CLASS, Text.class.getName());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, Text.class.getName());
conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_PARTITIONER_CLASS,
HashPartitioner.class.getName());
conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, localDirs);
}
private Path createIndexFile(int numPartitions, boolean allEmptyPartitions) throws IOException {
Path path = new Path(workingDir, "file.index.out");
TezSpillRecord spillRecord = new TezSpillRecord(numPartitions);
long startOffset = 0;
long partLen = 200; //compressed
for(int i=0;i<numPartitions;i++) {
long rawLen = ThreadLocalRandom.current().nextLong(100, 200);
if (i % 2 == 0 || allEmptyPartitions) {
rawLen = 0; //indicates empty partition, see TEZ-3605
}
TezIndexRecord indexRecord = new TezIndexRecord(startOffset, rawLen, partLen);
startOffset += partLen;
spillRecord.putIndex(indexRecord, i);
}
spillRecord.writeToFile(path, conf, FileSystem.getLocal(conf).getRaw());
return path;
}
@Test
public void testGenerateOnSpillEvent() throws Exception {
List<Event> events = Lists.newLinkedList();
Path indexFile = createIndexFile(10, false);
boolean finalMergeEnabled = false;
boolean isLastEvent = false;
int spillId = 0;
int physicalOutputs = 10;
String pathComponent = "/attempt_x_y_0/file.out";
String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
ShuffleUtils.generateEventOnSpill(events, finalMergeEnabled, isLastEvent,
outputContext, spillId, new TezSpillRecord(indexFile, conf),
physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater());
Assert.assertTrue(events.size() == 1);
Assert.assertTrue(events.get(0) instanceof CompositeDataMovementEvent);
CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(0);
Assert.assertTrue(cdme.getCount() == physicalOutputs);
Assert.assertTrue(cdme.getSourceIndexStart() == 0);
ByteBuffer payload = cdme.getUserPayload();
ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto =
ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom(payload));
Assert.assertTrue(dmeProto.getSpillId() == 0);
Assert.assertTrue(dmeProto.hasLastEvent() && !dmeProto.getLastEvent());
byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto.getEmptyPartitions());
BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions);
Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 5) = " + emptyPartitionsBitSet
.cardinality(), emptyPartitionsBitSet.cardinality() == 5);
events.clear();
}
@Test
public void testGenerateOnSpillEvent_With_FinalMerge() throws Exception {
List<Event> events = Lists.newLinkedList();
Path indexFile = createIndexFile(10, false);
boolean finalMergeEnabled = true;
boolean isLastEvent = true;
int spillId = 0;
int physicalOutputs = 10;
String pathComponent = "/attempt_x_y_0/file.out";
String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
//normal code path where we do final merge all the time
ShuffleUtils.generateEventOnSpill(events, finalMergeEnabled, isLastEvent,
outputContext, spillId, new TezSpillRecord(indexFile, conf),
physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater());
Assert.assertTrue(events.size() == 2); //one for VM
Assert.assertTrue(events.get(0) instanceof VertexManagerEvent);
Assert.assertTrue(events.get(1) instanceof CompositeDataMovementEvent);
CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(1);
Assert.assertTrue(cdme.getCount() == physicalOutputs);
Assert.assertTrue(cdme.getSourceIndexStart() == 0);
ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto =
ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom( cdme.getUserPayload()));
//With final merge, spill details should not be present
Assert.assertFalse(dmeProto.hasSpillId());
Assert.assertFalse(dmeProto.hasLastEvent() || dmeProto.getLastEvent());
byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto
.getEmptyPartitions());
BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions);
Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 5) = " + emptyPartitionsBitSet
.cardinality(), emptyPartitionsBitSet.cardinality() == 5);
}
@Test
public void testGenerateOnSpillEvent_With_All_EmptyPartitions() throws Exception {
List<Event> events = Lists.newLinkedList();
//Create an index file with all empty partitions
Path indexFile = createIndexFile(10, true);
boolean finalMergeDisabled = false;
boolean isLastEvent = true;
int spillId = 0;
int physicalOutputs = 10;
String pathComponent = "/attempt_x_y_0/file.out";
String auxiliaryService = conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID,
TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT);
//normal code path where we do final merge all the time
ShuffleUtils.generateEventOnSpill(events, finalMergeDisabled, isLastEvent,
outputContext, spillId, new TezSpillRecord(indexFile, conf),
physicalOutputs, true, pathComponent, null, false, auxiliaryService, TezCommonUtils.newBestCompressionDeflater());
Assert.assertTrue(events.size() == 2); //one for VM
Assert.assertTrue(events.get(0) instanceof VertexManagerEvent);
Assert.assertTrue(events.get(1) instanceof CompositeDataMovementEvent);
CompositeDataMovementEvent cdme = (CompositeDataMovementEvent) events.get(1);
Assert.assertTrue(cdme.getCount() == physicalOutputs);
Assert.assertTrue(cdme.getSourceIndexStart() == 0);
ShuffleUserPayloads.DataMovementEventPayloadProto dmeProto =
ShuffleUserPayloads.DataMovementEventPayloadProto.parseFrom(ByteString.copyFrom( cdme.getUserPayload()));
//spill details should be present
Assert.assertTrue(dmeProto.getSpillId() == 0);
Assert.assertTrue(dmeProto.hasLastEvent() && dmeProto.getLastEvent());
Assert.assertTrue(dmeProto.getPathComponent().equals(""));
byte[] emptyPartitions = TezCommonUtils.decompressByteStringToByteArray(dmeProto
.getEmptyPartitions());
BitSet emptyPartitionsBitSet = TezUtilsInternal.fromByteArray(emptyPartitions);
Assert.assertTrue("emptyPartitionBitSet cardinality (expecting 10) = " + emptyPartitionsBitSet
.cardinality(), emptyPartitionsBitSet.cardinality() == 10);
}
@Test
public void testInternalErrorTranslation() throws Exception {
String codecErrorMsg = "codec failure";
CompressionInputStream mockCodecStream = mock(CompressionInputStream.class);
when(mockCodecStream.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new InternalError(codecErrorMsg));
Decompressor mockDecoder = mock(Decompressor.class);
CompressionCodec mockCodec = mock(ConfigurableCodecForTest.class);
when(((ConfigurableCodecForTest) mockCodec).getConf()).thenReturn(mock(Configuration.class));
when(mockCodec.createDecompressor()).thenReturn(mockDecoder);
when(mockCodec.createInputStream(any(InputStream.class), any(Decompressor.class)))
.thenReturn(mockCodecStream);
byte[] header = new byte[] { (byte) 'T', (byte) 'I', (byte) 'F', (byte) 1};
try {
ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header),
1024, 128, mockCodec, false, 0, mock(Logger.class), null);
Assert.fail("shuffle was supposed to throw!");
} catch (IOException e) {
Assert.assertTrue(e.getCause() instanceof InternalError);
Assert.assertTrue(e.getMessage().contains(codecErrorMsg));
}
}
@Test
public void testExceptionTranslation() throws Exception {
String codecErrorMsg = "codec failure";
CompressionInputStream mockCodecStream = mock(CompressionInputStream.class);
when(mockCodecStream.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new IllegalArgumentException(codecErrorMsg));
Decompressor mockDecoder = mock(Decompressor.class);
CompressionCodec mockCodec = mock(ConfigurableCodecForTest.class);
when(((ConfigurableCodecForTest) mockCodec).getConf()).thenReturn(mock(Configuration.class));
when(mockCodec.createDecompressor()).thenReturn(mockDecoder);
when(mockCodec.createInputStream(any(InputStream.class), any(Decompressor.class)))
.thenReturn(mockCodecStream);
byte[] header = new byte[] { (byte) 'T', (byte) 'I', (byte) 'F', (byte) 1};
try {
ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header),
1024, 128, mockCodec, false, 0, mock(Logger.class), null);
Assert.fail("shuffle was supposed to throw!");
} catch (IOException e) {
Assert.assertTrue(e.getCause() instanceof IllegalArgumentException);
Assert.assertTrue(e.getMessage().contains(codecErrorMsg));
}
CompressionInputStream mockCodecStream1 = mock(CompressionInputStream.class);
when(mockCodecStream1.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new SocketTimeoutException(codecErrorMsg));
CompressionCodec mockCodec1 = mock(ConfigurableCodecForTest.class);
when(((ConfigurableCodecForTest) mockCodec1).getConf()).thenReturn(mock(Configuration.class));
when(mockCodec1.createDecompressor()).thenReturn(mockDecoder);
when(mockCodec1.createInputStream(any(InputStream.class), any(Decompressor.class)))
.thenReturn(mockCodecStream1);
try {
ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header),
1024, 128, mockCodec1, false, 0, mock(Logger.class), null);
Assert.fail("shuffle was supposed to throw!");
} catch (IOException e) {
Assert.assertTrue(e instanceof SocketTimeoutException);
Assert.assertTrue(e.getMessage().contains(codecErrorMsg));
}
CompressionInputStream mockCodecStream2 = mock(CompressionInputStream.class);
when(mockCodecStream2.read(any(byte[].class), anyInt(), anyInt()))
.thenThrow(new InternalError(codecErrorMsg));
CompressionCodec mockCodec2 = mock(ConfigurableCodecForTest.class);
when(((ConfigurableCodecForTest) mockCodec2).getConf()).thenReturn(mock(Configuration.class));
when(mockCodec2.createDecompressor()).thenReturn(mockDecoder);
when(mockCodec2.createInputStream(any(InputStream.class), any(Decompressor.class)))
.thenReturn(mockCodecStream2);
try {
ShuffleUtils.shuffleToMemory(new byte[1024], new ByteArrayInputStream(header),
1024, 128, mockCodec2, false, 0, mock(Logger.class), null);
Assert.fail("shuffle was supposed to throw!");
} catch (IOException e) {
Assert.assertTrue(e.getCause() instanceof InternalError);
Assert.assertTrue(e.getMessage().contains(codecErrorMsg));
}
}
@Test
public void testShuffleToDiskChecksum() throws Exception {
// verify sending a stream of zeroes without checksum validation
// does not trigger an exception
byte[] bogusData = new byte[1000];
Arrays.fill(bogusData, (byte) 0);
ByteArrayInputStream in = new ByteArrayInputStream(bogusData);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ShuffleUtils.shuffleToDisk(baos, "somehost", in,
bogusData.length, 2000, mock(Logger.class), null, false, 0, false);
Assert.assertArrayEquals(bogusData, baos.toByteArray());
// verify sending same stream of zeroes with validation generates an exception
in.reset();
try {
ShuffleUtils.shuffleToDisk(mock(OutputStream.class), "somehost", in,
bogusData.length, 2000, mock(Logger.class), null, false, 0, true);
Assert.fail("shuffle was supposed to throw!");
} catch (IOException e) {
}
}
@Test
public void testFetchStatsLogger() throws Exception {
Logger activeLogger = mock(Logger.class);
Logger aggregateLogger = mock(Logger.class);
FetchStatsLogger logger = new FetchStatsLogger(activeLogger, aggregateLogger);
InputAttemptIdentifier ident = new InputAttemptIdentifier(1, 1);
when(activeLogger.isInfoEnabled()).thenReturn(false);
for (int i = 0; i < 1000; i++) {
logger.logIndividualFetchComplete(10, 100, 1000, "testType", ident);
}
verify(activeLogger, times(0)).info(anyString());
verify(aggregateLogger, times(1)).info(anyString(), Matchers.<Object[]>anyVararg());
when(activeLogger.isInfoEnabled()).thenReturn(true);
for (int i = 0; i < 1000; i++) {
logger.logIndividualFetchComplete(10, 100, 1000, "testType", ident);
}
verify(activeLogger, times(1000)).info(anyString());
verify(aggregateLogger, times(1)).info(anyString(), Matchers.<Object[]>anyVararg());
}
/**
* A codec class which implements CompressionCodec, Configurable for testing purposes.
*/
public static class ConfigurableCodecForTest implements CompressionCodec, Configurable {
@Override
public Compressor createCompressor() {
return null;
}
@Override
public Decompressor createDecompressor() {
return null;
}
@Override
public CompressionInputStream createInputStream(InputStream arg0) throws IOException {
return null;
}
@Override
public CompressionInputStream createInputStream(InputStream arg0, Decompressor arg1)
throws IOException {
return null;
}
@Override
public CompressionOutputStream createOutputStream(OutputStream arg0) throws IOException {
return null;
}
@Override
public CompressionOutputStream createOutputStream(OutputStream arg0, Compressor arg1)
throws IOException {
return null;
}
@Override
public Class<? extends Compressor> getCompressorType() {
return null;
}
@Override
public Class<? extends Decompressor> getDecompressorType() {
return null;
}
@Override
public String getDefaultExtension() {
return null;
}
@Override
public Configuration getConf() {
return null;
}
@Override
public void setConf(Configuration arg0) {
}
}
}