[AINode][Bug fix] Concurrent inference (#16518)

* trigger CI

* bug fix 4 show loaded models
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
index c73fbeb..b5b9875 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
@@ -22,15 +22,23 @@
 import org.apache.iotdb.it.env.EnvFactory;
 import org.apache.iotdb.itbase.env.BaseEnv;
 
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import java.sql.Connection;
+import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Statement;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.TimeUnit;
 
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
 
@@ -38,6 +46,11 @@
 
   private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
 
+  private static final Map<String, String> MODEL_ID_TO_TYPE_MAP =
+      ImmutableMap.of(
+          "timer_xl", "Timer-XL",
+          "sundial", "Timer-Sundial");
+
   @BeforeClass
   public static void setUp() throws Exception {
     // Init 1C1D1A cluster environment
@@ -91,12 +104,17 @@
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
       final int loop = 10;
+      final int predictLength = 96;
       statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
+      checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
       concurrentInference(
           statement,
-          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
+          String.format(
+              "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"cpu\"", modelId));
     }
   }
@@ -111,14 +129,20 @@
       throws SQLException, InterruptedException {
     try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
         Statement statement = connection.createStatement()) {
-      final int threadCnt = 4;
-      final int loop = 10;
-      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
+      final int threadCnt = 10;
+      final int loop = 100;
+      final int predictLength = 512;
+      final String devices = "0,1";
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
+      checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
       concurrentInference(
           statement,
-          String.format("CALL INFERENCE(%s, \"SELECT s FROM root.AI\")", modelId),
+          String.format(
+              "CALL INFERENCE(%s, \"SELECT s FROM root.AI\", predict_length=%d)",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
     }
   }
@@ -134,15 +158,18 @@
         Statement statement = connection.createStatement()) {
       final int threadCnt = 4;
       final int loop = 10;
+      final int predictLength = 96;
       statement.execute(String.format("LOAD MODEL %s TO DEVICES \"cpu\"", modelId));
+      checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), "cpu");
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
           String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
-              modelId),
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
@@ -163,15 +190,19 @@
         Statement statement = connection.createStatement()) {
       final int threadCnt = 10;
       final int loop = 100;
-      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"0,1\"", modelId));
+      final int predictLength = 512;
+      final String devices = "0,1";
+      statement.execute(String.format("LOAD MODEL %s TO DEVICES \"%s\"", modelId, devices));
+      checkModelOnSpecifiedDevice(statement, MODEL_ID_TO_TYPE_MAP.get(modelId), devices);
       long startTime = System.currentTimeMillis();
       concurrentInference(
           statement,
           String.format(
-              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time)",
-              modelId),
+              "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d",
+              modelId, predictLength),
           threadCnt,
-          loop);
+          loop,
+          predictLength);
       long endTime = System.currentTimeMillis();
       LOGGER.info(
           String.format(
@@ -180,4 +211,29 @@
       statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES \"0,1\"", modelId));
     }
   }
+
+  private void checkModelOnSpecifiedDevice(Statement statement, String modelType, String device)
+      throws SQLException, InterruptedException {
+    for (int retry = 0; retry < 10; retry++) {
+      Set<String> targetDevices = ImmutableSet.copyOf(device.split(","));
+      Set<String> foundDevices = new HashSet<>();
+      try (final ResultSet resultSet =
+          statement.executeQuery(String.format("SHOW LOADED MODELS %s", device))) {
+        while (resultSet.next()) {
+          String deviceId = resultSet.getString(1);
+          String loadedModelType = resultSet.getString(2);
+          int count = resultSet.getInt(3);
+          if (loadedModelType.equals(modelType) && targetDevices.contains(deviceId)) {
+            Assert.assertTrue(count > 1);
+            foundDevices.add(deviceId);
+          }
+        }
+        if (foundDevices.containsAll(targetDevices)) {
+          return;
+        }
+      }
+      TimeUnit.SECONDS.sleep(3);
+    }
+    Assert.fail("Model " + modelType + " is not loaded on device " + device);
+  }
 }
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index 31e498a..cbb0b03 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -60,7 +60,8 @@
     }
   }
 
-  public static void concurrentInference(Statement statement, String sql, int threadCnt, int loop)
+  public static void concurrentInference(
+      Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength)
       throws InterruptedException {
     Thread[] threads = new Thread[threadCnt];
     for (int i = 0; i < threadCnt; i++) {
@@ -70,9 +71,11 @@
                 try {
                   for (int j = 0; j < loop; j++) {
                     try (ResultSet resultSet = statement.executeQuery(sql)) {
+                      int outputCnt = 0;
                       while (resultSet.next()) {
-                        // do nothing
+                        outputCnt++;
                       }
+                      assertEquals(expectedOutputLength, outputCnt);
                     } catch (SQLException e) {
                       fail(e.getMessage());
                     }
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index a99c23b..6f14036 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -217,7 +217,7 @@
             status=get_status(TSStatusCode.SUCCESS_STATUS),
             deviceLoadedModelsMap=self._pool_controller.show_loaded_models(
                 req.deviceIdList
-                if req.deviceIdList is not None
+                if len(req.deviceIdList) > 0
                 else get_available_devices()
             ),
         )