[feature] support struct and map type (#212)
Support doris map, struct type reading and writing
```java
//doris create table
CREATE TABLE `simple_map2` (
`id` int(11) NULL,
`m` MAP<text,int(11)> NULL,
`s_info` STRUCT<s_id:int(11),s_name:text,s_address:text> NULL
) ENGINE=OLAP
DUPLICATE KEY(`id`)
COMMENT 'OLAP'
DISTRIBUTED BY HASH(`id`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"is_being_synced" = "false",
"storage_format" = "V2",
"light_schema_change" = "true",
"disable_auto_compaction" = "false",
"enable_single_replica_compaction" = "false"
);
//datagen->doris
tEnv.executeSql(
"CREATE TABLE doris_test (" +
" id int,\n" +
" task Map<String,int>,\n" +
" buyer ROW<s_id int,s_name string, s_address string>\n" +
") " +
"WITH (\n" +
" 'connector' = 'datagen', \n" +
" 'number-of-rows' = '11' \n" +
")");
tEnv.executeSql("CREATE TABLE blackhole_table (" +
"id int," +
"m Map<String,int>," +
"s_info Row<s_id int,s_name string, s_address string>" +
") WITH (" +
" 'connector' = 'doris',\n" +
" 'fenodes' = '127.0.0.1:8030',\n" +
" 'table.identifier' = 'test.simple_map2',\n" +
" 'sink.enable-2pc' = 'false',\n" +
" 'username' = 'root',\n" +
" 'password' = '',\n" +
" 'sink.properties.format' = 'json',\n" +
" 'sink.properties.read_json_by_line' = 'true'\n" +
");");
tEnv.executeSql("INSERT INTO blackhole_table select * from doris_test");
//doris->doris
tEnv.executeSql(
"CREATE TABLE doris_source (" +
"id int," +
"m Map<String,int>," +
"s_info Row<s_id int,s_name string, s_address string>" +
") " +
"WITH (\n" +
" 'connector' = 'doris',\n" +
" 'fenodes' = '127.0.0.1:8030',\n" +
" 'table.identifier' = 'test.simple_map',\n" +
" 'username' = 'root',\n" +
" 'password' = ''\n" +
")");
tEnv.executeSql("CREATE TABLE blackhole_table (" +
"id int," +
"m Map<String,int>," +
"s_info Row<s_id int,s_name string, s_address string>" +
") WITH (" +
" 'connector' = 'doris',\n" +
" 'fenodes' = '127.0.0.1:8030',\n" +
" 'table.identifier' = 'test.simple_map2',\n" +
" 'sink.enable-2pc' = 'false',\n" +
" 'username' = 'root',\n" +
" 'password' = '',\n" +
" 'sink.properties.format' = 'json',\n" +
" 'sink.properties.read_json_by_line' = 'true'\n" +
");");
tEnv.executeSql("insert into blackhole_table select * from doris_source");
```
diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java
index 6fa3be9..ebc0ff6 100644
--- a/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java
+++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/deserialization/converter/DorisRowConverter.java
@@ -220,7 +220,9 @@
case ARRAY:
return val -> convertArrayData(((List<?>) val).toArray(), type);
case ROW:
+ return val -> convertRowData((Map<String, ?>) val, type);
case MAP:
+ return val -> convertMapData((Map) val, type);
case MULTISET:
case RAW:
default:
@@ -298,6 +300,33 @@
return arrayData;
}
+ private MapData convertMapData(Map<Object, Object> map, LogicalType type){
+ MapType mapType = (MapType) type;
+ DeserializationConverter keyConverter = createNullableInternalConverter(mapType.getKeyType());
+ DeserializationConverter valueConverter = createNullableInternalConverter(mapType.getValueType());
+ Map<Object, Object> result = new HashMap<>();
+ for(Map.Entry<Object, Object> entry : map.entrySet()){
+ Object key = keyConverter.deserialize(entry.getKey());
+ Object value = valueConverter.deserialize(entry.getValue());
+ result.put(key, value);
+ }
+ GenericMapData mapData = new GenericMapData(result);
+ return mapData;
+ }
+
+ private RowData convertRowData(Map<String, ?> row, LogicalType type) {
+ RowType rowType = (RowType) type;
+ GenericRowData rowData = new GenericRowData(row.size());
+ int index = 0;
+ for(Map.Entry<String, ?> entry : row.entrySet()){
+ DeserializationConverter converter = createNullableInternalConverter(rowType.getTypeAt(index));
+ Object value = converter.deserialize(entry.getValue());
+ rowData.setField(index, value);
+ index++;
+ }
+ return rowData;
+ }
+
private List<Object> convertArrayData(ArrayData array, LogicalType type){
if(array instanceof GenericArrayData){
return Arrays.asList(((GenericArrayData)array).toObjectArray());
diff --git a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
index de63a6e..ad8bb72 100644
--- a/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
+++ b/flink-doris-connector/src/main/java/org/apache/doris/flink/serialization/RowBatch.java
@@ -32,6 +32,9 @@
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.doris.flink.exception.DorisException;
@@ -50,7 +53,9 @@
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.NoSuchElementException;
/**
@@ -329,6 +334,31 @@
//todo: when the subtype of array is date, conversion is required
addValueToRow(rowIndex, listValue);
break;
+ case "MAP":
+ if (!minorType.equals(Types.MinorType.MAP)) return false;
+ MapVector mapVector = (MapVector) fieldVector;
+ UnionMapReader reader = mapVector.getReader();
+ if (mapVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ break;
+ }
+ reader.setPosition(rowIndex);
+ Map<String, Object> mapValue = new HashMap<>();
+ while (reader.next()) {
+ mapValue.put(reader.key().readObject().toString(), reader.value().readObject());
+ }
+ addValueToRow(rowIndex, mapValue);
+ break;
+ case "STRUCT":
+ if (!minorType.equals(Types.MinorType.STRUCT)) return false;
+ StructVector structVector = (StructVector) fieldVector;
+ if (structVector.isNull(rowIndex)) {
+ addValueToRow(rowIndex, null);
+ break;
+ }
+ Map<String, ?> structValue = structVector.getObject(rowIndex);
+ addValueToRow(rowIndex, structValue);
+ break;
default:
String errMsg = "Unsupported type " + schema.get(col).getType();
logger.error(errMsg);
diff --git a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
index e2ee827..47071f5 100644
--- a/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
+++ b/flink-doris-connector/src/test/java/org/apache/doris/flink/serialization/TestRowBatch.java
@@ -17,6 +17,7 @@
package org.apache.doris.flink.serialization;
+import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
@@ -30,17 +31,25 @@
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.NullableStructWriter;
+import org.apache.arrow.vector.complex.impl.UnionMapWriter;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.util.Text;
+import org.apache.doris.flink.exception.DorisException;
import org.apache.doris.flink.rest.RestService;
import org.apache.doris.flink.rest.models.Schema;
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.sdk.thrift.TStatus;
import org.apache.doris.sdk.thrift.TStatusCode;
+import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableList;
+import org.apache.flink.shaded.guava30.com.google.common.collect.ImmutableMap;
import org.apache.flink.table.data.DecimalData;
import org.junit.Assert;
import org.junit.Rule;
@@ -50,7 +59,9 @@
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
+import java.io.IOException;
import java.math.BigDecimal;
+import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.util.ArrayList;
@@ -441,4 +452,137 @@
thrown.expectMessage(startsWith("Get row offset:"));
rowBatch.next();
}
+
+ @Test
+ public void testMap() throws IOException, DorisException {
+
+ ImmutableList<Field> mapChildren = ImmutableList.of(
+ new Field("child", new FieldType(false, new ArrowType.Struct(), null),
+ ImmutableList.of(
+ new Field("key", new FieldType(false, new ArrowType.Utf8(), null), null),
+ new Field("value", new FieldType(false, new ArrowType.Int(32, true), null),
+ null)
+ )
+ ));
+
+ ImmutableList<Field> fields = ImmutableList.of(
+ new Field("col_map", new FieldType(false, new ArrowType.Map(false), null),
+ mapChildren)
+ );
+
+ RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(3);
+
+ MapVector mapVector = (MapVector) root.getVector("col_map");
+ mapVector.allocateNew();
+ UnionMapWriter mapWriter = mapVector.getWriter();
+ for (int i = 0; i < 3; i++) {
+ mapWriter.setPosition(i);
+ mapWriter.startMap();
+ mapWriter.startEntry();
+ String key = "k" + (i + 1);
+ byte[] bytes = key.getBytes(StandardCharsets.UTF_8);
+ ArrowBuf buffer = allocator.buffer(bytes.length);
+ buffer.setBytes(0, bytes);
+ mapWriter.key().varChar().writeVarChar(0, bytes.length, buffer);
+ buffer.close();
+ mapWriter.value().integer().writeInt(i);
+ mapWriter.endEntry();
+ mapWriter.endMap();
+ }
+ mapWriter.setValueCount(3);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+ String schemaStr = "{\"properties\":[{\"type\":\"MAP\",\"name\":\"col_map\",\"comment\":\"\"}" +
+ "], \"status\":200}";
+
+
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema).readArrow();
+ Assert.assertTrue(rowBatch.hasNext());
+ Assert.assertTrue(ImmutableMap.of("k1", 0).equals(rowBatch.next().get(0)));
+ Assert.assertTrue(rowBatch.hasNext());
+ Assert.assertTrue(ImmutableMap.of("k2", 1).equals(rowBatch.next().get(0)));
+ Assert.assertTrue(rowBatch.hasNext());
+ Assert.assertTrue(ImmutableMap.of("k3", 2).equals(rowBatch.next().get(0)));
+ Assert.assertFalse(rowBatch.hasNext());
+
+ }
+
+ @Test
+ public void testStruct() throws IOException, DorisException {
+
+ ImmutableList<Field> fields = ImmutableList.of(
+ new Field("col_struct", new FieldType(false, new ArrowType.Struct(), null),
+ ImmutableList.of(new Field("a", new FieldType(false, new ArrowType.Utf8(), null), null),
+ new Field("b", new FieldType(false, new ArrowType.Int(32, true), null), null))
+ ));
+
+ RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE);
+ VectorSchemaRoot root = VectorSchemaRoot.create(
+ new org.apache.arrow.vector.types.pojo.Schema(fields, null), allocator);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ ArrowStreamWriter arrowStreamWriter = new ArrowStreamWriter(
+ root,
+ new DictionaryProvider.MapDictionaryProvider(),
+ outputStream);
+
+ arrowStreamWriter.start();
+ root.setRowCount(3);
+
+ StructVector structVector = (StructVector) root.getVector("col_struct");
+ structVector.allocateNew();
+ NullableStructWriter writer = structVector.getWriter();
+ writer.setPosition(0);
+ writer.start();
+ byte[] bytes = "a1".getBytes(StandardCharsets.UTF_8);
+ ArrowBuf buffer = allocator.buffer(bytes.length);
+ buffer.setBytes(0, bytes);
+ writer.varChar("a").writeVarChar(0, bytes.length, buffer);
+ buffer.close();
+ writer.integer("b").writeInt(1);
+ writer.end();
+ writer.setValueCount(1);
+
+ arrowStreamWriter.writeBatch();
+
+ arrowStreamWriter.end();
+ arrowStreamWriter.close();
+
+ TStatus status = new TStatus();
+ status.setStatusCode(TStatusCode.OK);
+ TScanBatchResult scanBatchResult = new TScanBatchResult();
+ scanBatchResult.setStatus(status);
+ scanBatchResult.setEos(false);
+ scanBatchResult.setRows(outputStream.toByteArray());
+
+ String schemaStr = "{\"properties\":[{\"type\":\"STRUCT\",\"name\":\"col_struct\",\"comment\":\"\"}" +
+ "], \"status\":200}";
+ Schema schema = RestService.parseSchema(schemaStr, logger);
+
+ RowBatch rowBatch = new RowBatch(scanBatchResult, schema).readArrow();
+ Assert.assertTrue(rowBatch.hasNext());
+ Assert.assertTrue(ImmutableMap.of("a", new Text("a1"),"b",1).equals(rowBatch.next().get(0)));
+ }
}