[Enhancemant] Add load balance strategy for frontends and backends (#329)
diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisBackendHttpClient.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisBackendHttpClient.java
index 1f28ae5..2df5cc1 100644
--- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisBackendHttpClient.java
+++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisBackendHttpClient.java
@@ -18,6 +18,7 @@
package org.apache.doris.spark.client;
import org.apache.doris.spark.client.entity.Backend;
+import org.apache.doris.spark.util.LoadBalanceList;
import org.apache.doris.spark.util.HttpUtil;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
@@ -32,12 +33,13 @@
public class DorisBackendHttpClient implements Serializable {
private static final Logger log = LoggerFactory.getLogger(DorisBackendHttpClient.class);
- private final List<Backend> backends;
+
+ private final LoadBalanceList<Backend> backends;
private transient CloseableHttpClient httpClient;
public DorisBackendHttpClient(List<Backend> backends) {
- this.backends = backends;
+ this.backends = new LoadBalanceList<>(backends);
}
public <T> T executeReq(BiFunction<Backend, CloseableHttpClient, T> reqFunc) throws Exception {
@@ -52,6 +54,7 @@
}
} catch (Exception e) {
log.warn("Failed to execute request on backend: {}:{}", backend.getHost(), backend.getHttpPort(), e);
+ backends.reportFailed(backend);
ex = e;
}
}
diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
index b0d1dc1..b99902b 100644
--- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
+++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/DorisFrontendClient.java
@@ -28,6 +28,7 @@
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.HttpUtil;
import org.apache.doris.spark.util.HttpUtils;
+import org.apache.doris.spark.util.LoadBalanceList;
import org.apache.doris.spark.util.URLs;
import com.fasterxml.jackson.databind.JsonNode;
@@ -73,7 +74,7 @@
private final DorisConfig config;
private final String username;
private final String password;
- private final List<Frontend> frontends;
+ private final LoadBalanceList<Frontend> frontends;
private final boolean isHttpsEnabled;
private transient CloseableHttpClient httpClient;
@@ -83,7 +84,7 @@
this.password = null;
this.httpClient = null;
this.isHttpsEnabled = false;
- this.frontends = Collections.emptyList();
+ this.frontends = new LoadBalanceList<>(Collections.emptyList());
}
public DorisFrontendClient(DorisConfig config) throws Exception {
@@ -94,17 +95,18 @@
this.frontends = initFrontends(config);
}
- private List<Frontend> initFrontends(DorisConfig config) throws Exception {
+ private LoadBalanceList<Frontend> initFrontends(DorisConfig config) throws Exception {
String frontendNodes = config.getValue(DorisOptions.DORIS_FENODES);
String[] frontendNodeArray = frontendNodes.split(",");
+ List<Frontend> frontendList = null;
if (config.getValue(DorisOptions.DORIS_FE_AUTO_FETCH)) {
Exception ex = null;
- List<Frontend> frontendList = null;
for (String frontendNode : frontendNodeArray) {
String[] nodeDetails = frontendNode.split(":");
try {
- List<Frontend> list = Collections.singletonList(new Frontend(nodeDetails[0],
- nodeDetails.length > 1 ? Integer.parseInt(nodeDetails[1]) : -1));
+ LoadBalanceList<Frontend> list = new LoadBalanceList<>(
+ Collections.singletonList(new Frontend(nodeDetails[0],
+ nodeDetails.length > 1 ? Integer.parseInt(nodeDetails[1]) : -1)));
frontendList = requestFrontends(list, (frontend, client) -> {
String url = URLs.getFrontEndNodes(frontend.getHost(), frontend.getHttpPort(),
isHttpsEnabled);
@@ -132,18 +134,18 @@
}
throw new DorisException("frontend init fetch failed", ex);
}
- return frontendList;
+ return new LoadBalanceList<>(frontendList);
} else {
int queryPort = config.contains(DorisOptions.DORIS_QUERY_PORT) ?
config.getValue(DorisOptions.DORIS_QUERY_PORT) : -1;
int flightSqlPort = config.contains(DorisOptions.DORIS_READ_FLIGHT_SQL_PORT) ?
config.getValue(DorisOptions.DORIS_READ_FLIGHT_SQL_PORT) : -1;
- return Arrays.stream(frontendNodeArray)
+ return new LoadBalanceList<>(Arrays.stream(frontendNodeArray)
.map(node -> {
String[] nodeParts = node.split(":");
return new Frontend(nodeParts[0], nodeParts.length > 1 ? Integer.parseInt(nodeParts[1]) : -1, queryPort, flightSqlPort);
})
- .collect(Collectors.toList());
+ .collect(Collectors.toList()));
}
}
@@ -151,7 +153,7 @@
return requestFrontends(frontends, reqFunc);
}
- private <T> T requestFrontends(List<Frontend> frontEnds, BiFunction<Frontend, CloseableHttpClient, T> reqFunc) throws Exception {
+ private <T> T requestFrontends(LoadBalanceList<Frontend> frontEnds, BiFunction<Frontend, CloseableHttpClient, T> reqFunc) throws Exception {
if (httpClient == null) {
httpClient = HttpUtils.getHttpClient(config);
}
@@ -163,6 +165,7 @@
}
} catch (Exception e) {
LOG.warn("fe http request on {} failed, err: {}", frontEnd.hostHttpPortString(), e.getMessage());
+ frontEnds.reportFailed(frontEnd);
ex = e;
}
}
@@ -370,6 +373,7 @@
backends.add(new Backend(backendNode.get("ip").asText(), backendNode.get("http_port").asInt(), -1));
}
}
+ Collections.shuffle(backends);
return backends;
});
}
@@ -386,7 +390,7 @@
});
}
- public List<Frontend> getFrontends() {
+ public LoadBalanceList<Frontend> getFrontends() {
return frontends;
}
diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/Frontend.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/Frontend.java
index a2858fb..8fd17a4 100644
--- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/Frontend.java
+++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/Frontend.java
@@ -18,6 +18,7 @@
package org.apache.doris.spark.client.entity;
import java.io.Serializable;
+import java.util.Objects;
public class Frontend implements Serializable {
@@ -66,4 +67,19 @@
return host + ":" + queryPort;
}
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (!(o instanceof Frontend)) return false;
+ Frontend frontend = (Frontend) o;
+ return httpPort == frontend.httpPort
+ && queryPort == frontend.queryPort
+ && flightSqlPort == frontend.flightSqlPort
+ && Objects.equals(host, frontend.host);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(host, httpPort, queryPort, flightSqlPort);
+ }
}
\ No newline at end of file
diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
index 55298ff..112b2f8 100644
--- a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
+++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
@@ -48,7 +48,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -66,10 +65,8 @@
public DorisFlightSqlReader(DorisReaderPartition partition) throws Exception {
super(partition);
this.frontendClient = new DorisFrontendClient(partition.getConfig());
- List<Frontend> frontends = new ArrayList<>(frontendClient.getFrontends());
- Collections.shuffle(frontends);
Exception tx = null;
- for (Frontend frontend : frontends) {
+ for (Frontend frontend : frontendClient.getFrontends()) {
try {
this.connection = initializeConnection(frontend, partition.getConfig());
tx = null;
@@ -77,6 +74,7 @@
} catch (OptionRequiredException e) {
throw new DorisException("init adbc connection failed", e);
} catch (AdbcException e) {
+ frontendClient.getFrontends().reportFailed(frontend);
log.warn("init adbc connection failed with fe: " + frontend.getHost(), e);
tx = new DorisException("init adbc connection failed", e);
}
diff --git a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/util/LoadBalanceList.java b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/util/LoadBalanceList.java
new file mode 100644
index 0000000..ff76d46
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/util/LoadBalanceList.java
@@ -0,0 +1,106 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+
+package org.apache.doris.spark.util;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.List;
+import java.util.Iterator;
+import java.util.Queue;
+import java.util.PriorityQueue;
+import java.util.Collections;
+
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
+
+
+public class LoadBalanceList<T> implements Iterable<T>, Serializable {
+
+ private final List<T> list;
+
+ private final Map<T, FailedServer<T>> failedServers;
+
+ private final AtomicInteger globalOffset = new AtomicInteger(0);
+
+ private static final long FAILED_TIME_OUT = 60 * 60 * 1000;
+
+ public LoadBalanceList(List<T> servers) {
+ this.list = Collections.unmodifiableList(servers);
+ this.failedServers = new ConcurrentHashMap<>();
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ return new Iterator<T>() {
+ final int offset = globalOffset.getAndAdd(1);
+ final Queue<FailedServer<T>> skipServers = new PriorityQueue<>();
+ int index = 0;
+
+ @Override
+ public boolean hasNext() {
+ return index < list.size() || !skipServers.isEmpty();
+ }
+
+ @Override
+ public T next() {
+ if (index < list.size()) {
+ T server = list.get(Math.abs(offset + index++) % list.size());
+ FailedServer failedEntry = failedServers.get(server);
+ if (failedEntry != null) {
+ if (System.currentTimeMillis() - failedEntry.failedTime > FAILED_TIME_OUT) {
+ failedServers.remove(failedEntry.server);
+ } else {
+ skipServers.add(failedEntry);
+ return next();
+ }
+ }
+ return server;
+ } else {
+ return skipServers.poll().server;
+ }
+ }
+ };
+ }
+
+ public List<T> getList() {
+ return list;
+ }
+
+ public void reportFailed(T server) {
+ this.failedServers.put(server, new FailedServer<T>(server));
+ }
+
+ private static class FailedServer<T> implements Comparable<FailedServer<T>>, Serializable {
+
+ protected final T server;
+
+ protected final Long failedTime;
+
+ public FailedServer(T t) {
+ this.server = t;
+ this.failedTime = System.currentTimeMillis();
+ }
+
+
+ @Override
+ public int compareTo(FailedServer<T> o) {
+ return this.failedTime.compareTo(o.failedTime);
+ }
+ }
+}
diff --git a/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/util/LoadBalanceListTest.java b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/util/LoadBalanceListTest.java
new file mode 100644
index 0000000..e4226e9
--- /dev/null
+++ b/spark-doris-connector/spark-doris-connector-base/src/test/java/org/apache/doris/spark/util/LoadBalanceListTest.java
@@ -0,0 +1,89 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.util;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Set;
+import java.util.List;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.ArrayList;
+
+public class LoadBalanceListTest {
+
+ @Test
+ public void testLoadBalanceList1() {
+ List<String> serverList = Arrays.asList("server1", "server2", "server3");
+ LoadBalanceList<String> loadBalanceList = new LoadBalanceList<>(serverList);
+ Set<String> testHeadSet = new HashSet<>();
+ for (int i = 0; i < 1000; i++) {
+ List<String> testList = new ArrayList<>();
+ int index = 0;
+ for (String server : loadBalanceList) {
+ testList.add(server);
+ if (index++ == 0) {
+ testHeadSet.add(server);
+ }
+ System.out.println(server);
+ }
+ if (i % serverList.size() == 0) {
+ Assert.assertTrue(testList.equals(Arrays.asList("server1", "server2", "server3")));
+ }
+
+ if (i % serverList.size() == 1) {
+ Assert.assertTrue(testList.equals(Arrays.asList("server2", "server3", "server1")));
+ }
+
+ if (i % serverList.size() == 2) {
+ Assert.assertTrue(testList.equals(Arrays.asList("server3", "server1", "server2")));
+ }
+
+ System.out.println("---------");
+ Assert.assertTrue(testList.size() == serverList.size());
+ }
+ Assert.assertTrue(testHeadSet.size() == serverList.size());
+ }
+
+ @Test
+ public void testLoadBalanceList2() throws InterruptedException {
+ List<String> serverList = Arrays.asList("server1", "server2", "server3", "server4");
+ LoadBalanceList<String> loadBalanceList = new LoadBalanceList<>(serverList);
+ Set<String> failedSet = new HashSet<>();
+ failedSet.add("server1");
+ loadBalanceList.reportFailed("server1");
+ Thread.sleep(10000);
+ failedSet.add("server4");
+ loadBalanceList.reportFailed("server4");
+ Set<String> serverSet = new HashSet<>();
+
+ for (int i = 0; i < 1000; i++) {
+ int index = 0;
+ for (String server : loadBalanceList) {
+ serverSet.add(server);
+ if (++index > loadBalanceList.getList().size() - failedSet.size()) {
+ Assert.assertTrue(failedSet.contains(server));
+ }
+ System.out.println(server);
+ }
+ System.out.println("---------");
+ Assert.assertTrue(serverSet.size() == serverList.size());
+ }
+ }
+}