| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.hadoop.mapreduce.lib.partition; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.List; |
| |
| import org.junit.Test; |
| import static org.junit.Assert.*; |
| |
| import org.apache.hadoop.io.IntWritable; |
| import org.apache.hadoop.io.NullWritable; |
| import org.apache.hadoop.io.WritableComparable; |
| import org.apache.hadoop.mapreduce.InputFormat; |
| import org.apache.hadoop.mapreduce.InputSplit; |
| import org.apache.hadoop.mapreduce.Job; |
| import org.apache.hadoop.mapreduce.JobContext; |
| import org.apache.hadoop.mapreduce.RecordReader; |
| import org.apache.hadoop.mapreduce.TaskAttemptContext; |
| |
| public class TestInputSampler { |
| |
| static class SequentialSplit extends InputSplit { |
| private int i; |
| SequentialSplit(int i) { |
| this.i = i; |
| } |
| public long getLength() { return 0; } |
| public String[] getLocations() { return new String[0]; } |
| public int getInit() { return i; } |
| } |
| |
| static class TestInputSamplerIF |
| extends InputFormat<IntWritable,NullWritable> { |
| |
| final int maxDepth; |
| final ArrayList<InputSplit> splits = new ArrayList<InputSplit>(); |
| |
| TestInputSamplerIF(int maxDepth, int numSplits, int... splitInit) { |
| this.maxDepth = maxDepth; |
| assert splitInit.length == numSplits; |
| for (int i = 0; i < numSplits; ++i) { |
| splits.add(new SequentialSplit(splitInit[i])); |
| } |
| } |
| |
| public List<InputSplit> getSplits(JobContext context) |
| throws IOException, InterruptedException { |
| return splits; |
| } |
| |
| public RecordReader<IntWritable,NullWritable> createRecordReader( |
| final InputSplit split, TaskAttemptContext context) |
| throws IOException, InterruptedException { |
| return new RecordReader<IntWritable,NullWritable>() { |
| private int maxVal; |
| private final IntWritable i = new IntWritable(); |
| public void initialize(InputSplit split, TaskAttemptContext context) |
| throws IOException, InterruptedException { |
| i.set(((SequentialSplit)split).getInit() - 1); |
| maxVal = i.get() + maxDepth + 1; |
| } |
| public boolean nextKeyValue() { |
| i.set(i.get() + 1); |
| return i.get() < maxVal; |
| } |
| public IntWritable getCurrentKey() { return i; } |
| public NullWritable getCurrentValue() { return NullWritable.get(); } |
| public float getProgress() { return 1.0f; } |
| public void close() { } |
| }; |
| } |
| |
| } |
| |
| /** |
| * Verify SplitSampler contract, that an equal number of records are taken |
| * from the first splits. |
| */ |
| @Test |
| @SuppressWarnings("unchecked") // IntWritable comparator not typesafe |
| public void testSplitSampler() throws Exception { |
| final int TOT_SPLITS = 15; |
| final int NUM_SPLITS = 5; |
| final int STEP_SAMPLE = 5; |
| final int NUM_SAMPLES = NUM_SPLITS * STEP_SAMPLE; |
| InputSampler.Sampler<IntWritable,NullWritable> sampler = |
| new InputSampler.SplitSampler<IntWritable,NullWritable>( |
| NUM_SAMPLES, NUM_SPLITS); |
| int inits[] = new int[TOT_SPLITS]; |
| for (int i = 0; i < TOT_SPLITS; ++i) { |
| inits[i] = i * STEP_SAMPLE; |
| } |
| Job ignored = Job.getInstance(); |
| Object[] samples = sampler.getSample( |
| new TestInputSamplerIF(100000, TOT_SPLITS, inits), ignored); |
| assertEquals(NUM_SAMPLES, samples.length); |
| Arrays.sort(samples, new IntWritable.Comparator()); |
| for (int i = 0; i < NUM_SAMPLES; ++i) { |
| assertEquals(i, ((IntWritable)samples[i]).get()); |
| } |
| } |
| |
| /** |
| * Verify IntervalSampler contract, that samples are taken at regular |
| * intervals from the given splits. |
| */ |
| @Test |
| @SuppressWarnings("unchecked") // IntWritable comparator not typesafe |
| public void testIntervalSampler() throws Exception { |
| final int TOT_SPLITS = 16; |
| final int PER_SPLIT_SAMPLE = 4; |
| final int NUM_SAMPLES = TOT_SPLITS * PER_SPLIT_SAMPLE; |
| final double FREQ = 1.0 / TOT_SPLITS; |
| InputSampler.Sampler<IntWritable,NullWritable> sampler = |
| new InputSampler.IntervalSampler<IntWritable,NullWritable>( |
| FREQ, NUM_SAMPLES); |
| int inits[] = new int[TOT_SPLITS]; |
| for (int i = 0; i < TOT_SPLITS; ++i) { |
| inits[i] = i; |
| } |
| Job ignored = Job.getInstance(); |
| Object[] samples = sampler.getSample(new TestInputSamplerIF( |
| NUM_SAMPLES, TOT_SPLITS, inits), ignored); |
| assertEquals(NUM_SAMPLES, samples.length); |
| Arrays.sort(samples, new IntWritable.Comparator()); |
| for (int i = 0; i < NUM_SAMPLES; ++i) { |
| assertEquals(i, ((IntWritable)samples[i]).get()); |
| } |
| } |
| |
| } |