/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.samza.sql.avro;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.avro.Schema;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericDatumReader;
import org.apache.avro.generic.GenericDatumWriter;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.generic.IndexedRecord;
import org.apache.avro.io.BinaryDecoder;
import org.apache.avro.io.BinaryEncoder;
import org.apache.avro.io.DatumWriter;
import org.apache.avro.io.DecoderFactory;
import org.apache.avro.io.Encoder;
import org.apache.avro.io.EncoderFactory;
import org.apache.avro.specific.SpecificDatumWriter;
import org.apache.avro.util.Utf8;
import org.apache.calcite.avatica.util.ByteString;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelRecordType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.samza.config.MapConfig;
import org.apache.samza.operators.KV;
import org.apache.samza.sql.avro.schemas.AddressRecord;
import org.apache.samza.sql.avro.schemas.ComplexRecord;
import org.apache.samza.sql.avro.schemas.ComplexUnion;
import org.apache.samza.sql.avro.schemas.Kind;
import org.apache.samza.sql.avro.schemas.MyFixed;
import org.apache.samza.sql.avro.schemas.PhoneNumber;
import org.apache.samza.sql.avro.schemas.Profile;
import org.apache.samza.sql.avro.schemas.SimpleRecord;
import org.apache.samza.sql.avro.schemas.StreetNumRecord;
import org.apache.samza.sql.data.SamzaSqlRelMessage;
import org.apache.samza.sql.planner.RelSchemaConverter;
import org.apache.samza.sql.schema.SqlSchema;
import org.apache.samza.system.SystemStream;
import org.junit.Assert;
import org.junit.Ignore;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class TestAvroRelConversion {

  private static final Logger LOG = LoggerFactory.getLogger(TestAvroRelConversion.class);
  private static final byte[] DEFAULT_TRACKING_ID_BYTES =
    {76, 75, -24, 10, 33, -117, 24, -52, -110, -39, -5, 102, 65, 57, -62, -1};

  private final AvroRelConverter simpleRecordAvroRelConverter;
  private final AvroRelConverter complexRecordAvroRelConverter;
  private final AvroRelConverter nestedRecordAvroRelConverter;
  private final AvroRelConverter complexUnionAvroRelConverter;
  private final AvroRelSchemaProvider simpleRecordSchemaProvider;
  private final AvroRelSchemaProvider complexRecordSchemaProvider;
  private final AvroRelSchemaProvider nestedRecordSchemaProvider;
  private final AvroRelSchemaProvider complexUnionSchemaProvider;

  private int id = 1;
  private boolean boolValue = true;
  private double doubleValue = 0.6;
  private float floatValue = 0.6f;
  private String testStrValue = "testString";
  private ByteBuffer testBytes = ByteBuffer.wrap("testBytes".getBytes());
  private MyFixed fixedBytes = new MyFixed();
  private long longValue = 200L;

  private HashMap<String, String> mapValue = new HashMap<String, String>() { {
      put("key1", "val1");
      put("key2", "val2");
      put("key3", "val3");
    } };
  private List<String> arrayValue = Arrays.asList("val1", "val2", "val3");
  RelSchemaConverter relSchemaConverter = new RelSchemaConverter();

  public TestAvroRelConversion() {
    Map<String, String> props = new HashMap<>();
    SystemStream ss1 = new SystemStream("test", "complexRecord");
    SystemStream ss2 = new SystemStream("test", "simpleRecord");
    SystemStream ss3 = new SystemStream("test", "nestedRecord");
    SystemStream ss4 = new SystemStream("test", "complexUnion");
    props.put(
        String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, ss1.getSystem(), ss1.getStream()),
        ComplexRecord.SCHEMA$.toString());
    props.put(
        String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, ss2.getSystem(), ss2.getStream()),
        SimpleRecord.SCHEMA$.toString());
    props.put(
        String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, ss3.getSystem(), ss3.getStream()),
        Profile.SCHEMA$.toString());
    props.put(
        String.format(ConfigBasedAvroRelSchemaProviderFactory.CFG_SOURCE_SCHEMA, ss4.getSystem(), ss4.getStream()),
        ComplexUnion.SCHEMA$.toString());

    ConfigBasedAvroRelSchemaProviderFactory factory = new ConfigBasedAvroRelSchemaProviderFactory();

    complexRecordSchemaProvider = (AvroRelSchemaProvider) factory.create(ss1, new MapConfig(props));
    simpleRecordSchemaProvider = (AvroRelSchemaProvider) factory.create(ss2, new MapConfig(props));
    nestedRecordSchemaProvider = (AvroRelSchemaProvider) factory.create(ss3, new MapConfig(props));
    complexUnionSchemaProvider = (AvroRelSchemaProvider) factory.create(ss3, new MapConfig(props));
    complexRecordAvroRelConverter = new AvroRelConverter(ss1, complexRecordSchemaProvider, new MapConfig());
    simpleRecordAvroRelConverter = new AvroRelConverter(ss2, simpleRecordSchemaProvider, new MapConfig());
    nestedRecordAvroRelConverter = new AvroRelConverter(ss3, nestedRecordSchemaProvider, new MapConfig());
    complexUnionAvroRelConverter = new AvroRelConverter(ss4, complexUnionSchemaProvider, new MapConfig());

    fixedBytes.bytes(DEFAULT_TRACKING_ID_BYTES);
  }

  @Test
  public void testSimpleSchemaConversion() {
    String streamName = "stream";

    SqlSchema sqlSchema = simpleRecordSchemaProvider.getSqlSchema();
    RelDataType dataType = relSchemaConverter.convertToRelSchema(sqlSchema);
    junit.framework.Assert.assertTrue(dataType instanceof RelRecordType);
    RelRecordType recordType = (RelRecordType) dataType;

    junit.framework.Assert.assertEquals(recordType.getFieldCount(), SimpleRecord.SCHEMA$.getFields().size());
    junit.framework.Assert.assertTrue(
        recordType.getField("id", true, false).getType().getSqlTypeName() == SqlTypeName.INTEGER);
    junit.framework.Assert.assertTrue(
        recordType.getField("name", true, false).getType().getSqlTypeName() == SqlTypeName.VARCHAR);

    LOG.info("Relational schema " + dataType);
  }

  @Test
  public void testComplexSchemaConversion() {
    RelDataType relSchema = relSchemaConverter.convertToRelSchema(complexRecordSchemaProvider.getSqlSchema());

    LOG.info("Relational schema " + relSchema);
  }

  @Test
  public void testNestedSchemaConversion() {
    RelDataType relSchema = relSchemaConverter.convertToRelSchema(nestedRecordSchemaProvider.getSqlSchema());

    LOG.info("Relational schema " + relSchema);
  }

  @Test
  public void testSimpleRecordConversion() {

    GenericData.Record record = new GenericData.Record(SimpleRecord.SCHEMA$);
    record.put("id", 1);
    record.put("name", "name1");

    SamzaSqlRelMessage message = simpleRecordAvroRelConverter.convertToRelMessage(new KV<>("key", record));
    LOG.info(message.toString());
  }

  @Test
  public void testEmptyRecordConversion() {
    GenericData.Record record = new GenericData.Record(SimpleRecord.SCHEMA$);
    SamzaSqlRelMessage message = simpleRecordAvroRelConverter.convertToRelMessage(new KV<>("key", record));
    Assert.assertEquals(message.getSamzaSqlRelRecord().getFieldNames().size(),
        message.getSamzaSqlRelRecord().getFieldValues().size());
  }

  @Test
  public void testNullRecordConversion() {
    SamzaSqlRelMessage message = simpleRecordAvroRelConverter.convertToRelMessage(new KV<>("key", null));
    Assert.assertEquals(message.getSamzaSqlRelRecord().getFieldNames().size(),
        message.getSamzaSqlRelRecord().getFieldValues().size());
  }

  public static <T> byte[] encodeAvroSpecificRecord(Class<T> clazz, T record) throws IOException {
    DatumWriter<T> msgDatumWriter = new SpecificDatumWriter<>(clazz);
    ByteArrayOutputStream os = new ByteArrayOutputStream();

    Encoder encoder = EncoderFactory.get().binaryEncoder(os, null);
    msgDatumWriter.write(record, encoder);
    encoder.flush();
    return os.toByteArray();
  }

  @Test
  public void testComplexRecordConversion() throws IOException {
    GenericData.Record record = new GenericData.Record(ComplexRecord.SCHEMA$);

    record.put("id", id);
    record.put("bool_value", boolValue);
    record.put("double_value", doubleValue);
    record.put("float_value0", floatValue);
    record.put("string_value", testStrValue);
    record.put("bytes_value", testBytes);
    record.put("fixed_value", fixedBytes);
    record.put("long_value", longValue);
    record.put("array_values", arrayValue);
    record.put("map_values", mapValue);
    record.put("union_value", testStrValue);

    ComplexRecord complexRecord = new ComplexRecord();
    complexRecord.id = id;
    complexRecord.bool_value = boolValue;
    complexRecord.double_value = doubleValue;
    complexRecord.float_value0 = floatValue;
    complexRecord.string_value = testStrValue;
    complexRecord.bytes_value = testBytes;
    complexRecord.fixed_value = fixedBytes;
    complexRecord.long_value = longValue;
    complexRecord.array_values = new ArrayList<>();
    complexRecord.array_values.addAll(arrayValue);
    complexRecord.map_values = new HashMap<>();
    complexRecord.map_values.putAll(mapValue);
    complexRecord.union_value = testStrValue;


    byte[] serializedData = bytesFromGenericRecord(record);
    validateAvroSerializedData(serializedData, testStrValue);

    serializedData = encodeAvroSpecificRecord(ComplexRecord.class, complexRecord);
    validateAvroSerializedData(serializedData, testStrValue);
  }

  @Test
  public void testComplexUnionConversionShouldWorkWithBothStringAndIntTypes() throws Exception {
    // ComplexUnion is a nested avro non-nullable union-type with both String and Integer type
    // Test the complex-union conversion for String type.
    GenericData.Record record = new GenericData.Record(ComplexUnion.SCHEMA$);
    record.put("non_nullable_union_value", testStrValue);

    ComplexUnion complexUnion = new ComplexUnion();
    complexUnion.non_nullable_union_value = testStrValue;

    byte[] serializedData = bytesFromGenericRecord(record);
    GenericRecord genericRecord = genericRecordFromBytes(serializedData, ComplexUnion.SCHEMA$);
    SamzaSqlRelMessage message = complexUnionAvroRelConverter.convertToRelMessage(new KV<>("key", genericRecord));

    Assert.assertEquals(testStrValue, message.getSamzaSqlRelRecord().getField("non_nullable_union_value").get().toString());

    serializedData = encodeAvroSpecificRecord(ComplexUnion.class, complexUnion);
    genericRecord = genericRecordFromBytes(serializedData, ComplexUnion.SCHEMA$);
    Assert.assertEquals(testStrValue, genericRecord.get("non_nullable_union_value").toString());

    // Testing the complex-union conversion for Integer type
    record.put("non_nullable_union_value", Integer.valueOf(123));

    complexUnion.non_nullable_union_value = Integer.valueOf(123);

    serializedData = bytesFromGenericRecord(record);
    genericRecord = genericRecordFromBytes(serializedData, ComplexUnion.SCHEMA$);
    message = complexUnionAvroRelConverter.convertToRelMessage(new KV<>("key", genericRecord));
    Assert.assertEquals(Integer.valueOf(123), message.getSamzaSqlRelRecord().getField("non_nullable_union_value").get());

    serializedData = encodeAvroSpecificRecord(ComplexUnion.class, complexUnion);
    genericRecord = genericRecordFromBytes(serializedData, ComplexUnion.SCHEMA$);
    Assert.assertEquals(Integer.valueOf(123), genericRecord.get("non_nullable_union_value"));

  }

  @Test
  public void testNestedRecordConversion() throws IOException {
    GenericData.Record record = new GenericData.Record(Profile.SCHEMA$);
    record.put("id", 1);
    record.put("name", "name1");
    record.put("companyId", 0);
    GenericData.Record addressRecord = new GenericData.Record(AddressRecord.SCHEMA$);
    addressRecord.put("zip", 90000);
    GenericData.Record streetNumRecord = new GenericData.Record(StreetNumRecord.SCHEMA$);
    streetNumRecord.put("number", 1200);
    addressRecord.put("streetnum", streetNumRecord);
    record.put("address", addressRecord);
    record.put("selfEmployed", "True");

    GenericData.Record phoneNumberRecordH = new GenericData.Record(PhoneNumber.SCHEMA$);
    phoneNumberRecordH.put("kind", Kind.Home);
    phoneNumberRecordH.put("number", "111-111-1111");
    GenericData.Record phoneNumberRecordC = new GenericData.Record(PhoneNumber.SCHEMA$);
    phoneNumberRecordC.put("kind", Kind.Cell);
    phoneNumberRecordC.put("number", "111-111-1112");
    List<GenericData.Record> phoneNumbers = new ArrayList<>();
    phoneNumbers.add(phoneNumberRecordH);
    phoneNumbers.add(phoneNumberRecordC);
    record.put("phoneNumbers", phoneNumbers);

    GenericData.Record simpleRecord1 = new GenericData.Record(SimpleRecord.SCHEMA$);
    simpleRecord1.put("id", 1);
    simpleRecord1.put("name", "name1");
    GenericData.Record simpleRecord2 = new GenericData.Record(SimpleRecord.SCHEMA$);
    simpleRecord2.put("id", 2);
    simpleRecord2.put("name", "name2");
    HashMap<String, IndexedRecord> mapValues = new HashMap<>();
    mapValues.put("key1", simpleRecord1);
    mapValues.put("key2", simpleRecord2);
    record.put("mapValues", mapValues);

    SamzaSqlRelMessage relMessage = nestedRecordAvroRelConverter.convertToRelMessage(new KV<>("key", record));

    LOG.info(relMessage.toString());

    KV<Object, Object> samzaMessage = nestedRecordAvroRelConverter.convertToSamzaMessage(relMessage);
    GenericRecord recordPostConversion = (GenericRecord) samzaMessage.getValue();

    for (Schema.Field field : Profile.SCHEMA$.getFields()) {
      // equals() on GenericRecord does the nested record equality check as well.
      Assert.assertEquals(record.get(field.name()), recordPostConversion.get(field.name()));
    }
  }

  // SAMZA-2110 We need to enable this when we have a true support for Null records
  @Ignore
  @Test
  public void testRecordConversionWithNullPayload() throws IOException {
    GenericData.Record record = null;
    SamzaSqlRelMessage relMessage = nestedRecordAvroRelConverter.convertToRelMessage(new KV<>("key", record));

    LOG.info(relMessage.toString());

    KV<Object, Object> samzaMessage = nestedRecordAvroRelConverter.convertToSamzaMessage(relMessage);
    GenericRecord recordPostConversion = (GenericRecord) samzaMessage.getValue();

    Assert.assertTrue(recordPostConversion == null);
  }

  @Test
  public void testNestedRecordConversionWithSubRecordsBeingNull() throws IOException {
    GenericData.Record record = new GenericData.Record(Profile.SCHEMA$);
    record.put("id", 1);
    record.put("name", "name1");
    record.put("companyId", 0);
    GenericData.Record addressRecord = null;
    record.put("address", addressRecord);
    record.put("selfEmployed", "True");

    List<GenericData.Record> phoneNumbers = null;
    record.put("phoneNumbers", phoneNumbers);

    HashMap<String, IndexedRecord> mapValues = null;
    record.put("mapValues", mapValues);

    SamzaSqlRelMessage relMessage = nestedRecordAvroRelConverter.convertToRelMessage(new KV<>("key", record));

    LOG.info(relMessage.toString());

    KV<Object, Object> samzaMessage = nestedRecordAvroRelConverter.convertToSamzaMessage(relMessage);
    GenericRecord recordPostConversion = (GenericRecord) samzaMessage.getValue();

    for (Schema.Field field : Profile.SCHEMA$.getFields()) {
      // equals() on GenericRecord does the nested record equality check as well.
      Assert.assertEquals(record.get(field.name()), recordPostConversion.get(field.name()));
    }
  }

  private static <T> T genericRecordFromBytes(byte[] bytes, Schema schema) throws IOException {
    BinaryDecoder binDecoder = DecoderFactory.defaultFactory().createBinaryDecoder(bytes, null);
    GenericDatumReader<T> reader = new GenericDatumReader<>(schema);
    return reader.read(null, binDecoder);
  }

  private static byte[] bytesFromGenericRecord(GenericRecord record) throws IOException {
    DatumWriter<IndexedRecord> datumWriter;
    datumWriter = new GenericDatumWriter<>(record.getSchema());
    ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
    BinaryEncoder encoder = EncoderFactory.get().binaryEncoder(outputStream, null);
    datumWriter.write(record, encoder);
    encoder.flush();
    outputStream.close();
    return outputStream.toByteArray();
  }

  private void validateAvroSerializedData(byte[] serializedData, Object unionValue) throws IOException {
    GenericRecord complexRecordValue = genericRecordFromBytes(serializedData, ComplexRecord.SCHEMA$);

    SamzaSqlRelMessage message = complexRecordAvroRelConverter.convertToRelMessage(new KV<>("key", complexRecordValue));
    Assert.assertEquals(message.getSamzaSqlRelRecord().getFieldNames().size(),
        ComplexRecord.SCHEMA$.getFields().size() + 1);

    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("id").get(), id);
    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("bool_value").get(), boolValue);
    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("double_value").get(), doubleValue);
    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("string_value").get(), new Utf8(testStrValue));
    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("float_value0").get(), floatValue);
    Assert.assertEquals(message.getSamzaSqlRelRecord().getField("long_value").get(), longValue);
    if (unionValue instanceof String) {
      Assert.assertEquals(message.getSamzaSqlRelRecord().getField("union_value").get(), new Utf8((String) unionValue));
    } else {
      Assert.assertEquals(message.getSamzaSqlRelRecord().getField("union_value").get(), unionValue);
    }
    Assert.assertTrue(arrayValue.stream()
        .map(Utf8::new)
        .collect(Collectors.toList())
        .equals(message.getSamzaSqlRelRecord().getField("array_values").get()));
    Assert.assertTrue(mapValue.entrySet()
        .stream()
        .collect(Collectors.toMap(x -> new Utf8(x.getKey()), y -> new Utf8(y.getValue())))
        .equals(message.getSamzaSqlRelRecord().getField("map_values").get()));

    Assert.assertTrue(
        Arrays.equals(((ByteString) message.getSamzaSqlRelRecord().getField("bytes_value").get()).getBytes(),
            testBytes.array()));
    Assert.assertTrue(
        Arrays.equals(((ByteString) message.getSamzaSqlRelRecord().getField("fixed_value").get()).getBytes(),
            DEFAULT_TRACKING_ID_BYTES));

    LOG.info(message.toString());

    KV<Object, Object> samzaMessage = complexRecordAvroRelConverter.convertToSamzaMessage(message);
    GenericRecord record = (GenericRecord) samzaMessage.getValue();

    for (Schema.Field field : ComplexRecord.SCHEMA$.getFields()) {
      if (field.name().equals("array_values")) {
        Assert.assertTrue(record.get(field.name()).equals(complexRecordValue.get(field.name())));
      } else {
        Object expected = complexRecordValue.get(field.name());
        Assert.assertEquals(expected, record.get(field.name()));
      }
    }
  }
}
