blob: 8fd84d383c8de4b77a4671c27877248e4024b865 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.spark.serialization;
import com.google.common.base.Preconditions;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.FixedSizeBinaryVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeStampMicroVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.StructVector;
import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.types.Types;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.doris.sdk.thrift.TScanBatchResult;
import org.apache.doris.spark.exception.DorisException;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.spark.sql.types.Decimal;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.JavaConverters;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.sql.Date;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Objects;
/**
* row batch data container.
*/
public class RowBatch {
private static final Logger logger = LoggerFactory.getLogger(RowBatch.class);
private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder()
.appendPattern("yyyy-MM-dd HH:mm:ss")
.appendFraction(ChronoField.MICRO_OF_SECOND, 0, 6, true)
.toFormatter();
public static class Row {
private final List<Object> cols;
Row(int colCount) {
this.cols = new ArrayList<>(colCount);
}
List<Object> getCols() {
return cols;
}
public void put(Object o) {
cols.add(o);
}
}
// offset for iterate the rowBatch
private int offsetInRowBatch = 0;
private int rowCountInOneBatch = 0;
private int readRowCount = 0;
private final List<Row> rowBatch = new ArrayList<>();
private final ArrowStreamReader arrowStreamReader;
private List<FieldVector> fieldVectors;
private final RootAllocator rootAllocator;
private final Schema schema;
public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisException {
this.schema = schema;
this.rootAllocator = new RootAllocator(Integer.MAX_VALUE);
this.arrowStreamReader = new ArrowStreamReader(
new ByteArrayInputStream(nextResult.getRows()),
rootAllocator
);
try {
VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot();
while (arrowStreamReader.loadNextBatch()) {
fieldVectors = root.getFieldVectors();
if (fieldVectors.size() > schema.size()) {
logger.error("Data schema size '{}' should not be bigger than arrow field size '{}'.",
schema.size(), fieldVectors.size());
throw new DorisException("Load Doris data failed, schema size of fetch data is wrong.");
}
if (fieldVectors.isEmpty() || root.getRowCount() == 0) {
logger.debug("One batch in arrow has no data.");
continue;
}
rowCountInOneBatch = root.getRowCount();
// init the rowBatch
for (int i = 0; i < rowCountInOneBatch; ++i) {
rowBatch.add(new Row(fieldVectors.size()));
}
convertArrowToRowBatch();
readRowCount += root.getRowCount();
}
} catch (Exception e) {
logger.error("Read Doris Data failed because: ", e);
throw new DorisException(e.getMessage());
} finally {
close();
}
}
public boolean hasNext() {
if (offsetInRowBatch >= readRowCount) {
rowBatch.clear();
return false;
}
return true;
}
private void addValueToRow(int rowIndex, Object obj) {
if (rowIndex > rowCountInOneBatch) {
String errMsg = "Get row offset: " + rowIndex + " larger than row size: " +
rowCountInOneBatch;
logger.error(errMsg);
throw new NoSuchElementException(errMsg);
}
rowBatch.get(readRowCount + rowIndex).put(obj);
}
public void convertArrowToRowBatch() throws DorisException {
try {
for (int col = 0; col < fieldVectors.size(); col++) {
FieldVector curFieldVector = fieldVectors.get(col);
Types.MinorType mt = curFieldVector.getMinorType();
final String currentType = schema.get(col).getType();
switch (currentType) {
case "NULL_TYPE":
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
addValueToRow(rowIndex, null);
}
break;
case "BOOLEAN":
Preconditions.checkArgument(mt.equals(Types.MinorType.BIT),
typeMismatchMessage(currentType, mt));
BitVector bitVector = (BitVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = bitVector.isNull(rowIndex) ? null : bitVector.get(rowIndex) != 0;
addValueToRow(rowIndex, fieldValue);
}
break;
case "TINYINT":
Preconditions.checkArgument(mt.equals(Types.MinorType.TINYINT),
typeMismatchMessage(currentType, mt));
TinyIntVector tinyIntVector = (TinyIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = tinyIntVector.isNull(rowIndex) ? null : tinyIntVector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "SMALLINT":
Preconditions.checkArgument(mt.equals(Types.MinorType.SMALLINT),
typeMismatchMessage(currentType, mt));
SmallIntVector smallIntVector = (SmallIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "INT":
Preconditions.checkArgument(mt.equals(Types.MinorType.INT),
typeMismatchMessage(currentType, mt));
IntVector intVector = (IntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = intVector.isNull(rowIndex) ? null : intVector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "BIGINT":
Preconditions.checkArgument(mt.equals(Types.MinorType.BIGINT),
typeMismatchMessage(currentType, mt));
BigIntVector bigIntVector = (BigIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = bigIntVector.isNull(rowIndex) ? null : bigIntVector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "LARGEINT":
Preconditions.checkArgument(mt.equals(Types.MinorType.FIXEDSIZEBINARY) ||
mt.equals(Types.MinorType.VARCHAR), typeMismatchMessage(currentType, mt));
if (mt.equals(Types.MinorType.FIXEDSIZEBINARY)) {
FixedSizeBinaryVector largeIntVector = (FixedSizeBinaryVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (largeIntVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
byte[] bytes = largeIntVector.get(rowIndex);
ArrayUtils.reverse(bytes);
BigInteger largeInt = new BigInteger(bytes);
addValueToRow(rowIndex, Decimal.apply(largeInt));
}
} else {
VarCharVector largeIntVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (largeIntVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String stringValue = new String(largeIntVector.get(rowIndex));
BigInteger largeInt = new BigInteger(stringValue);
addValueToRow(rowIndex, Decimal.apply(largeInt));
}
}
break;
case "FLOAT":
Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT4),
typeMismatchMessage(currentType, mt));
Float4Vector float4Vector = (Float4Vector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = float4Vector.isNull(rowIndex) ? null : float4Vector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "TIME":
case "DOUBLE":
Preconditions.checkArgument(mt.equals(Types.MinorType.FLOAT8),
typeMismatchMessage(currentType, mt));
Float8Vector float8Vector = (Float8Vector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = float8Vector.isNull(rowIndex) ? null : float8Vector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "BINARY":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARBINARY),
typeMismatchMessage(currentType, mt));
VarBinaryVector varBinaryVector = (VarBinaryVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
addValueToRow(rowIndex, fieldValue);
}
break;
case "DECIMAL":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
VarCharVector varCharVectorForDecimal = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (varCharVectorForDecimal.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String decimalValue = new String(varCharVectorForDecimal.get(rowIndex));
Decimal decimal = new Decimal();
try {
decimal.set(new scala.math.BigDecimal(new BigDecimal(decimalValue)));
} catch (NumberFormatException e) {
String errMsg = "Decimal response result '" + decimalValue + "' is illegal.";
logger.error(errMsg, e);
throw new DorisException(errMsg);
}
addValueToRow(rowIndex, decimal);
}
break;
case "DECIMALV2":
case "DECIMAL32":
case "DECIMAL64":
case "DECIMAL128I":
Preconditions.checkArgument(mt.equals(Types.MinorType.DECIMAL),
typeMismatchMessage(currentType, mt));
DecimalVector decimalVector = (DecimalVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (decimalVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
Decimal decimalV2 = Decimal.apply(decimalVector.getObject(rowIndex));
addValueToRow(rowIndex, decimalV2);
}
break;
case "DATE":
case "DATEV2":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR)
|| mt.equals(Types.MinorType.DATEDAY), typeMismatchMessage(currentType, mt));
if (mt.equals(Types.MinorType.VARCHAR)) {
VarCharVector date = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (date.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String stringValue = new String(date.get(rowIndex));
LocalDate localDate = LocalDate.parse(stringValue);
addValueToRow(rowIndex, Date.valueOf(localDate));
}
} else {
DateDayVector date = (DateDayVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (date.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
LocalDate localDate = LocalDate.ofEpochDay(date.get(rowIndex));
addValueToRow(rowIndex, Date.valueOf(localDate));
}
}
break;
case "DATETIME":
case "DATETIMEV2":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR)
|| mt.equals(Types.MinorType.TIMESTAMPMICRO),
typeMismatchMessage(currentType, mt));
if (mt.equals(Types.MinorType.VARCHAR)) {
VarCharVector varCharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (varCharVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String value = new String(varCharVector.get(rowIndex), StandardCharsets.UTF_8);
addValueToRow(rowIndex, value);
}
} else {
TimeStampMicroVector vector = (TimeStampMicroVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (vector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
long time = vector.get(rowIndex);
Instant instant;
if (time / 10000000000L == 0) { // datetime(0)
instant = Instant.ofEpochSecond(time);
} else if (time / 10000000000000L == 0) { // datetime(3)
instant = Instant.ofEpochMilli(time);
} else { // datetime(6)
instant = Instant.ofEpochSecond(time / 1000000, time % 1000000 * 1000);
}
LocalDateTime dateTime = LocalDateTime.ofInstant(instant, ZoneId.systemDefault());
String formatted = DATE_TIME_FORMATTER.format(dateTime);
addValueToRow(rowIndex, formatted);
}
}
break;
case "CHAR":
case "VARCHAR":
case "STRING":
case "JSONB":
case "VARIANT":
Preconditions.checkArgument(mt.equals(Types.MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
VarCharVector varCharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (varCharVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String value = new String(varCharVector.get(rowIndex), StandardCharsets.UTF_8);
addValueToRow(rowIndex, value);
}
break;
case "ARRAY":
Preconditions.checkArgument(mt.equals(Types.MinorType.LIST),
typeMismatchMessage(currentType, mt));
ListVector listVector = (ListVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (listVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String value = listVector.getObject(rowIndex).toString();
addValueToRow(rowIndex, value);
}
break;
case "MAP":
Preconditions.checkArgument(mt.equals(Types.MinorType.MAP),
typeMismatchMessage(currentType, mt));
MapVector mapVector = (MapVector) curFieldVector;
UnionMapReader reader = mapVector.getReader();
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (mapVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
reader.setPosition(rowIndex);
Map<String, String> value = new HashMap<>();
while (reader.next()) {
value.put(Objects.toString(reader.key().readObject(), null),
Objects.toString(reader.value().readObject(), null));
}
addValueToRow(rowIndex, JavaConverters.mapAsScalaMapConverter(value).asScala());
}
break;
case "STRUCT":
Preconditions.checkArgument(mt.equals(Types.MinorType.STRUCT),
typeMismatchMessage(currentType, mt));
StructVector structVector = (StructVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (structVector.isNull(rowIndex)) {
addValueToRow(rowIndex, null);
continue;
}
String value = structVector.getObject(rowIndex).toString();
addValueToRow(rowIndex, value);
}
break;
default:
String errMsg = "Unsupported type " + schema.get(col).getType();
logger.error(errMsg);
throw new DorisException(errMsg);
}
}
} catch (Exception e) {
close();
throw e;
}
}
public List<Object> next() {
if (!hasNext()) {
String errMsg = "Get row offset:" + offsetInRowBatch + " larger than row size: " + readRowCount;
logger.error(errMsg);
throw new NoSuchElementException(errMsg);
}
return rowBatch.get(offsetInRowBatch++).getCols();
}
private String typeMismatchMessage(final String sparkType, final Types.MinorType arrowType) {
final String messageTemplate = "Spark type is %1$s, but arrow type is %2$s.";
return String.format(messageTemplate, sparkType, arrowType.name());
}
public int getReadRowCount() {
return readRowCount;
}
public void close() {
try {
if (arrowStreamReader != null) {
arrowStreamReader.close();
}
if (rootAllocator != null) {
rootAllocator.close();
}
} catch (IOException ioe) {
// do nothing
}
}
}