[fix] streaming write execution plan error (#135)
* fix streaming write error and add json data pass through option
* handle stream pass through, force set read_json_by_line is true when format is json
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 2ab200d..09c0416 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -100,4 +100,10 @@
String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc";
boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false;
+ /**
+ * pass through json data when sink to doris in streaming mode
+ */
+ String DORIS_SINK_STREAMING_PASSTHROUGH = "doris.sink.streaming.passthrough";
+ boolean DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT = false;
+
}
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 4a7b1e0..ac920cd 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -96,6 +96,8 @@
private boolean readJsonByLine = false;
+ private boolean streamingPassthrough = false;
+
public DorisStreamLoad(SparkSettings settings) {
String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
this.db = dbTable[0];
@@ -121,6 +123,8 @@
}
}
LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
+ this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
+ ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
}
public String getLoadUrlStr() {
@@ -196,6 +200,38 @@
}
+ public List<Integer> loadStream(List<List<Object>> rows, String[] dfColumns, Boolean enable2PC)
+ throws StreamLoadException, JsonProcessingException {
+
+ List<String> loadData;
+
+ if (this.streamingPassthrough) {
+ handleStreamPassThrough();
+ loadData = passthrough(rows);
+ } else {
+ loadData = parseLoadData(rows, dfColumns);
+ }
+
+ List<Integer> txnIds = new ArrayList<>(loadData.size());
+
+ try {
+ for (String data : loadData) {
+ txnIds.add(load(data, enable2PC));
+ }
+ } catch (StreamLoadException e) {
+ if (enable2PC && !txnIds.isEmpty()) {
+ LOG.error("load batch failed, abort previously pre-committed transactions");
+ for (Integer txnId : txnIds) {
+ abort(txnId);
+ }
+ }
+ throw e;
+ }
+
+ return txnIds;
+
+ }
+
public int load(String value, Boolean enable2PC) throws StreamLoadException {
String label = generateLoadLabel();
@@ -442,4 +478,18 @@
return hexData;
}
+ private void handleStreamPassThrough() {
+
+ if ("json".equalsIgnoreCase(fileType)) {
+ LOG.info("handle stream pass through, force set read_json_by_line is true for json format");
+ streamLoadProp.put("read_json_by_line", "true");
+ streamLoadProp.remove("strip_outer_array");
+ }
+
+ }
+
+ private List<String> passthrough(List<List<Object>> values) {
+ return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList());
+ }
+
}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
index 342e940..d1a2b74 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -34,7 +34,7 @@
if (batchId <= latestBatchId) {
logger.info(s"Skipping already committed batch $batchId")
} else {
- writer.write(data)
+ writer.writeStream(data)
latestBatchId = batchId
}
}
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 2b918e8..e32267e 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -22,6 +22,9 @@
import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
import org.apache.doris.spark.sql.Utils
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CollectionAccumulator
import org.slf4j.{Logger, LoggerFactory}
import java.io.IOException
@@ -76,28 +79,13 @@
* flush data to Doris and do retry when flush error
*
*/
- def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = {
+ def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = {
Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
- dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
+ dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
} match {
- case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId))
+ case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
case Failure(e) =>
- if (enable2PC) {
- // if task run failed, acc value will not be returned to driver,
- // should abort all pre committed transactions inside the task
- logger.info("load task failed, start aborting previously pre-committed transactions")
- val abortFailedTxnIds = mutable.Buffer[Int]()
- preCommittedTxnAcc.value.asScala.foreach(txnId => {
- Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
- dorisStreamLoader.abort(txnId)
- } match {
- case Success(_) =>
- case Failure(_) => abortFailedTxnIds += txnId
- }
- })
- if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
- preCommittedTxnAcc.reset()
- }
+ if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
throw new IOException(
s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
}
@@ -105,5 +93,76 @@
}
+ def writeStream(dataFrame: DataFrame): Unit = {
+
+ val sc = dataFrame.sqlContext.sparkContext
+ val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
+ if (enable2PC) {
+ sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader))
+ }
+
+ var resultRdd = dataFrame.queryExecution.toRdd
+ val schema = dataFrame.schema
+ val dfColumns = dataFrame.columns
+ if (Objects.nonNull(sinkTaskPartitionSize)) {
+ resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
+ }
+ resultRdd
+ .foreachPartition(partition => {
+ partition
+ .grouped(batchSize)
+ .foreach(batch =>
+ flush(batch, dfColumns))
+ })
+
+ /**
+ * flush data to Doris and do retry when flush error
+ *
+ */
+ def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
+ Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
+ dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC)
+ } match {
+ case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
+ case Failure(e) =>
+ if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+ throw new IOException(
+ s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
+ }
+ }
+
+ def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = {
+ rows.map(row => {
+ row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
+ }).asJava
+ }
+
+ }
+
+ private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = {
+ txnIds.foreach(txnId => acc.add(txnId))
+ }
+
+ def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
+ // if task run failed, acc value will not be returned to driver,
+ // should abort all pre committed transactions inside the task
+ logger.info("load task failed, start aborting previously pre-committed transactions")
+ if (acc.isZero) {
+ logger.info("no pre-committed transactions, skip abort")
+ return
+ }
+ val abortFailedTxnIds = mutable.Buffer[Int]()
+ acc.value.asScala.foreach(txnId => {
+ Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
+ dorisStreamLoader.abort(txnId)
+ } match {
+ case Success(_) =>
+ case Failure(_) => abortFailedTxnIds += txnId
+ }
+ })
+ if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
+ acc.reset()
+ }
+
}