blob: 7ede8811a2a2c377513f60c6122e58f634de897a [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.crunch;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.Locale;
import org.apache.crunch.impl.mem.MemPipeline;
import org.apache.crunch.impl.mr.MRPipeline;
import org.apache.crunch.io.At;
import org.apache.crunch.io.ReadableSource;
import org.apache.crunch.lib.Aggregate;
import org.apache.crunch.test.TemporaryPath;
import org.apache.crunch.test.TemporaryPaths;
import org.apache.crunch.types.PTypeFamily;
import org.apache.crunch.types.writable.WritableTypeFamily;
import org.junit.Rule;
import org.junit.Test;
@SuppressWarnings("serial")
public class TermFrequencyIT implements Serializable {
@Rule
public transient TemporaryPath tmpDir = TemporaryPaths.create();
@Test
public void testTermFrequencyWithNoTransform() throws IOException {
run(new MRPipeline(TermFrequencyIT.class, tmpDir.getDefaultConfiguration()), WritableTypeFamily.getInstance(), false);
}
@Test
public void testTermFrequencyWithTransform() throws IOException {
run(new MRPipeline(TermFrequencyIT.class, tmpDir.getDefaultConfiguration()), WritableTypeFamily.getInstance(), true);
}
@Test
public void testTermFrequencyNoTransformInMemory() throws IOException {
run(MemPipeline.getInstance(), WritableTypeFamily.getInstance(), false);
}
@Test
public void testTermFrequencyWithTransformInMemory() throws IOException {
run(MemPipeline.getInstance(), WritableTypeFamily.getInstance(), true);
}
public void run(Pipeline pipeline, PTypeFamily typeFamily, boolean transformTF) throws IOException {
String input = tmpDir.copyResourceFileName("docs.txt");
File transformedOutput = tmpDir.getFile("transformed-output");
File tfOutput = tmpDir.getFile("tf-output");
PCollection<String> docs = pipeline.readTextFile(input);
PTypeFamily ptf = docs.getTypeFamily();
/*
* Input: String Input title text
*
* Output: PTable<Pair<String, String>, Long> Pair<Pair<word, title>, count
* in title>
*/
PTable<Pair<String, String>, Long> tf = Aggregate.count(docs.parallelDo("term document frequency",
new DoFn<String, Pair<String, String>>() {
@Override
public void process(String doc, Emitter<Pair<String, String>> emitter) {
String[] kv = doc.split("\t");
String title = kv[0];
String text = kv[1];
for (String word : text.split("\\W+")) {
if (!word.isEmpty()) {
Pair<String, String> pair = Pair.of(word.toLowerCase(Locale.ENGLISH), title);
emitter.emit(pair);
}
}
}
}, ptf.pairs(ptf.strings(), ptf.strings())));
if (transformTF) {
/*
* Input: Pair<Pair<String, String>, Long> Pair<Pair<word, title>, count
* in title>
*
* Output: PTable<String, Pair<String, Long>> PTable<word, Pair<title,
* count in title>>
*/
PTable<String, Pair<String, Long>> wordDocumentCountPair = tf.parallelDo("transform wordDocumentPairCount",
new MapFn<Pair<Pair<String, String>, Long>, Pair<String, Pair<String, Long>>>() {
@Override
public Pair<String, Pair<String, Long>> map(Pair<Pair<String, String>, Long> input) {
Pair<String, String> wordDocumentPair = input.first();
return Pair.of(wordDocumentPair.first(), Pair.of(wordDocumentPair.second(), input.second()));
}
}, ptf.tableOf(ptf.strings(), ptf.pairs(ptf.strings(), ptf.longs())));
pipeline.writeTextFile(wordDocumentCountPair, transformedOutput.getAbsolutePath());
}
SourceTarget<String> st = At.textFile(tfOutput.getAbsolutePath());
pipeline.write(tf, st);
pipeline.run();
// test the case we should see
Iterable<String> lines = ((ReadableSource<String>) st).read(pipeline.getConfiguration());
boolean passed = false;
for (String line : lines) {
if ("[well,A]\t0".equals(line)) {
fail("Found " + line + " but well is in Document A 1 time");
}
if ("[well,A]\t1".equals(line)) {
passed = true;
}
}
assertTrue(passed);
pipeline.done();
}
}