| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.beam.sdk.io; |
| |
| import static org.apache.beam.sdk.io.Compression.AUTO; |
| import static org.apache.beam.sdk.io.Compression.DEFLATE; |
| import static org.apache.beam.sdk.io.Compression.GZIP; |
| import static org.apache.beam.sdk.io.Compression.UNCOMPRESSED; |
| import static org.apache.beam.sdk.transforms.display.DisplayDataMatchers.hasDisplayItem; |
| import static org.hamcrest.Matchers.isIn; |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertThat; |
| |
| import com.google.common.collect.Lists; |
| import com.google.common.io.BaseEncoding; |
| import com.google.common.io.ByteStreams; |
| import java.io.File; |
| import java.io.FileInputStream; |
| import java.io.FileOutputStream; |
| import java.io.IOException; |
| import java.nio.file.FileVisitResult; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.nio.file.SimpleFileVisitor; |
| import java.nio.file.attribute.BasicFileAttributes; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.List; |
| import org.apache.beam.sdk.coders.StringUtf8Coder; |
| import org.apache.beam.sdk.testing.NeedsRunner; |
| import org.apache.beam.sdk.testing.PAssert; |
| import org.apache.beam.sdk.testing.TestPipeline; |
| import org.apache.beam.sdk.transforms.Create; |
| import org.apache.beam.sdk.transforms.DoFn; |
| import org.apache.beam.sdk.transforms.ParDo; |
| import org.apache.beam.sdk.transforms.display.DisplayData; |
| import org.apache.beam.sdk.values.PCollection; |
| import org.junit.AfterClass; |
| import org.junit.BeforeClass; |
| import org.junit.Rule; |
| import org.junit.Test; |
| import org.junit.experimental.categories.Category; |
| import org.junit.rules.ExpectedException; |
| import org.junit.rules.RuleChain; |
| import org.junit.runner.RunWith; |
| import org.junit.runners.JUnit4; |
| |
| /** |
| * Tests for TFRecordIO Read and Write transforms. |
| */ |
| @RunWith(JUnit4.class) |
| public class TFRecordIOTest { |
| |
| /* |
| From https://github.com/apache/beam/blob/master/sdks/python/apache_beam/io/tfrecordio_test.py |
| Created by running following code in python: |
| >>> import tensorflow as tf |
| >>> import base64 |
| >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') |
| >>> writer.write('foo') |
| >>> writer.close() |
| >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: |
| ... data = base64.b64encode(f.read()) |
| ... print data |
| */ |
| private static final String FOO_RECORD_BASE64 = "AwAAAAAAAACwmUkOZm9vYYq+/g=="; |
| |
| // Same as above but containing two records ['foo', 'bar'] |
| private static final String FOO_BAR_RECORD_BASE64 = |
| "AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg="; |
| private static final String BAR_FOO_RECORD_BASE64 = |
| "AwAAAAAAAACwmUkOYmFyRgDlyAMAAAAAAAAAsJlJDmZvb2GKvv4="; |
| |
| private static final String[] FOO_RECORDS = {"foo"}; |
| private static final String[] FOO_BAR_RECORDS = {"foo", "bar"}; |
| |
| private static final Iterable<String> EMPTY = Collections.emptyList(); |
| private static final Iterable<String> LARGE = makeLines(1000); |
| |
| private static Path tempFolder; |
| |
| |
| public TestPipeline p = TestPipeline.create(); |
| |
| public ExpectedException expectedException = ExpectedException.none(); |
| |
| @Rule |
| public TestPipeline p2 = TestPipeline.create(); |
| |
| @Rule |
| public RuleChain ruleChain = RuleChain.outerRule(expectedException).around(p); |
| |
| @BeforeClass |
| public static void setupClass() throws IOException { |
| tempFolder = Files.createTempDirectory("TFRecordIOTest"); |
| } |
| |
| @AfterClass |
| public static void teardownClass() throws IOException { |
| Files.walkFileTree(tempFolder, new SimpleFileVisitor<Path>() { |
| @Override |
| public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { |
| Files.delete(file); |
| return FileVisitResult.CONTINUE; |
| } |
| |
| @Override |
| public FileVisitResult postVisitDirectory(Path dir, IOException exc) throws IOException { |
| Files.delete(dir); |
| return FileVisitResult.CONTINUE; |
| } |
| }); |
| } |
| |
| @Test |
| public void testReadNamed() { |
| p.enableAbandonedNodeEnforcement(false); |
| |
| assertEquals( |
| "TFRecordIO.Read/Read.out", |
| p.apply(TFRecordIO.read().from("foo.*").withoutValidation()).getName()); |
| assertEquals( |
| "MyRead/Read.out", |
| p.apply("MyRead", TFRecordIO.read().from("foo.*").withoutValidation()).getName()); |
| } |
| |
| @Test |
| public void testReadDisplayData() { |
| TFRecordIO.Read read = TFRecordIO.read() |
| .from("foo.*") |
| .withCompression(GZIP) |
| .withoutValidation(); |
| |
| DisplayData displayData = DisplayData.from(read); |
| |
| assertThat(displayData, hasDisplayItem("filePattern", "foo.*")); |
| assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString())); |
| assertThat(displayData, hasDisplayItem("validation", false)); |
| } |
| |
| @Test |
| public void testWriteDisplayData() { |
| TFRecordIO.Write write = TFRecordIO.write() |
| .to("/foo") |
| .withSuffix("bar") |
| .withShardNameTemplate("-SS-of-NN-") |
| .withNumShards(100) |
| .withCompression(GZIP); |
| |
| DisplayData displayData = DisplayData.from(write); |
| |
| assertThat(displayData, hasDisplayItem("filePrefix", "/foo")); |
| assertThat(displayData, hasDisplayItem("fileSuffix", "bar")); |
| assertThat(displayData, hasDisplayItem("shardNameTemplate", "-SS-of-NN-")); |
| assertThat(displayData, hasDisplayItem("numShards", 100)); |
| assertThat(displayData, hasDisplayItem("compressionType", GZIP.toString())); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testReadOne() throws Exception { |
| runTestRead(FOO_RECORD_BASE64, FOO_RECORDS); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testReadTwo() throws Exception { |
| runTestRead(FOO_BAR_RECORD_BASE64, FOO_BAR_RECORDS); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testWriteOne() throws Exception { |
| runTestWrite(FOO_RECORDS, FOO_RECORD_BASE64); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testWriteTwo() throws Exception { |
| runTestWrite(FOO_BAR_RECORDS, FOO_BAR_RECORD_BASE64, BAR_FOO_RECORD_BASE64); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testReadInvalidRecord() throws Exception { |
| expectedException.expect(IllegalStateException.class); |
| expectedException.expectMessage("Not a valid TFRecord. Fewer than 12 bytes."); |
| System.out.println("abr".getBytes().length); |
| runTestRead("bar".getBytes(), new String[0]); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testReadInvalidLengthMask() throws Exception { |
| expectedException.expect(IllegalStateException.class); |
| expectedException.expectMessage("Mismatch of length mask"); |
| byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64); |
| data[9] += 1; |
| runTestRead(data, FOO_RECORDS); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void testReadInvalidDataMask() throws Exception { |
| expectedException.expect(IllegalStateException.class); |
| expectedException.expectMessage("Mismatch of data mask"); |
| byte[] data = BaseEncoding.base64().decode(FOO_RECORD_BASE64); |
| data[16] += 1; |
| runTestRead(data, FOO_RECORDS); |
| } |
| |
| private void runTestRead(String base64, String[] expected) throws IOException { |
| runTestRead(BaseEncoding.base64().decode(base64), expected); |
| } |
| |
| private void runTestRead(byte[] data, String[] expected) throws IOException { |
| File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile(); |
| String filename = tmpFile.getPath(); |
| |
| FileOutputStream fos = new FileOutputStream(tmpFile); |
| fos.write(data); |
| fos.close(); |
| |
| TFRecordIO.Read read = TFRecordIO.read().from(filename); |
| PCollection<String> output = p.apply(read).apply(ParDo.of(new ByteArrayToString())); |
| |
| PAssert.that(output).containsInAnyOrder(expected); |
| p.run(); |
| } |
| |
| private void runTestWrite(String[] elems, String ...base64) throws IOException { |
| File tmpFile = Files.createTempFile(tempFolder, "file", ".tfrecords").toFile(); |
| String filename = tmpFile.getPath(); |
| |
| PCollection<byte[]> input = p.apply(Create.of(Arrays.asList(elems))) |
| .apply(ParDo.of(new StringToByteArray())); |
| |
| TFRecordIO.Write write = TFRecordIO.write().to(filename).withoutSharding(); |
| input.apply(write); |
| |
| p.run(); |
| |
| FileInputStream fis = new FileInputStream(tmpFile); |
| String written = BaseEncoding.base64().encode(ByteStreams.toByteArray(fis)); |
| // bytes written may vary depending the order of elems |
| assertThat(written, isIn(base64)); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTrip() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripWithEmptyData() throws IOException { |
| runTestRoundTrip(EMPTY, 10, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripWithOneShards() throws IOException { |
| runTestRoundTrip(LARGE, 1, ".tfrecords", UNCOMPRESSED, UNCOMPRESSED); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripWithSuffix() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".suffix", UNCOMPRESSED, UNCOMPRESSED); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripGzip() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, GZIP); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripZlib() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", DEFLATE, DEFLATE); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripUncompressedFilesWithAuto() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", UNCOMPRESSED, AUTO); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripGzipFilesWithAuto() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", GZIP, AUTO); |
| } |
| |
| @Test |
| @Category(NeedsRunner.class) |
| public void runTestRoundTripZlibFilesWithAuto() throws IOException { |
| runTestRoundTrip(LARGE, 10, ".tfrecords", DEFLATE, AUTO); |
| } |
| |
| private void runTestRoundTrip(Iterable<String> elems, |
| int numShards, |
| String suffix, |
| Compression writeCompression, |
| Compression readCompression) throws IOException { |
| String outputName = "file"; |
| Path baseDir = Files.createTempDirectory(tempFolder, "test-rt"); |
| String baseFilename = baseDir.resolve(outputName).toString(); |
| |
| TFRecordIO.Write write = TFRecordIO.write().to(baseFilename) |
| .withNumShards(numShards) |
| .withSuffix(suffix) |
| .withCompression(writeCompression); |
| p.apply(Create.of(elems).withCoder(StringUtf8Coder.of())) |
| .apply(ParDo.of(new StringToByteArray())) |
| .apply(write); |
| p.run(); |
| |
| TFRecordIO.Read read = TFRecordIO.read().from(baseFilename + "*") |
| .withCompression(readCompression); |
| PCollection<String> output = p2.apply(read).apply(ParDo.of(new ByteArrayToString())); |
| |
| PAssert.that(output).containsInAnyOrder(elems); |
| p2.run(); |
| } |
| |
| private static Iterable<String> makeLines(int n) { |
| List<String> ret = Lists.newArrayList(); |
| for (int i = 0; i < n; ++i) { |
| ret.add("word" + i); |
| } |
| return ret; |
| } |
| |
| static class ByteArrayToString extends DoFn<byte[], String> { |
| @ProcessElement |
| public void processElement(ProcessContext c) { |
| c.output(new String(c.element())); |
| } |
| } |
| |
| static class StringToByteArray extends DoFn<String, byte[]> { |
| @ProcessElement |
| public void processElement(ProcessContext c) { |
| c.output(c.element().getBytes()); |
| } |
| } |
| |
| } |