blob: 6e5e4da927aa25afd0b44097c27b7d9562c61a62 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io;
import static org.apache.beam.sdk.io.Compression.AUTO;
import static org.apache.beam.sdk.io.Compression.DEFLATE;
import static org.apache.beam.sdk.io.Compression.GZIP;
import static org.apache.beam.sdk.io.Compression.UNCOMPRESSED;
import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem;
import static org.hamcrest.Matchers.isIn;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import com.google.common.collect.Lists;
import com.google.common.io.BaseEncoding;
import com.google.common.io.ByteStreams;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.FileVisitResult;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.SimpleFileVisitor;
import java.nio.file.attribute.BasicFileAttributes;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.values.PCollection;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.rules.ExpectedException;
import org.junit.rules.RuleChain;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Tests for TFRecordIO Read and Write transforms.
*/
@RunWith(JUnit4.class)
public class TFRecordIOTest {
/*
From https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py
Created by running following code in python:
>>> import tensorflow as tf
>>> import base64
>>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
>>> writer.write('foo')
>>> writer.close()
>>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
... data = base64.b64encode(f.read())
... print data
*/
private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g==";
// Same as above but containing two records ['foo', 'bar']
private static final String FOO_BAR_RECORD_BASE64 =
"AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=";
private static final String BAR_FOO_RECORD_BASE64 =
"AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4=";
private static final String[] FOO_RECORDS = {"foo"};
private static final String[] FOO_BAR_RECORDS = {"foo", "bar"};
private static final Iterable<String> EMPTY = Collections.emptyList();
private static final Iterable<String> LARGE = makeLines(1000);
private static Path tempFolder;
public TestPipeline p = TestPipeline.create();
public ExpectedException expectedException = ExpectedException.none();
@Rule
public TestPipeline p2 = TestPipeline.create();
@Rule
public RuleChain ruleChain = RuleChain.outerRule(expectedException).around(p);
@BeforeClass
public static void setupClass() throws IOException {
tempFolder = Files.createTempDirectory("TFRecordIOTest");
}
@AfterClass
public static void teardownClass() throws IOException {
Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() {
@Override
public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException {
Files.delete(file);
return FileVisitResult.CONTINUE;
}
@Override
public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException {
Files.delete(dir);
return FileVisitResult.CONTINUE;
}
});
}
@Test
public void testReadNamed() {
p.enableAbandonedNodeEnforcement(false);
assertEquals(
"TFRecordIO.Read/Read.out",
p.apply(TFRecordIO.read().from("foo.*").withoutValidation()).getName());
assertEquals(
"MyRead/Read.out",
p.apply("MyRead", TFRecordIO.read().from("foo.*").withoutValidation()).getName());
}
@Test
public void testReadDisplayData() {
TFRecordIO.Read read = TFRecordIO.read()
.from("foo.*")
.withCompression(GZIP)
.withoutValidation();
DisplayData displayData = DisplayData.from(read);
assertThat(displayData, hasDisplayItem("filePattern", "foo.*"));
assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
assertThat(displayData, hasDisplayItem("validation", false));
}
@Test
public void testWriteDisplayData() {
TFRecordIO.Write write = TFRecordIO.write()
.to("/foo")
.withSuffix("bar")
.withShardNameTemplate("-SS-of-NN-")
.withNumShards(100)
.withCompression(GZIP);
DisplayData displayData = DisplayData.from(write);
assertThat(displayData, hasDisplayItem("filePrefix", "/foo"));
assertThat(displayData, hasDisplayItem("fileSuffix", "bar"));
assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-"));
assertThat(displayData, hasDisplayItem("numShards", 100));
assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString()));
}
@Test
@Category(NeedsRunner.class)
public void testReadOne() throws Exception {
runTestRead(FOO_RECORD_BASE64, FOO_RECORDS);
}
@Test
@Category(NeedsRunner.class)
public void testReadTwo() throws Exception {
runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS);
}
@Test
@Category(NeedsRunner.class)
public void testWriteOne() throws Exception {
runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64);
}
@Test
@Category(NeedsRunner.class)
public void testWriteTwo() throws Exception {
runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64);
}
@Test
@Category(NeedsRunner.class)
public void testReadInvalidRecord() throws Exception {
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes.");
System.out.println("abr".getBytes().length);
runTestRead("bar".getBytes(), new String[0]);
}
@Test
@Category(NeedsRunner.class)
public void testReadInvalidLengthMask() throws Exception {
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage("Mismatch of length mask");
byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
data[9] += 1;
runTestRead(data, FOO_RECORDS);
}
@Test
@Category(NeedsRunner.class)
public void testReadInvalidDataMask() throws Exception {
expectedException.expect(IllegalStateException.class);
expectedException.expectMessage("Mismatch of data mask");
byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64);
data[16] += 1;
runTestRead(data, FOO_RECORDS);
}
private void runTestRead(String base64, String[] expected) throws IOException {
runTestRead(BaseEncoding.base64().decode(base64), expected);
}
private void runTestRead(byte[] data, String[] expected) throws IOException {
File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
String filename = tmpFile.getPath();
FileOutputStream fos = new FileOutputStream(tmpFile);
fos.write(data);
fos.close();
TFRecordIO.Read read = TFRecordIO.read().from(filename);
PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString()));
PAssert.that(output).containsInAnyOrder(expected);
p.run();
}
private void runTestWrite(String[] elems, String ...base64) throws IOException {
File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile();
String filename = tmpFile.getPath();
PCollection<byte[]> input = p.apply(Create.of(Arrays.asList(elems)))
.apply(ParDo.of(new StringToByteArray()));
TFRecordIO.Write write = TFRecordIO.write().to(filename).withoutSharding();
input.apply(write);
p.run();
FileInputStream fis = new FileInputStream(tmpFile);
String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis));
// bytes written may vary depending the order of elems
assertThat(written, isIn(base64));
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTrip() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripWithEmptyData() throws IOException {
runTestRoundTrip(EMPTY, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripWithOneShards() throws IOException {
runTestRoundTrip(LARGE, 1, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripWithSuffix() throws IOException {
runTestRoundTrip(LARGE, 10, ".suffix", UNCOMPRESSED, UNCOMPRESSED);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripGzip() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripZlib() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", DEFLATE, DEFLATE);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripUncompressedFilesWithAuto() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, AUTO);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripGzipFilesWithAuto() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, AUTO);
}
@Test
@Category(NeedsRunner.class)
public void runTestRoundTripZlibFilesWithAuto() throws IOException {
runTestRoundTrip(LARGE, 10, ".tfrecords", DEFLATE, AUTO);
}
private void runTestRoundTrip(Iterable<String> elems,
int numShards,
String suffix,
Compression writeCompression,
Compression readCompression) throws IOException {
String outputName = "file";
Path baseDir = Files.createTempDirectory(tempFolder, "test-rt");
String baseFilename = baseDir.resolve(outputName).toString();
TFRecordIO.Write write = TFRecordIO.write().to(baseFilename)
.withNumShards(numShards)
.withSuffix(suffix)
.withCompression(writeCompression);
p.apply(Create.of(elems).withCoder(StringUtf8Coder.of()))
.apply(ParDo.of(new StringToByteArray()))
.apply(write);
p.run();
TFRecordIO.Read read = TFRecordIO.read().from(baseFilename + "*")
.withCompression(readCompression);
PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString()));
PAssert.that(output).containsInAnyOrder(elems);
p2.run();
}
private static Iterable<String> makeLines(int n) {
List<String> ret = Lists.newArrayList();
for (int i = 0; i < n; ++i) {
ret.add("word" + i);
}
return ret;
}
static class ByteArrayToString extends DoFn<byte[], String> {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(new String(c.element()));
}
}
static class StringToByteArray extends DoFn<String, byte[]> {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(c.element().getBytes());
}
}
}