blob: afe5779c5fa3349b36ff3e510fdd3b1cc00ea6d0 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.avro;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData.Array;
import org.apache.avro.generic.GenericRecord;
import org.apache.nifi.serialization.RecordSetWriter;
import org.apache.nifi.serialization.SimpleRecordSchema;
import org.apache.nifi.serialization.WriteResult;
import org.apache.nifi.serialization.record.DataType;
import org.apache.nifi.serialization.record.ListRecordSet;
import org.apache.nifi.serialization.record.MapRecord;
import org.apache.nifi.serialization.record.Record;
import org.apache.nifi.serialization.record.RecordField;
import org.apache.nifi.serialization.record.RecordFieldType;
import org.apache.nifi.serialization.record.RecordSchema;
import org.apache.nifi.serialization.record.RecordSet;
import org.junit.Assert;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;
import java.text.DateFormat;
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TimeZone;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
public abstract class TestWriteAvroResult {
protected abstract RecordSetWriter createWriter(Schema schema, OutputStream out) throws IOException;
protected abstract GenericRecord readRecord(InputStream in, Schema schema) throws IOException;
protected abstract List<GenericRecord> readRecords(InputStream in, Schema schema, int recordCount) throws IOException;
protected void verify(final WriteResult writeResult) {
}
@Test
public void testWriteRecursiveRecord() throws IOException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/recursive.avsc"));
final RecordSchema recordSchema = AvroTypeUtil.createSchema(schema);
final FileInputStream in = new FileInputStream("src/test/resources/avro/recursive.avro");
try (final AvroRecordReader reader = new AvroReaderWithExplicitSchema(in, recordSchema, schema);
final RecordSetWriter writer = createWriter(schema, new ByteArrayOutputStream())) {
final GenericRecord avroRecord = reader.nextAvroRecord();
final Map<String, Object> recordMap = AvroTypeUtil.convertAvroRecordToMap(avroRecord, recordSchema);
final Record record = new MapRecord(recordSchema, recordMap);
try {
writer.write(record);
} catch (StackOverflowError soe) {
Assert.fail("Recursive schema resulted in infinite loop during write");
}
}
}
@Test
public void testWriteRecord() throws IOException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/simple.avsc"));
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("msg", RecordFieldType.STRING.getDataType()));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final Map<String, Object> values = new HashMap<>();
values.put("msg", "nifi");
final Record record = new MapRecord(recordSchema, values);
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writer.write(record);
}
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
assertNotNull(avroRecord);
assertNotNull(avroRecord.get("msg"));
assertEquals("nifi", avroRecord.get("msg").toString());
}
}
@Test
public void testWriteRecordSet() throws IOException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/simple.avsc"));
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("msg", RecordFieldType.STRING.getDataType()));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final int recordCount = 3;
List<Record> records = new ArrayList<>();
for (int i = 0; i < recordCount; i++){
final Map<String, Object> values = new HashMap<>();
values.put("msg", "nifi" + i);
final Record record = new MapRecord(recordSchema, values);
records.add(record);
}
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writer.write(new ListRecordSet(recordSchema, records));
}
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final List<GenericRecord> avroRecords = readRecords(in, schema, recordCount);
for (int i = 0; i < recordCount; i++) {
final GenericRecord avroRecord = avroRecords.get(i);
assertNotNull(avroRecord);
assertNotNull(avroRecord.get("msg"));
assertEquals("nifi" + i, avroRecord.get("msg").toString());
}
}
}
@Test
public void testDecimalType() throws IOException {
final Object[][] decimals = new Object[][] {
// id, record field, value, expected value
// Uses the whole precision and scale
{1, RecordFieldType.DECIMAL.getDecimalDataType(10, 2), new BigDecimal("12345678.12"), new BigDecimal("12345678.12")},
// Uses less precision and scale than allowed
{2, RecordFieldType.DECIMAL.getDecimalDataType(10, 2), new BigDecimal("123456.1"), new BigDecimal("123456.10")},
// Record schema uses smaller precision and scale than allowed
{3, RecordFieldType.DECIMAL.getDecimalDataType(8, 1), new BigDecimal("123456.1"), new BigDecimal("123456.10")},
// Record schema uses bigger precision and scale than allowed
{4, RecordFieldType.DECIMAL.getDecimalDataType(16, 4), new BigDecimal("123456.1"), new BigDecimal("123456.10")},
};
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/decimals.avsc"));
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> fields = new ArrayList<>();
final Map<String, Object> values = new HashMap<>();
for (final Object[] decimal : decimals) {
fields.add(new RecordField("decimal" + decimal[0], (DataType) decimal[1]));
values.put("decimal" + decimal[0], decimal[2]);
}
final Record record = new MapRecord(new SimpleRecordSchema(fields), values);
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writer.write(RecordSet.of(record.getSchema(), record));
}
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
for (final Object[] decimal : decimals) {
final Schema decimalSchema = schema.getField("decimal" + decimal[0]).schema();
final LogicalType logicalType = decimalSchema.getLogicalType();
Assert.assertEquals(decimal[3], new Conversions.DecimalConversion().fromBytes((ByteBuffer) avroRecord.get("decimal" + decimal[0]), decimalSchema, logicalType));
}
}
}
@Test
public void testLogicalTypes() throws IOException, ParseException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/logical-types.avsc"));
testLogicalTypes(schema);
}
@Test
public void testNullableLogicalTypes() throws IOException, ParseException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/logical-types-nullable.avsc"));
testLogicalTypes(schema);
}
private void testLogicalTypes(Schema schema) throws ParseException, IOException {
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("timeMillis", RecordFieldType.TIME.getDataType()));
fields.add(new RecordField("timeMicros", RecordFieldType.TIME.getDataType()));
fields.add(new RecordField("timestampMillis", RecordFieldType.TIMESTAMP.getDataType()));
fields.add(new RecordField("timestampMicros", RecordFieldType.TIMESTAMP.getDataType()));
fields.add(new RecordField("date", RecordFieldType.DATE.getDataType()));
fields.add(new RecordField("decimal", RecordFieldType.DECIMAL.getDecimalDataType(5,2)));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final String expectedTime = "2017-04-04 14:20:33.789";
final DateFormat df = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS");
df.setTimeZone(TimeZone.getTimeZone("gmt"));
final long timeLong = df.parse(expectedTime).getTime();
final Map<String, Object> values = new HashMap<>();
values.put("timeMillis", new Time(timeLong));
values.put("timeMicros", new Time(timeLong));
values.put("timestampMillis", new Timestamp(timeLong));
values.put("timestampMicros", new Timestamp(timeLong));
values.put("date", new Date(timeLong));
values.put("decimal", new BigDecimal("123.45"));
final Record record = new MapRecord(recordSchema, values);
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writer.write(RecordSet.of(record.getSchema(), record));
}
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
final long secondsSinceMidnight = 33 + (20 * 60) + (14 * 60 * 60);
final long millisSinceMidnight = (secondsSinceMidnight * 1000L) + 789;
assertEquals((int) millisSinceMidnight, avroRecord.get("timeMillis"));
assertEquals(millisSinceMidnight * 1000L, avroRecord.get("timeMicros"));
assertEquals(timeLong, avroRecord.get("timestampMillis"));
assertEquals(timeLong * 1000L, avroRecord.get("timestampMicros"));
// Double value will be converted into logical decimal if Avro schema is defined as logical decimal.
final Schema decimalSchema = schema.getField("decimal").schema();
final LogicalType logicalType = decimalSchema.getLogicalType() != null
? decimalSchema.getLogicalType()
// Union type doesn't return logical type. Find the first logical type defined within the union.
: decimalSchema.getTypes().stream().map(s -> s.getLogicalType()).filter(Objects::nonNull).findFirst().get();
final BigDecimal decimal = new Conversions.DecimalConversion().fromBytes((ByteBuffer) avroRecord.get("decimal"), decimalSchema, logicalType);
assertEquals(new BigDecimal("123.45"), decimal);
}
}
@Test
public void testDataTypes() throws IOException {
final Schema schema = new Schema.Parser().parse(new File("src/test/resources/avro/datatypes.avsc"));
final ByteArrayOutputStream baos = new ByteArrayOutputStream();
final List<RecordField> subRecordFields = Collections.singletonList(new RecordField("field1", RecordFieldType.STRING.getDataType()));
final RecordSchema subRecordSchema = new SimpleRecordSchema(subRecordFields);
final DataType subRecordDataType = RecordFieldType.RECORD.getRecordDataType(subRecordSchema);
final List<RecordField> fields = new ArrayList<>();
fields.add(new RecordField("string", RecordFieldType.STRING.getDataType()));
fields.add(new RecordField("int", RecordFieldType.INT.getDataType()));
fields.add(new RecordField("long", RecordFieldType.LONG.getDataType()));
fields.add(new RecordField("double", RecordFieldType.DOUBLE.getDataType()));
fields.add(new RecordField("float", RecordFieldType.FLOAT.getDataType()));
fields.add(new RecordField("boolean", RecordFieldType.BOOLEAN.getDataType()));
fields.add(new RecordField("bytes", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.BYTE.getDataType())));
fields.add(new RecordField("nullOrLong", RecordFieldType.LONG.getDataType()));
fields.add(new RecordField("array", RecordFieldType.ARRAY.getArrayDataType(RecordFieldType.INT.getDataType())));
fields.add(new RecordField("record", subRecordDataType));
fields.add(new RecordField("map", RecordFieldType.MAP.getMapDataType(subRecordDataType)));
final RecordSchema recordSchema = new SimpleRecordSchema(fields);
final Record innerRecord = new MapRecord(subRecordSchema, Collections.singletonMap("field1", "hello"));
final Map<String, Object> innerMap = new HashMap<>();
innerMap.put("key1", innerRecord);
final Map<String, Object> values = new HashMap<>();
values.put("string", "hello");
values.put("int", 8);
values.put("long", 42L);
values.put("double", 3.14159D);
values.put("float", 1.23456F);
values.put("boolean", true);
values.put("bytes", AvroTypeUtil.convertByteArray("hello".getBytes()));
values.put("nullOrLong", null);
values.put("array", new Integer[] {1, 2, 3});
values.put("record", innerRecord);
values.put("map", innerMap);
final Record record = new MapRecord(recordSchema, values);
final WriteResult writeResult;
try (final RecordSetWriter writer = createWriter(schema, baos)) {
writeResult = writer.write(RecordSet.of(record.getSchema(), record));
}
verify(writeResult);
final byte[] data = baos.toByteArray();
try (final InputStream in = new ByteArrayInputStream(data)) {
final GenericRecord avroRecord = readRecord(in, schema);
assertMatch(record, avroRecord);
}
}
protected void assertMatch(final Record record, final GenericRecord avroRecord) {
for (final String fieldName : record.getSchema().getFieldNames()) {
Object avroValue = avroRecord.get(fieldName);
final Object recordValue = record.getValue(fieldName);
if (recordValue instanceof String) {
assertNotNull(fieldName + " should not have been null", avroValue);
avroValue = avroValue.toString();
}
if (recordValue instanceof Object[] && avroValue instanceof ByteBuffer) {
final ByteBuffer bb = (ByteBuffer) avroValue;
final Object[] objectArray = (Object[]) recordValue;
assertEquals("For field " + fieldName + ", byte buffer remaining should have been " + objectArray.length + " but was " + bb.remaining(),
objectArray.length, bb.remaining());
for (int i = 0; i < objectArray.length; i++) {
assertEquals(objectArray[i], bb.get());
}
} else if (recordValue instanceof Object[]) {
assertTrue(fieldName + " should have been instanceof Array", avroValue instanceof Array);
final Array<?> avroArray = (Array<?>) avroValue;
final Object[] recordArray = (Object[]) recordValue;
assertEquals(fieldName + " not equal", recordArray.length, avroArray.size());
for (int i = 0; i < recordArray.length; i++) {
assertEquals(fieldName + "[" + i + "] not equal", recordArray[i], avroArray.get(i));
}
} else if (recordValue instanceof byte[]) {
final ByteBuffer bb = ByteBuffer.wrap((byte[]) recordValue);
assertEquals(fieldName + " not equal", bb, avroValue);
} else if (recordValue instanceof Map) {
assertTrue(fieldName + " should have been instanceof Map", avroValue instanceof Map);
final Map<?, ?> avroMap = (Map<?, ?>) avroValue;
final Map<?, ?> recordMap = (Map<?, ?>) recordValue;
assertEquals(fieldName + " not equal", recordMap.size(), avroMap.size());
for (Object s : avroMap.keySet()) {
assertMatch((Record) recordMap.get(s.toString()), (GenericRecord) avroMap.get(s));
}
} else if (recordValue instanceof Record) {
assertMatch((Record) recordValue, (GenericRecord) avroValue);
} else {
assertEquals(fieldName + " not equal", recordValue, avroValue);
}
}
}
}