[AINode][Bug fix] Concurrent inference (#16518)
* trigger CI
* bug fix 4 show loaded models
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
index c73fbeb..b5b9875 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
@@ -22,15 +22,23 @@
import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.itbase.env.BaseEnv;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import org.junit.AfterClass;
+import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.Connection;
+import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
@@ -38,6 +46,11 @@
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
+ private static final Map<String, String> MODEL_ID_TO_TYPE_MAP =
+ ImmutableMap.of(
+ "timer_xl", "Timer-XL",
+ "sundial", "Timer-Sundial");
+
@BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
@@ -91,12 +104,17 @@
Statement statement = connection.createStatement()) {
final int threadCnt = 4;
final int loop = 10;
+ final int predictLength = 96;
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
+ checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
concurrentInference(
statement,
- String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
+ String.format(
+ "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
+ modelId, predictLength),
threadCnt,
- loop);
+ loop,
+ predictLength);
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
}
}
@@ -111,14 +129,20 @@
throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- final int threadCnt = 4;
- final int loop = 10;
- statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
+ final int threadCnt = 10;
+ final int loop = 100;
+ final int predictLength = 512;
+ final String devices = "0,1";
+ statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
+ checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
concurrentInference(
statement,
- String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
+ String.format(
+ "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
+ modelId, predictLength),
threadCnt,
- loop);
+ loop,
+ predictLength);
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
}
}
@@ -134,15 +158,18 @@
Statement statement = connection.createStatement()) {
final int threadCnt = 4;
final int loop = 10;
+ final int predictLength = 96;
statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
+ checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
long startTime = System.currentTimeMillis();
concurrentInference(
statement,
String.format(
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
- modelId),
+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
+ modelId, predictLength),
threadCnt,
- loop);
+ loop,
+ predictLength);
long endTime = System.currentTimeMillis();
LOGGER.info(
String.format(
@@ -163,15 +190,19 @@
Statement statement = connection.createStatement()) {
final int threadCnt = 10;
final int loop = 100;
- statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
+ final int predictLength = 512;
+ final String devices = "0,1";
+ statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
+ checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
long startTime = System.currentTimeMillis();
concurrentInference(
statement,
String.format(
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
- modelId),
+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
+ modelId, predictLength),
threadCnt,
- loop);
+ loop,
+ predictLength);
long endTime = System.currentTimeMillis();
LOGGER.info(
String.format(
@@ -180,4 +211,29 @@
statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
}
}
+
+ private void checkModelOnSpecifiedDevice(Statement statement, String modelType, String device)
+ throws SQLException, InterruptedException {
+ for (int retry = 0; retry < 10; retry++) {
+ Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
+ Set<String> foundDevices = new HashSet<>();
+ try (final ResultSet resultSet =
+ statement.executeQuery(String.format("SHOW LOADED MODELS %s", device))) {
+ while (resultSet.next()) {
+ String deviceId = resultSet.getString(1);
+ String loadedModelType = resultSet.getString(2);
+ int count = resultSet.getInt(3);
+ if (loadedModelType.equals(modelType) && targetDevices.contains(deviceId)) {
+ Assert.assertTrue(count > 1);
+ foundDevices.add(deviceId);
+ }
+ }
+ if (foundDevices.containsAll(targetDevices)) {
+ return;
+ }
+ }
+ TimeUnit.SECONDS.sleep(3);
+ }
+ Assert.fail("Model " + modelType + " is not loaded on device " + device);
+ }
}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 31e498a..cbb0b03 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -60,7 +60,8 @@
}
}
- public static void concurrentInference(Statement statement, String sql, int threadCnt, int loop)
+ public static void concurrentInference(
+ Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength)
throws InterruptedException {
Thread[] threads = new Thread[threadCnt];
for (int i = 0; i < threadCnt; i++) {
@@ -70,9 +71,11 @@
try {
for (int j = 0; j < loop; j++) {
try (ResultSet resultSet = statement.executeQuery(sql)) {
+ int outputCnt = 0;
while (resultSet.next()) {
- // do nothing
+ outputCnt++;
}
+ assertEquals(expectedOutputLength, outputCnt);
} catch (SQLException e) {
fail(e.getMessage());
}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index a99c23b..6f14036 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -217,7 +217,7 @@
status=get_status(TSStatusCode.SUCCESS_STATUS),
deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
req.deviceIdList
- if req.deviceIdList is not None
+ if len(req.deviceIdList) > 0
else get_available_devices()
),
)