[fix] streaming write execution plan error (#135)

* fix streaming write error and add json data pass through option
* handle stream pass through, force set read_json_by_line is true when format is json
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 2ab200d..09c0416 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -100,4 +100,10 @@
     String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc";
     boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false;
 
+    /**
+     * pass through json data when sink to doris in streaming mode
+     */
+    String DORIS_SINK_STREAMING_PASSTHROUGH = "doris.sink.streaming.passthrough";
+    boolean DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT = false;
+
 }
diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 4a7b1e0..ac920cd 100644
--- a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -96,6 +96,8 @@
 
     private boolean readJsonByLine = false;
 
+    private boolean streamingPassthrough = false;
+
     public DorisStreamLoad(SparkSettings settings) {
         String[] dbTable = settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
         this.db = dbTable[0];
@@ -121,6 +123,8 @@
             }
         }
         LINE_DELIMITER = escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
+        this.streamingPassthrough = settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
+                ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
     }
 
     public String getLoadUrlStr() {
@@ -196,6 +200,38 @@
 
     }
 
+    public List<Integer> loadStream(List<List<Object>> rows, String[] dfColumns, Boolean enable2PC)
+            throws StreamLoadException, JsonProcessingException {
+
+        List<String> loadData;
+
+        if (this.streamingPassthrough) {
+            handleStreamPassThrough();
+            loadData = passthrough(rows);
+        } else {
+            loadData = parseLoadData(rows, dfColumns);
+        }
+
+        List<Integer> txnIds = new ArrayList<>(loadData.size());
+
+        try {
+            for (String data : loadData) {
+                txnIds.add(load(data, enable2PC));
+            }
+        } catch (StreamLoadException e) {
+            if (enable2PC && !txnIds.isEmpty()) {
+                LOG.error("load batch failed, abort previously pre-committed transactions");
+                for (Integer txnId : txnIds) {
+                    abort(txnId);
+                }
+            }
+            throw e;
+        }
+
+        return txnIds;
+
+    }
+
     public int load(String value, Boolean enable2PC) throws StreamLoadException {
 
         String label = generateLoadLabel();
@@ -442,4 +478,18 @@
         return hexData;
     }
 
+    private void handleStreamPassThrough() {
+
+        if ("json".equalsIgnoreCase(fileType)) {
+            LOG.info("handle stream pass through, force set read_json_by_line is true for json format");
+            streamLoadProp.put("read_json_by_line", "true");
+            streamLoadProp.remove("strip_outer_array");
+        }
+
+    }
+
+    private List<String> passthrough(List<List<Object>> values) {
+        return values.stream().map(list -> list.get(0).toString()).collect(Collectors.toList());
+    }
+
 }
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
index 342e940..d1a2b74 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -34,7 +34,7 @@
     if (batchId <= latestBatchId) {
       logger.info(s"Skipping already committed batch $batchId")
     } else {
-      writer.write(data)
+      writer.writeStream(data)
       latestBatchId = batchId
     }
   }
diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 2b918e8..e32267e 100644
--- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -22,6 +22,9 @@
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CollectionAccumulator
 import org.slf4j.{Logger, LoggerFactory}
 
 import java.io.IOException
@@ -76,28 +79,13 @@
      * flush data to Doris and do retry when flush error
      *
      */
-    def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): Unit = {
+    def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = {
       Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
-        dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
+        dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
       } match {
-        case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => preCommittedTxnAcc.add(txnId))
+        case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
         case Failure(e) =>
-          if (enable2PC) {
-            // if task run failed, acc value will not be returned to driver,
-            // should abort all pre committed transactions inside the task
-            logger.info("load task failed, start aborting previously pre-committed transactions")
-            val abortFailedTxnIds = mutable.Buffer[Int]()
-            preCommittedTxnAcc.value.asScala.foreach(txnId => {
-              Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
-                dorisStreamLoader.abort(txnId)
-              } match {
-                case Success(_) =>
-                case Failure(_) => abortFailedTxnIds += txnId
-              }
-            })
-            if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
-            preCommittedTxnAcc.reset()
-          }
+          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
           throw new IOException(
             s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
       }
@@ -105,5 +93,76 @@
 
   }
 
+  def writeStream(dataFrame: DataFrame): Unit = {
+
+    val sc = dataFrame.sqlContext.sparkContext
+    val preCommittedTxnAcc = sc.collectionAccumulator[Int]("preCommittedTxnAcc")
+    if (enable2PC) {
+      sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, dorisStreamLoader))
+    }
+
+    var resultRdd = dataFrame.queryExecution.toRdd
+    val schema = dataFrame.schema
+    val dfColumns = dataFrame.columns
+    if (Objects.nonNull(sinkTaskPartitionSize)) {
+      resultRdd = if (sinkTaskUseRepartition) resultRdd.repartition(sinkTaskPartitionSize) else resultRdd.coalesce(sinkTaskPartitionSize)
+    }
+    resultRdd
+      .foreachPartition(partition => {
+        partition
+          .grouped(batchSize)
+          .foreach(batch =>
+            flush(batch, dfColumns))
+      })
+
+    /**
+     * flush data to Doris and do retry when flush error
+     *
+     */
+    def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
+      Utils.retry[util.List[Integer], Exception](maxRetryTimes, Duration.ofMillis(batchInterValMs.toLong), logger) {
+        dorisStreamLoader.loadStream(convertToObjectList(batch, schema), dfColumns, enable2PC)
+      } match {
+        case Success(txnIds) => if (enable2PC) handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
+        case Failure(e) =>
+          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+          throw new IOException(
+            s"Failed to load batch data on BE: ${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} retry times.", e)
+      }
+    }
+
+    def convertToObjectList(rows: Seq[InternalRow], schema: StructType): util.List[util.List[Object]] = {
+      rows.map(row => {
+        row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
+      }).asJava
+    }
+
+  }
+
+  private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: CollectionAccumulator[Int]): Unit = {
+    txnIds.foreach(txnId => acc.add(txnId))
+  }
+
+  def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
+    // if task run failed, acc value will not be returned to driver,
+    // should abort all pre committed transactions inside the task
+    logger.info("load task failed, start aborting previously pre-committed transactions")
+    if (acc.isZero) {
+      logger.info("no pre-committed transactions, skip abort")
+      return
+    }
+    val abortFailedTxnIds = mutable.Buffer[Int]()
+    acc.value.asScala.foreach(txnId => {
+      Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
+        dorisStreamLoader.abort(txnId)
+      } match {
+        case Success(_) =>
+        case Failure(_) => abortFailedTxnIds += txnId
+      }
+    })
+    if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", abortFailedTxnIds.mkString(","))
+    acc.reset()
+  }
+
 
 }