/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.sdk.schemas.utils;

import static org.hamcrest.Matchers.not;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeThat;

import com.pholser.junit.quickcheck.From;
import com.pholser.junit.quickcheck.Property;
import com.pholser.junit.quickcheck.runner.JUnitQuickcheck;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.avro.Conversions;
import org.apache.avro.LogicalType;
import org.apache.avro.LogicalTypes;
import org.apache.avro.RandomData;
import org.apache.avro.Schema.Type;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.GenericRecordBuilder;
import org.apache.avro.reflect.ReflectData;
import org.apache.avro.util.Utf8;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.AvroCoder;
import org.apache.beam.sdk.io.AvroGeneratedUser;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.Schema.Field;
import org.apache.beam.sdk.schemas.Schema.FieldType;
import org.apache.beam.sdk.schemas.utils.AvroGenerators.RecordSchemaGenerator;
import org.apache.beam.sdk.schemas.utils.AvroUtils.TypeWithNullability;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.junit.Test;
import org.junit.runner.RunWith;

/** Tests for conversion between AVRO records and Beam rows. */
@RunWith(JUnitQuickcheck.class)
public class AvroUtilsTest {

  private static final org.apache.avro.Schema NULL_SCHEMA =
      org.apache.avro.Schema.create(Type.NULL);

  @Property(trials = 1000)
  @SuppressWarnings("unchecked")
  public void supportsAnyAvroSchema(
      @From(RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) {
    // not everything is possible to translate
    assumeThat(avroSchema, not(containsField(AvroUtilsTest::hasNonNullUnion)));

    Schema schema = AvroUtils.toBeamSchema(avroSchema);
    Iterable iterable = new RandomData(avroSchema, 10);
    List<GenericRecord> records = Lists.newArrayList((Iterable<GenericRecord>) iterable);

    for (GenericRecord record : records) {
      AvroUtils.toBeamRowStrict(record, schema);
    }
  }

  @Property(trials = 1000)
  @SuppressWarnings("unchecked")
  public void avroToBeamRoundTrip(
      @From(RecordSchemaGenerator.class) org.apache.avro.Schema avroSchema) {
    // not everything is possible to translate
    assumeThat(avroSchema, not(containsField(AvroUtilsTest::hasNonNullUnion)));
    // roundtrip for enums returns strings because Beam doesn't have enum type
    assumeThat(avroSchema, not(containsField(x -> x.getType() == Type.ENUM)));
    // roundtrip for fixed returns bytes because Beam doesn't have FIXED type
    assumeThat(avroSchema, not(containsField(x -> x.getType() == Type.FIXED)));

    Schema schema = AvroUtils.toBeamSchema(avroSchema);
    Iterable iterable = new RandomData(avroSchema, 10);
    List<GenericRecord> records = Lists.newArrayList((Iterable<GenericRecord>) iterable);

    for (GenericRecord record : records) {
      Row row = AvroUtils.toBeamRowStrict(record, schema);
      GenericRecord out = AvroUtils.toGenericRecord(row, avroSchema);
      assertEquals(record, out);
    }
  }

  @Test
  public void testUnwrapNullableSchema() {
    org.apache.avro.Schema avroSchema =
        org.apache.avro.Schema.createUnion(
            org.apache.avro.Schema.create(Type.NULL), org.apache.avro.Schema.create(Type.STRING));

    TypeWithNullability typeWithNullability = new TypeWithNullability(avroSchema);
    assertTrue(typeWithNullability.nullable);
    assertEquals(org.apache.avro.Schema.create(Type.STRING), typeWithNullability.type);
  }

  @Test
  public void testUnwrapNullableSchemaReordered() {
    org.apache.avro.Schema avroSchema =
        org.apache.avro.Schema.createUnion(
            org.apache.avro.Schema.create(Type.STRING), org.apache.avro.Schema.create(Type.NULL));

    TypeWithNullability typeWithNullability = new TypeWithNullability(avroSchema);
    assertTrue(typeWithNullability.nullable);
    assertEquals(org.apache.avro.Schema.create(Type.STRING), typeWithNullability.type);
  }

  @Test
  public void testUnwrapNullableSchemaToUnion() {
    org.apache.avro.Schema avroSchema =
        org.apache.avro.Schema.createUnion(
            org.apache.avro.Schema.create(Type.STRING),
            org.apache.avro.Schema.create(Type.LONG),
            org.apache.avro.Schema.create(Type.NULL));

    TypeWithNullability typeWithNullability = new TypeWithNullability(avroSchema);
    assertTrue(typeWithNullability.nullable);
    assertEquals(
        org.apache.avro.Schema.createUnion(
            org.apache.avro.Schema.create(Type.STRING), org.apache.avro.Schema.create(Type.LONG)),
        typeWithNullability.type);
  }

  @Test
  public void testNullableArrayFieldToBeamArrayField() {
    org.apache.avro.Schema.Field avroField =
        new org.apache.avro.Schema.Field(
            "arrayField",
            ReflectData.makeNullable(
                org.apache.avro.Schema.createArray((org.apache.avro.Schema.create(Type.INT)))),
            "",
            null);

    Field expectedBeamField = Field.nullable("arrayField", FieldType.array(FieldType.INT32));

    Field beamField = AvroUtils.toBeamField(avroField);
    assertEquals(expectedBeamField, beamField);
  }

  @Test
  public void testNullableBeamArrayFieldToAvroField() {
    Field beamField = Field.nullable("arrayField", FieldType.array(FieldType.INT32));

    org.apache.avro.Schema.Field expectedAvroField =
        new org.apache.avro.Schema.Field(
            "arrayField",
            ReflectData.makeNullable(
                org.apache.avro.Schema.createArray((org.apache.avro.Schema.create(Type.INT)))),
            "",
            null);

    org.apache.avro.Schema.Field avroField = AvroUtils.toAvroField(beamField, "ignored");
    assertEquals(expectedAvroField, avroField);
  }

  private static List<org.apache.avro.Schema.Field> getAvroSubSchemaFields() {
    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
    fields.add(
        new org.apache.avro.Schema.Field(
            "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", null));
    fields.add(
        new org.apache.avro.Schema.Field("int", org.apache.avro.Schema.create(Type.INT), "", null));
    return fields;
  }

  private static org.apache.avro.Schema getAvroSubSchema(String name) {
    return org.apache.avro.Schema.createRecord(
        name, null, "topLevelRecord", false, getAvroSubSchemaFields());
  }

  private static org.apache.avro.Schema getAvroSchema() {
    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
    fields.add(
        new org.apache.avro.Schema.Field(
            "bool", org.apache.avro.Schema.create(Type.BOOLEAN), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "int", org.apache.avro.Schema.create(Type.INT), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "long", org.apache.avro.Schema.create(Type.LONG), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "float", org.apache.avro.Schema.create(Type.FLOAT), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "double", org.apache.avro.Schema.create(Type.DOUBLE), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "string", org.apache.avro.Schema.create(Type.STRING), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "bytes", org.apache.avro.Schema.create(Type.BYTES), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "decimal",
            LogicalTypes.decimal(Integer.MAX_VALUE)
                .addToSchema(org.apache.avro.Schema.create(Type.BYTES)),
            "",
            (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "timestampMillis",
            LogicalTypes.timestampMillis().addToSchema(org.apache.avro.Schema.create(Type.LONG)),
            "",
            (Object) null));
    fields.add(new org.apache.avro.Schema.Field("row", getAvroSubSchema("row"), "", (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "array",
            org.apache.avro.Schema.createArray(getAvroSubSchema("array")),
            "",
            (Object) null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "map", org.apache.avro.Schema.createMap(getAvroSubSchema("map")), "", (Object) null));
    return org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields);
  }

  private static Schema getBeamSubSchema() {
    return new Schema.Builder()
        .addField(Field.of("bool", FieldType.BOOLEAN))
        .addField(Field.of("int", FieldType.INT32))
        .build();
  }

  private Schema getBeamSchema() {
    Schema subSchema = getBeamSubSchema();
    return new Schema.Builder()
        .addField(Field.of("bool", FieldType.BOOLEAN))
        .addField(Field.of("int", FieldType.INT32))
        .addField(Field.of("long", FieldType.INT64))
        .addField(Field.of("float", FieldType.FLOAT))
        .addField(Field.of("double", FieldType.DOUBLE))
        .addField(Field.of("string", FieldType.STRING))
        .addField(Field.of("bytes", FieldType.BYTES))
        .addField(Field.of("decimal", FieldType.DECIMAL))
        .addField(Field.of("timestampMillis", FieldType.DATETIME))
        .addField(Field.of("row", FieldType.row(subSchema)))
        .addField(Field.of("array", FieldType.array(FieldType.row(subSchema))))
        .addField(Field.of("map", FieldType.map(FieldType.STRING, FieldType.row(subSchema))))
        .build();
  }

  private static final byte[] BYTE_ARRAY = new byte[] {1, 2, 3, 4};
  private static final DateTime DATE_TIME =
      new DateTime().withDate(1979, 3, 14).withTime(1, 2, 3, 4).withZone(DateTimeZone.UTC);
  private static final BigDecimal BIG_DECIMAL = new BigDecimal(3600);

  private Row getBeamRow() {
    Row subRow = Row.withSchema(getBeamSubSchema()).addValues(true, 42).build();
    return Row.withSchema(getBeamSchema())
        .addValue(true)
        .addValue(43)
        .addValue(44L)
        .addValue((float) 44.1)
        .addValue((double) 44.2)
        .addValue("string")
        .addValue(BYTE_ARRAY)
        .addValue(BIG_DECIMAL)
        .addValue(DATE_TIME)
        .addValue(subRow)
        .addValue(ImmutableList.of(subRow, subRow))
        .addValue(ImmutableMap.of("k1", subRow, "k2", subRow))
        .build();
  }

  private static GenericRecord getSubGenericRecord(String name) {
    return new GenericRecordBuilder(getAvroSubSchema(name))
        .set("bool", true)
        .set("int", 42)
        .build();
  }

  private static GenericRecord getGenericRecord() {

    LogicalType decimalType =
        LogicalTypes.decimal(Integer.MAX_VALUE)
            .addToSchema(org.apache.avro.Schema.create(Type.BYTES))
            .getLogicalType();
    ByteBuffer encodedDecimal =
        new Conversions.DecimalConversion().toBytes(BIG_DECIMAL, null, decimalType);

    return new GenericRecordBuilder(getAvroSchema())
        .set("bool", true)
        .set("int", 43)
        .set("long", 44L)
        .set("float", (float) 44.1)
        .set("double", (double) 44.2)
        .set("string", new Utf8("string"))
        .set("bytes", ByteBuffer.wrap(BYTE_ARRAY))
        .set("decimal", encodedDecimal)
        .set("timestampMillis", DATE_TIME.getMillis())
        .set("row", getSubGenericRecord("row"))
        .set("array", ImmutableList.of(getSubGenericRecord("array"), getSubGenericRecord("array")))
        .set(
            "map",
            ImmutableMap.of(
                new Utf8("k1"),
                getSubGenericRecord("map"),
                new Utf8("k2"),
                getSubGenericRecord("map")))
        .build();
  }

  @Test
  public void testFromAvroSchema() {
    assertEquals(getBeamSchema(), AvroUtils.toBeamSchema(getAvroSchema()));
  }

  @Test
  public void testFromBeamSchema() {
    Schema beamSchema = getBeamSchema();
    org.apache.avro.Schema avroSchema = AvroUtils.toAvroSchema(beamSchema);
    assertEquals(getAvroSchema(), avroSchema);
  }

  @Test
  public void testAvroSchemaFromBeamSchemaCanBeParsed() {
    org.apache.avro.Schema convertedSchema = AvroUtils.toAvroSchema(getBeamSchema());
    org.apache.avro.Schema validatedSchema =
        new org.apache.avro.Schema.Parser().parse(convertedSchema.toString());
    assertEquals(convertedSchema, validatedSchema);
  }

  @Test
  public void testAvroSchemaFromBeamSchemaWithFieldCollisionCanBeParsed() {

    // Two similar schemas, the only difference is the "street" field type in the nested record.
    Schema contact =
        new Schema.Builder()
            .addField(Field.of("name", FieldType.STRING))
            .addField(
                Field.of(
                    "address",
                    FieldType.row(
                        new Schema.Builder()
                            .addField(Field.of("street", FieldType.STRING))
                            .addField(Field.of("city", FieldType.STRING))
                            .build())))
            .build();

    Schema contactMultiline =
        new Schema.Builder()
            .addField(Field.of("name", FieldType.STRING))
            .addField(
                Field.of(
                    "address",
                    FieldType.row(
                        new Schema.Builder()
                            .addField(Field.of("street", FieldType.array(FieldType.STRING)))
                            .addField(Field.of("city", FieldType.STRING))
                            .build())))
            .build();

    // Ensure that no collisions happen between two sibling fields with same-named child fields
    // (with different schemas, between a parent field and a sub-record field with the same name,
    // and artificially with the generated field name.
    Schema beamSchema =
        new Schema.Builder()
            .addField(Field.of("home", FieldType.row(contact)))
            .addField(Field.of("work", FieldType.row(contactMultiline)))
            .addField(Field.of("address", FieldType.row(contact)))
            .addField(Field.of("topLevelRecord", FieldType.row(contactMultiline)))
            .build();

    org.apache.avro.Schema convertedSchema = AvroUtils.toAvroSchema(beamSchema);
    org.apache.avro.Schema validatedSchema =
        new org.apache.avro.Schema.Parser().parse(convertedSchema.toString());
    assertEquals(convertedSchema, validatedSchema);
  }

  @Test
  public void testNullableFieldInAvroSchema() {
    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
    fields.add(
        new org.apache.avro.Schema.Field(
            "int", ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "array",
            org.apache.avro.Schema.createArray(
                ReflectData.makeNullable(org.apache.avro.Schema.create(Type.BYTES))),
            "",
            null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "map",
            org.apache.avro.Schema.createMap(
                ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
            "",
            null));
    org.apache.avro.Schema avroSchema =
        org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields);

    Schema expectedSchema =
        Schema.builder()
            .addNullableField("int", FieldType.INT32)
            .addArrayField("array", FieldType.BYTES.withNullable(true))
            .addMapField("map", FieldType.STRING, FieldType.INT32.withNullable(true))
            .build();
    assertEquals(expectedSchema, AvroUtils.toBeamSchema(avroSchema));

    Map<String, Object> nullMap = Maps.newHashMap();
    nullMap.put("k1", null);
    GenericRecord genericRecord =
        new GenericRecordBuilder(avroSchema)
            .set("int", null)
            .set("array", Lists.newArrayList((Object) null))
            .set("map", nullMap)
            .build();
    Row expectedRow =
        Row.withSchema(expectedSchema)
            .addValue(null)
            .addValue(Lists.newArrayList((Object) null))
            .addValue(nullMap)
            .build();
    assertEquals(expectedRow, AvroUtils.toBeamRowStrict(genericRecord, expectedSchema));
  }

  @Test
  public void testNullableFieldsInBeamSchema() {
    Schema beamSchema =
        Schema.builder()
            .addNullableField("int", FieldType.INT32)
            .addArrayField("array", FieldType.INT32.withNullable(true))
            .addMapField("map", FieldType.STRING, FieldType.INT32.withNullable(true))
            .build();

    List<org.apache.avro.Schema.Field> fields = Lists.newArrayList();
    fields.add(
        new org.apache.avro.Schema.Field(
            "int", ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT)), "", null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "array",
            org.apache.avro.Schema.createArray(
                ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
            "",
            null));
    fields.add(
        new org.apache.avro.Schema.Field(
            "map",
            org.apache.avro.Schema.createMap(
                ReflectData.makeNullable(org.apache.avro.Schema.create(Type.INT))),
            "",
            null));
    org.apache.avro.Schema avroSchema =
        org.apache.avro.Schema.createRecord("topLevelRecord", null, null, false, fields);
    assertEquals(avroSchema, AvroUtils.toAvroSchema(beamSchema));

    Map<Utf8, Object> nullMapUtf8 = Maps.newHashMap();
    nullMapUtf8.put(new Utf8("k1"), null);
    Map<String, Object> nullMapString = Maps.newHashMap();
    nullMapString.put("k1", null);

    GenericRecord expectedGenericRecord =
        new GenericRecordBuilder(avroSchema)
            .set("int", null)
            .set("array", Lists.newArrayList((Object) null))
            .set("map", nullMapUtf8)
            .build();
    Row row =
        Row.withSchema(beamSchema)
            .addValue(null)
            .addValue(Lists.newArrayList((Object) null))
            .addValue(nullMapString)
            .build();
    assertEquals(expectedGenericRecord, AvroUtils.toGenericRecord(row, avroSchema));
  }

  @Test
  public void testBeamRowToGenericRecord() {
    GenericRecord genericRecord = AvroUtils.toGenericRecord(getBeamRow(), null);
    assertEquals(getAvroSchema(), genericRecord.getSchema());
    assertEquals(getGenericRecord(), genericRecord);
  }

  @Test
  public void testGenericRecordToBeamRow() {
    GenericRecord genericRecord = getGenericRecord();
    Row row = AvroUtils.toBeamRowStrict(getGenericRecord(), null);
    assertEquals(getBeamRow(), row);

    // Alternatively, a timestamp-millis logical type can have a joda datum.
    genericRecord.put("timestampMillis", new DateTime(genericRecord.get("timestampMillis")));
    row = AvroUtils.toBeamRowStrict(getGenericRecord(), null);
    assertEquals(getBeamRow(), row);
  }

  @Test
  public void testAvroSchemaCoders() {
    Pipeline pipeline = Pipeline.create();
    org.apache.avro.Schema schema =
        org.apache.avro.Schema.createRecord(
            "TestSubRecord",
            "TestSubRecord doc",
            "org.apache.beam.sdk.schemas.utils",
            false,
            getAvroSubSchemaFields());
    GenericRecord record =
        new GenericRecordBuilder(getAvroSubSchema("simple"))
            .set("bool", true)
            .set("int", 42)
            .build();

    PCollection<GenericRecord> records =
        pipeline.apply(Create.of(record).withCoder(AvroCoder.of(schema)));
    assertFalse(records.hasSchema());
    records.setCoder(AvroUtils.schemaCoder(schema));
    assertTrue(records.hasSchema());

    AvroGeneratedUser user = new AvroGeneratedUser("foo", 42, "green");
    PCollection<AvroGeneratedUser> users =
        pipeline.apply(Create.of(user).withCoder(AvroCoder.of(AvroGeneratedUser.class)));
    assertFalse(users.hasSchema());
    users.setCoder(AvroUtils.schemaCoder((AvroCoder<AvroGeneratedUser>) users.getCoder()));
    assertTrue(users.hasSchema());
  }

  public static ContainsField containsField(Function<org.apache.avro.Schema, Boolean> predicate) {
    return new ContainsField(predicate);
  }

  // doesn't work because Beam doesn't have unions, only nullable fields
  public static boolean hasNonNullUnion(org.apache.avro.Schema schema) {
    if (schema.getType() == Type.UNION) {
      final List<org.apache.avro.Schema> types = schema.getTypes();

      if (types.size() == 2) {
        return !types.contains(NULL_SCHEMA);
      } else {
        return true;
      }
    }

    return false;
  }

  static class ContainsField extends BaseMatcher<org.apache.avro.Schema> {

    private final Function<org.apache.avro.Schema, Boolean> predicate;

    ContainsField(final Function<org.apache.avro.Schema, Boolean> predicate) {
      this.predicate = predicate;
    }

    @Override
    public boolean matches(final Object item0) {
      if (!(item0 instanceof org.apache.avro.Schema)) {
        return false;
      }

      org.apache.avro.Schema item = (org.apache.avro.Schema) item0;

      if (predicate.apply(item)) {
        return true;
      }

      switch (item.getType()) {
        case RECORD:
          return item.getFields().stream().anyMatch(x -> matches(x.schema()));

        case UNION:
          return item.getTypes().stream().anyMatch(this::matches);

        case ARRAY:
          return matches(item.getElementType());

        case MAP:
          return matches(item.getValueType());

        default:
          return false;
      }
    }

    @Override
    public void describeTo(final Description description) {}
  }
}
