Move built-in TVF from common sub-module into datanode sub-module
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBSQLFunctionManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBSQLFunctionManagementIT.java index fa8b2cb..7695678 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBSQLFunctionManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/udf/IoTDBSQLFunctionManagementIT.java
@@ -20,7 +20,7 @@ import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction; -import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinTableFunction; +import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.it.framework.IoTDBTestRunner; import org.apache.iotdb.itbase.category.TableClusterIT;
diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java index 0c73228..a311433 100644 --- a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -190,6 +190,7 @@ TRANSFER_LEADER_ERROR(1008), GET_CLUSTER_ID_ERROR(1009), CAN_NOT_CONNECT_CONFIGNODE(1010), + CAN_NOT_CONNECT_AINODE(1011), // Sync, Load TsFile LOAD_FILE_ERROR(1100),
diff --git a/iotdb-core/ainode/ainode/core/handler.py b/iotdb-core/ainode/ainode/core/handler.py index 7b94d20..0405d77 100644 --- a/iotdb-core/ainode/ainode/core/handler.py +++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -22,7 +22,7 @@ from ainode.thrift.ainode import IAINodeRPCService from ainode.thrift.ainode.ttypes import (TDeleteModelReq, TRegisterModelReq, TAIHeartbeatReq, TInferenceReq, TRegisterModelResp, TInferenceResp, - TAIHeartbeatResp, TTrainingReq) + TAIHeartbeatResp, TTrainingReq, TForecastReq) from ainode.thrift.common.ttypes import TSStatus @@ -39,6 +39,9 @@ def inference(self, req: TInferenceReq) -> TInferenceResp: return InferenceManager.inference(req, self._model_manager) + def forecast(self, req: TForecastReq) -> TSStatus: + return InferenceManager.forecast(req, self._model_manager) + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: return ClusterManager.get_heart_beat(req)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index ebfc6d4..bdae5cd 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -16,7 +16,8 @@ # under the License. # import pandas as pd -from torch import tensor +import torch +from iotdb.tsfile.utils.tsblock_serde import deserialize from ainode.core.constant import TSStatusCode from ainode.core.exception import InvalidWindowArgumentError, InferenceModelInternalError, runtime_error_extractor @@ -24,42 +25,10 @@ from ainode.core.manager.model_manager import ModelManager from ainode.core.util.serde import convert_to_binary, convert_to_df from ainode.core.util.status import get_status -from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp +from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp, TForecastReq, TForecastResp logger = Logger() - -class InferenceManager: - @staticmethod - def inference(req: TInferenceReq, model_manager: ModelManager): - logger.info(f"start inference registered model {req.modelId}") - try: - model_id, full_data, window_interval, window_step, inference_attributes = _parse_inference_request(req) - - if model_id.startswith('_'): - # built-in models - logger.info(f"start inference built-in model {model_id}") - # parse the inference attributes and create the built-in model - model = _get_built_in_model(model_id, model_manager, inference_attributes) - inference_results = _inference_with_built_in_model( - model, full_data) - else: - # user-registered models - model = _get_model(model_id, model_manager, inference_attributes) - inference_results = _inference_with_registered_model( - model, full_data, window_interval, window_step) - for i in range(len(inference_results)): - inference_results[i] = convert_to_binary(inference_results[i]) - return TInferenceResp( - get_status( - TSStatusCode.SUCCESS_STATUS), - inference_results) - except Exception as e: - logger.warning(e) - inference_results = [] - return TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), inference_results) - - def _process_data(full_data): """ Args: @@ -83,11 +52,75 @@ data[data.columns[i]] = 0 elif type_list[i] == "BOOLEAN": data[data.columns[i]] = data[data.columns[i]].astype("int") - data = tensor(data.values).unsqueeze(0) + data = torch.tensor(data.values).unsqueeze(0) return data, data_length -def _inference_with_registered_model(model, full_data, window_interval, window_step): +class InferenceManager: + + @staticmethod + def forecast(req: TForecastReq, model_manager:ModelManager): + model_id = req.modelId + logger.info(f"start to forcast by model {model_id}") + try: + data = deserialize(req.inputData) + if model_id.startswith('_'): + # built-in models + logger.info(f"start to forecast built-in model {model_id}") + # parse the inference attributes and create the built-in model + options = req.options + options['predict_length'] = req.outputLength + model = _get_built_in_model(model_id, model_manager, options) + inference_result = convert_to_binary(_inference_with_built_in_model( + model, data)) + else: + # user-registered models + model = _get_model(model_id, model_manager, req.options) + _, dataset, _, dataset_length = data + dataset = torch.tensor(dataset, dtype=torch.float).unsqueeze(2) + inference_results = _inference_with_registered_model( + model, dataset, dataset_length, dataset_length, float('inf')) + inference_result = convert_to_binary(inference_results[0]) + return TForecastResp( + get_status(TSStatusCode.SUCCESS_STATUS), + inference_result + ) + except Exception as e: + logger.warning(e) + inference_results = [] + return TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), inference_results) + + @staticmethod + def inference(req: TInferenceReq, model_manager: ModelManager): + logger.info(f"start inference registered model {req.modelId}") + try: + model_id, full_data, window_interval, window_step, inference_attributes = _parse_inference_request(req) + + if model_id.startswith('_'): + # built-in models + logger.info(f"start inference built-in model {model_id}") + # parse the inference attributes and create the built-in model + model = _get_built_in_model(model_id, model_manager, inference_attributes) + inference_results = [_inference_with_built_in_model( + model, full_data)] + else: + # user-registered models + model = _get_model(model_id, model_manager, inference_attributes) + dataset, dataset_length = _process_data(full_data) + inference_results = _inference_with_registered_model( + model, dataset, dataset_length, window_interval, window_step) + for i in range(len(inference_results)): + inference_results[i] = convert_to_binary(inference_results[i]) + return TInferenceResp( + get_status( + TSStatusCode.SUCCESS_STATUS), + inference_results) + except Exception as e: + logger.warning(e) + inference_results = [] + return TInferenceResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)), inference_results) + +def _inference_with_registered_model(model, dataset, dataset_length, window_interval, window_step): """ Args: model: the user-defined model @@ -109,8 +142,6 @@ a list. """ - dataset, dataset_length = _process_data(full_data) - # check the validity of window_interval and window_step, the two arguments must be positive integers, and the # window_interval should not be larger than the dataset length if window_interval is None or window_step is None \ @@ -168,8 +199,7 @@ output = model.inference(data) # output: DataFrame, shape: (H', C') output = pd.DataFrame(output) - outputs = [output] - return outputs + return output def _get_model(model_id: str, model_manager: ModelManager, inference_attributes: {}):
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index f5b083b..7fb6c18 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml
@@ -59,7 +59,7 @@ sktime = "^0.24.1" pmdarima = "^2.0.4" hmmlearn = "^0.3.0" -apache-iotdb = "2.0.1b0" +apache-iotdb = "2.0.4.dev0" [tool.poetry.scripts] ainode = "ainode.core.script:main" \ No newline at end of file
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java index 4a007e7..cce01c8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelException.java
@@ -19,17 +19,18 @@ package org.apache.iotdb.db.exception.ainode; +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.rpc.TSStatusCode; -public class ModelException extends RuntimeException { - TSStatusCode statusCode; +import static org.apache.iotdb.rpc.TSStatusCode.representOf; + +public class ModelException extends IoTDBRuntimeException { public ModelException(String message, TSStatusCode code) { - super(message); - this.statusCode = code; + super(message, code.getStatusCode()); } public TSStatusCode getStatusCode() { - return statusCode; + return representOf(getErrorCode()); } }
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java index 82e8527..f5be781 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/DataNodeInternalRPCServiceImpl.java
@@ -68,7 +68,6 @@ import org.apache.iotdb.commons.subscription.meta.topic.TopicMeta; import org.apache.iotdb.commons.trigger.TriggerInformation; import org.apache.iotdb.commons.udf.UDFInformation; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.utils.PathUtils; import org.apache.iotdb.commons.utils.StatusUtils; import org.apache.iotdb.consensus.ConsensusFactory; @@ -154,6 +153,7 @@ import org.apache.iotdb.db.queryengine.plan.scheduler.load.LoadTsFileScheduler; import org.apache.iotdb.db.queryengine.plan.statement.component.WhereCondition; import org.apache.iotdb.db.queryengine.plan.statement.crud.QueryStatement; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.schemaengine.SchemaEngine; import org.apache.iotdb.db.schemaengine.schemaregion.ISchemaRegion; import org.apache.iotdb.db.schemaengine.schemaregion.read.resp.info.ITimeSeriesSchemaInfo;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/UDAFAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/UDAFAccumulator.java index a0b78cc..73ec966 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/UDAFAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/UDAFAccumulator.java
@@ -19,9 +19,9 @@ package org.apache.iotdb.db.queryengine.execution.aggregation; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.queryengine.plan.expression.Expression; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDFParametersFactory; import org.apache.iotdb.udf.api.State; import org.apache.iotdb.udf.api.UDAF;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/TransformOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/TransformOperator.java index 96c7aae..6dd24e3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/TransformOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/TransformOperator.java
@@ -20,7 +20,6 @@ package org.apache.iotdb.db.queryengine.execution.operator.process; import org.apache.iotdb.commons.udf.service.UDFClassLoaderManager; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.utils.TestOnly; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.exception.query.QueryProcessException; @@ -30,6 +29,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; import org.apache.iotdb.db.queryengine.plan.expression.Expression; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.queryengine.transformation.api.LayerReader; import org.apache.iotdb.db.queryengine.transformation.api.YieldableState; import org.apache.iotdb.db.queryengine.transformation.dag.builder.EvaluationDAGBuilder;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 0c9edd7..d61f299 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java
@@ -20,7 +20,6 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; import org.apache.iotdb.common.rpc.thrift.TAggregationType; -import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; @@ -45,6 +44,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.hash.MarkDistinctHash; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager; +import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils; import org.apache.iotdb.udf.api.customizer.parameter.FunctionArguments; import org.apache.iotdb.udf.api.relational.AggregateFunction;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java index 1b94894..472899e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/relational/ColumnTransformerBuilder.java
@@ -20,7 +20,6 @@ package org.apache.iotdb.db.queryengine.execution.relational; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction; -import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.exception.sql.SemanticException; @@ -71,6 +70,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause; import org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager; import org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundException; +import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils; import org.apache.iotdb.db.queryengine.transformation.dag.column.ColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.TableCaseWhenThenColumnTransformer; import org.apache.iotdb.db.queryengine.transformation.dag.column.binary.ArithmeticColumnTransformerApi;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java index 1feecae..586e12e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java
@@ -20,8 +20,12 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; public interface IModelFetcher { /** Get model information by model id from configNode. */ TSStatus fetchModel(String modelId, Analysis analysis); + + // currently only used by table model + ModelInferenceDescriptor fetchModel(String modelName); }
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index 8cefb5e..3638234 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java
@@ -23,6 +23,7 @@ import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; @@ -78,4 +79,29 @@ throw new StatementAnalyzeException(e.getMessage()); } } + + @Override + public ModelInferenceDescriptor fetchModel(String modelName) { + try (ConfigNodeClient client = + configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); + if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + if (getModelInfoResp.modelInfo != null && getModelInfoResp.isSetAiNodeAddress()) { + return new ModelInferenceDescriptor( + getModelInfoResp.aiNodeAddress, + ModelInformation.deserialize(getModelInfoResp.modelInfo)); + } else { + throw new IoTDBRuntimeException( + String.format("model [%s] is not available", modelName), + TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + } + } else { + throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); + } + } catch (ClientManagerException | TException e) { + throw new IoTDBRuntimeException( + String.format("fetch model [%s] info failed: %s", modelName, e.getMessage()), + TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + } + } }
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 56e36fb..3de1ee0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -70,7 +70,6 @@ import org.apache.iotdb.commons.trigger.service.TriggerExecutableManager; import org.apache.iotdb.commons.udf.service.UDFClassLoader; import org.apache.iotdb.commons.udf.service.UDFExecutableManager; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.utils.CommonDateTimeUtils; import org.apache.iotdb.commons.utils.PathUtils; import org.apache.iotdb.commons.utils.TimePartitionUtils; @@ -278,6 +277,7 @@ import org.apache.iotdb.db.queryengine.plan.statement.sys.quota.SetThrottleQuotaStatement; import org.apache.iotdb.db.queryengine.plan.statement.sys.quota.ShowSpaceQuotaStatement; import org.apache.iotdb.db.queryengine.plan.statement.sys.quota.ShowThrottleQuotaStatement; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.schemaengine.SchemaEngine; import org.apache.iotdb.db.schemaengine.rescon.DataNodeSchemaQuotaManager; import org.apache.iotdb.db.schemaengine.table.InformationSchemaUtils;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java index e0e88b5..7cff012 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ShowFunctionsTask.java
@@ -29,14 +29,14 @@ import org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction; -import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinTableFunction; -import org.apache.iotdb.commons.udf.utils.TableUDFUtils; -import org.apache.iotdb.commons.udf.utils.TreeUDFUtils; import org.apache.iotdb.db.queryengine.common.header.DatasetHeader; import org.apache.iotdb.db.queryengine.common.header.DatasetHeaderFactory; import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; +import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils; +import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils; import org.apache.iotdb.rpc.TSStatusCode; import com.google.common.util.concurrent.ListenableFuture;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java index a63db1f..957856f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/expression/multi/FunctionExpression.java
@@ -23,7 +23,6 @@ import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction; import org.apache.iotdb.commons.udf.builtin.BuiltinScalarFunction; -import org.apache.iotdb.commons.udf.utils.TreeUDFUtils; import org.apache.iotdb.db.queryengine.common.NodeRef; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; import org.apache.iotdb.db.queryengine.plan.expression.Expression; @@ -32,6 +31,7 @@ import org.apache.iotdb.db.queryengine.plan.expression.multi.builtin.BuiltInScalarFunctionHelperFactory; import org.apache.iotdb.db.queryengine.plan.expression.visitor.ExpressionVisitor; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.InputLocation; +import org.apache.iotdb.db.queryengine.plan.udf.TreeUDFUtils; import org.apache.iotdb.db.queryengine.transformation.dag.memory.LayerMemoryAssigner; import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFExecutor; import org.apache.iotdb.db.queryengine.transformation.dag.udf.UDTFInformationInferrer;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java index f8cec46..756e209 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java
@@ -20,7 +20,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; -import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultExpressionTraversalVisitor; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; @@ -28,6 +27,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Identifier; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.QualifiedName; +import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils; import com.google.common.collect.ImmutableList;
diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java similarity index 93% rename from iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java index c5984bc..fda10eb 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/TableBuiltinTableFunction.java
@@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.commons.udf.builtin.relational; +package org.apache.iotdb.db.queryengine.plan.relational.function; import org.apache.iotdb.commons.udf.builtin.relational.tvf.CapacityTableFunction; import org.apache.iotdb.commons.udf.builtin.relational.tvf.CumulateTableFunction; @@ -39,6 +39,7 @@ SESSION("session"), VARIATION("variation"), CAPACITY("capacity"); + // FORECAST("forecast"); private final String functionName; @@ -78,6 +79,8 @@ return new VariationTableFunction(); case "capacity": return new CapacityTableFunction(); + // case "forecast": + // return new ForecastTableFunction(); default: throw new UnsupportedOperationException("Unsupported table function: " + functionName); }
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java new file mode 100644 index 0000000..38b2c8d --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -0,0 +1,642 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; + +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ainode.AINodeClient; +import org.apache.iotdb.commons.client.ainode.AINodeClientManager; +import org.apache.iotdb.commons.exception.IoTDBRuntimeException; +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; +import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; +import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.rpc.TSStatusCode; +import org.apache.iotdb.udf.api.relational.TableFunction; +import org.apache.iotdb.udf.api.relational.access.Record; +import org.apache.iotdb.udf.api.relational.table.TableFunctionAnalysis; +import org.apache.iotdb.udf.api.relational.table.TableFunctionHandle; +import org.apache.iotdb.udf.api.relational.table.TableFunctionProcessorProvider; +import org.apache.iotdb.udf.api.relational.table.argument.Argument; +import org.apache.iotdb.udf.api.relational.table.argument.DescribedSchema; +import org.apache.iotdb.udf.api.relational.table.argument.ScalarArgument; +import org.apache.iotdb.udf.api.relational.table.argument.TableArgument; +import org.apache.iotdb.udf.api.relational.table.processor.TableFunctionDataProcessor; +import org.apache.iotdb.udf.api.relational.table.specification.ParameterSpecification; +import org.apache.iotdb.udf.api.relational.table.specification.ScalarParameterSpecification; +import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification; +import org.apache.iotdb.udf.api.type.Type; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.TsBlockSerde; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.iotdb.commons.udf.builtin.relational.tvf.WindowTVFUtils.findColumnIndex; +import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; + +public class ForecastTableFunction implements TableFunction { + + private static class ForecastTableFunctionHandle implements TableFunctionHandle { + TEndPoint targetAINode; + String modelId; + int maxInputLength; + int outputLength; + boolean keepInput; + Map<String, String> options; + List<Type> types; + + public ForecastTableFunctionHandle() {} + + public ForecastTableFunctionHandle( + boolean keepInput, + int maxInputLength, + String modelId, + Map<String, String> options, + int outputLength, + TEndPoint targetAINode, + List<Type> types) { + this.keepInput = keepInput; + this.maxInputLength = maxInputLength; + this.modelId = modelId; + this.options = options; + this.outputLength = outputLength; + this.targetAINode = targetAINode; + this.types = types; + } + + @Override + public byte[] serialize() { + try (PublicBAOS publicBAOS = new PublicBAOS(); + DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { + ReadWriteIOUtils.write(targetAINode.getIp(), outputStream); + ReadWriteIOUtils.write(targetAINode.getPort(), outputStream); + ReadWriteIOUtils.write(modelId, outputStream); + ReadWriteIOUtils.write(maxInputLength, outputStream); + ReadWriteIOUtils.write(outputLength, outputStream); + ReadWriteIOUtils.write(keepInput, outputStream); + ReadWriteIOUtils.write(options, outputStream); + ReadWriteIOUtils.write(types.size(), outputStream); + for (Type type : types) { + ReadWriteIOUtils.write(type.getType(), outputStream); + } + outputStream.flush(); + return publicBAOS.toByteArray(); + } catch (IOException e) { + throw new IoTDBRuntimeException( + String.format( + "Error occurred while serializing ForecastTableFunctionHandle: %s", e.getMessage()), + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); + } + } + + @Override + public void deserialize(byte[] bytes) { + ByteBuffer buffer = ByteBuffer.wrap(bytes); + this.targetAINode = + new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); + this.modelId = ReadWriteIOUtils.readString(buffer); + this.maxInputLength = ReadWriteIOUtils.readInt(buffer); + this.outputLength = ReadWriteIOUtils.readInt(buffer); + this.keepInput = ReadWriteIOUtils.readBoolean(buffer); + this.options = ReadWriteIOUtils.readMap(buffer); + int size = ReadWriteIOUtils.readInt(buffer); + this.types = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer))); + } + } + } + + private static final IModelFetcher MODEL_FETCHER = ModelFetcher.getInstance(); + + private static final String INPUT_PARAMETER_NAME = "INPUT"; + private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID"; + private static final String OUTPUT_LENGTH_PARAMETER_NAME = "OUTPUT_LENGTH"; + private static final int DEFAULT_OUTPUT_LENGTH = 96; + private static final String PREDICATED_COLUMNS_PARAMETER_NAME = "PREDICATED_COLUMNS"; + private static final String DEFAULT_PREDICATED_COLUMNS = ""; + private static final String TIMECOL_PARAMETER_NAME = "TIMECOL"; + private static final String DEFAULT_TIME_COL = "time"; + private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT"; + private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE; + private static final String IS_INPUT_COLUMN_NAME = "is_input"; + private static final String OPTIONS_PARAMETER_NAME = "OPTIONS"; + private static final String DEFAULT_OPTIONS = ""; + + private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s"; + + private static final Set<Type> ALLOWED_INPUT_TYPES = new HashSet<>(); + + static { + ALLOWED_INPUT_TYPES.add(Type.INT32); + ALLOWED_INPUT_TYPES.add(Type.INT64); + ALLOWED_INPUT_TYPES.add(Type.FLOAT); + ALLOWED_INPUT_TYPES.add(Type.DOUBLE); + } + + @Override + public List<ParameterSpecification> getArgumentsSpecifications() { + return Arrays.asList( + TableParameterSpecification.builder().name(INPUT_PARAMETER_NAME).setSemantics().build(), + ScalarParameterSpecification.builder() + .name(MODEL_ID_PARAMETER_NAME) + .type(Type.STRING) + .build(), + ScalarParameterSpecification.builder() + .name(OUTPUT_LENGTH_PARAMETER_NAME) + .type(Type.INT32) + .defaultValue(DEFAULT_OUTPUT_LENGTH) + .build(), + ScalarParameterSpecification.builder() + .name(PREDICATED_COLUMNS_PARAMETER_NAME) + .type(Type.STRING) + .defaultValue(DEFAULT_PREDICATED_COLUMNS) + .build(), + ScalarParameterSpecification.builder() + .name(TIMECOL_PARAMETER_NAME) + .type(Type.STRING) + .defaultValue(DEFAULT_TIME_COL) + .build(), + ScalarParameterSpecification.builder() + .name(KEEP_INPUT_PARAMETER_NAME) + .type(Type.BOOLEAN) + .defaultValue(DEFAULT_KEEP_INPUT) + .build(), + ScalarParameterSpecification.builder() + .name(OPTIONS_PARAMETER_NAME) + .type(Type.STRING) + .defaultValue(DEFAULT_OPTIONS) + .build()); + } + + @Override + public TableFunctionAnalysis analyze(Map<String, Argument> arguments) { + TableArgument input = (TableArgument) arguments.get(INPUT_PARAMETER_NAME); + String modelId = (String) ((ScalarArgument) arguments.get(MODEL_ID_PARAMETER_NAME)).getValue(); + // modelId should never be null or empty + if (modelId == null || modelId.isEmpty()) { + throw new SemanticException( + String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); + } + + // make sure modelId exists + ModelInferenceDescriptor descriptor = getModelInfo(modelId); + if (descriptor == null || !descriptor.getModelInformation().available()) { + throw new IoTDBRuntimeException( + String.format("model [%s] is not available", modelId), + TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); + } + + int maxInputLength = descriptor.getModelInformation().getInputShape()[0]; + TEndPoint targetAINode = descriptor.getTargetAINode(); + + int outputLength = + (int) ((ScalarArgument) arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue(); + if (outputLength <= 0) { + throw new SemanticException( + String.format("%s should be greater than 0", OUTPUT_LENGTH_PARAMETER_NAME)); + } + + String predicatedColumns = + (String) ((ScalarArgument) arguments.get(PREDICATED_COLUMNS_PARAMETER_NAME)).getValue(); + + String timeColumn = + (String) ((ScalarArgument) arguments.get(TIMECOL_PARAMETER_NAME)).getValue(); + + // predicated columns should never contain partition by columns and time column + Set<String> excludedColumns = new HashSet<>(input.getPartitionBy()); + excludedColumns.add(timeColumn); + int timeColumnIndex = findColumnIndex(input, timeColumn, Collections.singleton(Type.TIMESTAMP)); + + List<Integer> requiredIndexList = new ArrayList<>(); + requiredIndexList.add(timeColumnIndex); + DescribedSchema.Builder properColumnSchemaBuilder = + new DescribedSchema.Builder().addField(timeColumn, Type.TIMESTAMP); + + List<Type> predicatedColumnTypes = new ArrayList<>(); + List<Optional<String>> allInputColumnsName = input.getFieldNames(); + List<Type> allInputColumnsType = input.getFieldTypes(); + if (predicatedColumns.isEmpty()) { + // predicated columns by default include all columns from input table except for timecol and + // partition by columns + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { + Optional<String> fieldName = allInputColumnsName.get(i); + if (!fieldName.isPresent() || !excludedColumns.contains(fieldName.get())) { + Type columnType = allInputColumnsType.get(i); + predicatedColumnTypes.add(columnType); + checkType(columnType, fieldName.orElse("")); + requiredIndexList.add(i); + properColumnSchemaBuilder.addField(fieldName, columnType); + } + } + } else { + String[] predictedColumnsArray = predicatedColumns.split(","); + Map<String, Integer> inputColumnIndexMap = new HashMap<>(); + for (int i = 0, size = allInputColumnsName.size(); i < size; i++) { + Optional<String> fieldName = allInputColumnsName.get(i); + if (!fieldName.isPresent()) { + continue; + } + inputColumnIndexMap.put(fieldName.get(), i); + } + + Set<Integer> requiredIndexSet = new HashSet<>(predictedColumnsArray.length); + // columns need to be predicated + for (String outputColumn : predictedColumnsArray) { + if (excludedColumns.contains(outputColumn)) { + throw new SemanticException( + String.format("%s is in partition by clause or is time column", outputColumn)); + } + Integer inputColumnIndex = inputColumnIndexMap.get(outputColumn); + if (inputColumnIndex == null) { + throw new SemanticException( + String.format("Column %s don't exist in input", outputColumn)); + } + if (!requiredIndexSet.add(inputColumnIndex)) { + throw new SemanticException(String.format("Duplicate column %s", outputColumn)); + } + + Type columnType = allInputColumnsType.get(inputColumnIndex); + predicatedColumnTypes.add(columnType); + checkType(columnType, outputColumn); + requiredIndexList.add(inputColumnIndex); + properColumnSchemaBuilder.addField(outputColumn, columnType); + } + } + + boolean keepInput = + (boolean) ((ScalarArgument) arguments.get(KEEP_INPUT_PARAMETER_NAME)).getValue(); + if (keepInput) { + properColumnSchemaBuilder.addField(IS_INPUT_COLUMN_NAME, Type.BOOLEAN); + } + + String options = (String) ((ScalarArgument) arguments.get(OPTIONS_PARAMETER_NAME)).getValue(); + + ForecastTableFunctionHandle functionHandle = + new ForecastTableFunctionHandle( + keepInput, + maxInputLength, + modelId, + parseOptions(options), + outputLength, + targetAINode, + predicatedColumnTypes); + + // outputColumnSchema + return TableFunctionAnalysis.builder() + .properColumnSchema(properColumnSchemaBuilder.build()) + .handle(functionHandle) + .requiredColumns(INPUT_PARAMETER_NAME, requiredIndexList) + .build(); + } + + @Override + public TableFunctionHandle createTableFunctionHandle() { + return new ForecastTableFunctionHandle(); + } + + @Override + public TableFunctionProcessorProvider getProcessorProvider( + TableFunctionHandle tableFunctionHandle) { + return new TableFunctionProcessorProvider() { + @Override + public TableFunctionDataProcessor getDataProcessor() { + return new ForecastDataProcessor((ForecastTableFunctionHandle) tableFunctionHandle); + } + }; + } + + private ModelInferenceDescriptor getModelInfo(String modelId) { + return MODEL_FETCHER.fetchModel(modelId); + } + + // only allow for INT32, INT64, FLOAT, DOUBLE + private void checkType(Type type, String columnName) { + if (!ALLOWED_INPUT_TYPES.contains(type)) { + throw new SemanticException( + String.format( + "The type of the column [%s] is [%s], only INT32, INT64, FLOAT, DOUBLE is allowed", + columnName, type)); + } + } + + private static Map<String, String> parseOptions(String options) { + if (options.isEmpty()) { + return Collections.emptyMap(); + } + String[] optionArray = options.split(","); + if (optionArray.length == 0) { + throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, options)); + } + + Map<String, String> optionsMap = new HashMap<>(optionArray.length); + for (String option : optionArray) { + int index = option.indexOf('='); + if (index == -1 || index == option.length() - 1) { + throw new SemanticException(String.format(INVALID_OPTIONS_FORMAT, option)); + } + String key = option.substring(0, index).trim(); + String value = option.substring(index + 1).trim(); + optionsMap.put(key, value); + } + return optionsMap; + } + + private static class ForecastDataProcessor implements TableFunctionDataProcessor { + + private static final TsBlockSerde SERDE = new TsBlockSerde(); + private static final IClientManager<TEndPoint, AINodeClient> CLIENT_MANAGER = + AINodeClientManager.getInstance(); + + private final TEndPoint targetAINode; + private final String modelId; + private final int maxInputLength; + private final int outputLength; + private final boolean keepInput; + private final Map<String, String> options; + private final LinkedList<Record> inputRecords; + private final List<ResultColumnAppender> resultColumnAppenderList; + private final TsBlockBuilder inputTsBlockBuilder; + + public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { + this.targetAINode = functionHandle.targetAINode; + this.modelId = functionHandle.modelId; + this.maxInputLength = functionHandle.maxInputLength; + this.outputLength = functionHandle.outputLength; + this.keepInput = functionHandle.keepInput; + this.options = functionHandle.options; + this.inputRecords = new LinkedList<>(); + this.resultColumnAppenderList = new ArrayList<>(functionHandle.types.size()); + List<TSDataType> tsDataTypeList = new ArrayList<>(functionHandle.types.size()); + for (Type type : functionHandle.types) { + resultColumnAppenderList.add(createResultColumnAppender(type)); + // ainode currently only accept double input + tsDataTypeList.add(TSDataType.DOUBLE); + } + this.inputTsBlockBuilder = new TsBlockBuilder(tsDataTypeList); + } + + private static ResultColumnAppender createResultColumnAppender(Type type) { + switch (type) { + case INT32: + return new Int32Appender(); + case INT64: + return new Int64Appender(); + case FLOAT: + return new FloatAppender(); + case DOUBLE: + return new DoubleAppender(); + default: + throw new IllegalArgumentException("Unsupported column type: " + type); + } + } + + @Override + public void process( + Record input, + List<ColumnBuilder> properColumnBuilders, + ColumnBuilder passThroughIndexBuilder) { + + if (keepInput) { + int columnSize = properColumnBuilders.size(); + + // time column, will never be null + if (input.isNull(0)) { + throw new IoTDBRuntimeException( + "Time column should never be null", TSStatusCode.SEMANTIC_ERROR.getStatusCode()); + } + properColumnBuilders.get(0).writeLong(input.getLong(0)); + + // predicated columns + for (int i = 1, size = columnSize - 1; i < size; i++) { + resultColumnAppenderList.get(i - 1).append(input, i, properColumnBuilders.get(i)); + } + + // is_input column + properColumnBuilders.get(columnSize - 1).writeBoolean(true); + } + + // only keep at most maxInputLength rows + if (maxInputLength != 0 && inputRecords.size() == maxInputLength) { + inputRecords.removeFirst(); + } + inputRecords.add(input); + } + + @Override + public void finish( + List<ColumnBuilder> properColumnBuilders, ColumnBuilder passThroughIndexBuilder) { + + int columnSize = properColumnBuilders.size(); + + // time column + long startTime = inputRecords.getFirst().getLong(0); + long endTime = inputRecords.getLast().getLong(0); + long interval = (endTime - startTime) / inputRecords.size(); + for (int i = 0; i < outputLength; i++) { + properColumnBuilders.get(0).writeLong(endTime + interval * (i + 1)); + } + + // predicated columns + TsBlock predicatedResult = forecast(); + if (predicatedResult.getPositionCount() != outputLength) { + throw new IoTDBRuntimeException( + String.format( + "Model %s output length is %s, doesn't equal to specified %s", + modelId, predicatedResult.getPositionCount(), outputLength), + TSStatusCode.INTERNAL_SERVER_ERROR.getStatusCode()); + } + for (int columnIndex = 1, size = predicatedResult.getValueColumnCount(); + columnIndex <= size; + columnIndex++) { + Column column = predicatedResult.getColumn(columnIndex - 1); + ColumnBuilder builder = properColumnBuilders.get(columnIndex); + ResultColumnAppender appender = resultColumnAppenderList.get(columnIndex - 1); + for (int row = 0; row < outputLength; row++) { + if (column.isNull(row)) { + builder.appendNull(); + } else { + // convert double to real type + appender.writeDouble(column.getDouble(row), builder); + } + } + } + + // is_input column if keep_input is true + if (keepInput) { + for (int i = 0; i < outputLength; i++) { + properColumnBuilders.get(columnSize - 1).writeBoolean(false); + } + } + } + + private TsBlock forecast() { + while (!inputRecords.isEmpty()) { + Record row = inputRecords.removeFirst(); + inputTsBlockBuilder.getTimeColumnBuilder().writeLong(row.getLong(0)); + for (int i = 1, size = row.size(); i < size; i++) { + // we set null input to 0.0 + if (row.isNull(i)) { + inputTsBlockBuilder.getColumnBuilder(i - 1).writeDouble(0.0); + } else { + // need to transform other types to DOUBLE + inputTsBlockBuilder + .getColumnBuilder(i - 1) + .writeDouble(resultColumnAppenderList.get(i - 1).getDouble(row, i)); + } + } + inputTsBlockBuilder.declarePosition(); + } + TsBlock inputData = inputTsBlockBuilder.build(); + + TForecastResp resp; + try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { + resp = client.forecast(modelId, inputData, outputLength, options); + } catch (Exception e) { + throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); + } + + if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + String message = + String.format( + "Error occurred while executing forecast:[%s]", resp.getStatus().getMessage()); + throw new IoTDBRuntimeException(message, resp.getStatus().getCode()); + } + + return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult())); + } + } + + private interface ResultColumnAppender { + void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder); + + double getDouble(Record row, int columnIndex); + + void writeDouble(double value, ColumnBuilder columnBuilder); + } + + private static class Int32Appender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeInt(row.getInt(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getInt(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeInt((int) value); + } + } + + private static class Int64Appender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeLong(row.getLong(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getLong(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeLong((long) value); + } + } + + private static class FloatAppender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeFloat(row.getFloat(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getFloat(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeFloat((float) value); + } + } + + private static class DoubleAppender implements ResultColumnAppender { + + @Override + public void append(Record row, int columnIndex, ColumnBuilder properColumnBuilder) { + if (row.isNull(columnIndex)) { + properColumnBuilder.appendNull(); + } else { + properColumnBuilder.writeDouble(row.getDouble(columnIndex)); + } + } + + @Override + public double getDouble(Record row, int columnIndex) { + return row.getDouble(columnIndex); + } + + @Override + public void writeDouble(double value, ColumnBuilder columnBuilder) { + columnBuilder.writeDouble(value); + } + } +}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index bd4fe3f..44b7f69 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
@@ -27,8 +27,6 @@ import org.apache.iotdb.commons.schema.table.TsTable; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction; -import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinTableFunction; -import org.apache.iotdb.commons.udf.utils.TableUDFUtils; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.exception.load.LoadAnalyzeTableColumnDisorderException; import org.apache.iotdb.db.exception.sql.SemanticException; @@ -37,6 +35,7 @@ import org.apache.iotdb.db.queryengine.plan.analyze.ClusterPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; +import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.AdditionResolver; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.DivisionResolver; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.ModulusResolver; @@ -51,6 +50,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.type.TypeManager; import org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundException; import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature; +import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils; import org.apache.iotdb.db.schemaengine.table.DataNodeTableCache; import org.apache.iotdb.db.utils.constant.SqlConstant; import org.apache.iotdb.rpc.TSStatusCode;
diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TableUDFUtils.java similarity index 95% rename from iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TableUDFUtils.java index 03ba31e..861bc70 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TableUDFUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TableUDFUtils.java
@@ -17,12 +17,11 @@ * under the License. */ -package org.apache.iotdb.commons.udf.utils; +package org.apache.iotdb.db.queryengine.plan.udf; import org.apache.iotdb.common.rpc.thrift.FunctionType; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.commons.udf.UDFInformation; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.udf.api.exception.UDFException; import org.apache.iotdb.udf.api.relational.AggregateFunction; import org.apache.iotdb.udf.api.relational.ScalarFunction;
diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TreeUDFUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TreeUDFUtils.java similarity index 92% rename from iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TreeUDFUtils.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TreeUDFUtils.java index 16e7454..f04cfe0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/utils/TreeUDFUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/TreeUDFUtils.java
@@ -17,9 +17,8 @@ * under the License. */ -package org.apache.iotdb.commons.udf.utils; +package org.apache.iotdb.db.queryengine.plan.udf; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.udf.api.UDAF; import org.apache.iotdb.udf.api.UDTF;
diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java similarity index 96% rename from iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java index 04c84b6..9efa02d 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/service/UDFManagementService.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDFManagementService.java
@@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.commons.udf.service; +package org.apache.iotdb.db.queryengine.plan.udf; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.commons.udf.UDFInformation; @@ -28,8 +28,11 @@ import org.apache.iotdb.commons.udf.builtin.BuiltinTimeSeriesGeneratingFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction; import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinScalarFunction; -import org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinTableFunction; +import org.apache.iotdb.commons.udf.service.UDFClassLoader; +import org.apache.iotdb.commons.udf.service.UDFClassLoaderManager; +import org.apache.iotdb.commons.udf.service.UDFExecutableManager; import org.apache.iotdb.commons.utils.TestOnly; +import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.udf.api.UDF; import org.apache.iotdb.udf.api.exception.UDFException; import org.apache.iotdb.udf.api.exception.UDFManagementException;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDAFInformationInferrer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDAFInformationInferrer.java index eb51bcc..9df4993 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDAFInformationInferrer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDAFInformationInferrer.java
@@ -19,9 +19,9 @@ package org.apache.iotdb.db.queryengine.transformation.dag.udf; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.udf.api.UDAF; import org.apache.iotdb.udf.api.customizer.config.UDAFConfigurations; import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFExecutor.java index f93de5c..00d91cc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFExecutor.java
@@ -19,8 +19,8 @@ package org.apache.iotdb.db.queryengine.transformation.dag.udf; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.queryengine.transformation.dag.adapter.PointCollectorAdaptor; import org.apache.iotdb.db.queryengine.transformation.dag.util.InputRowUtils; import org.apache.iotdb.db.queryengine.transformation.datastructure.tv.ElasticSerializableTVList;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFInformationInferrer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFInformationInferrer.java index 01baa83..08b1e11 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFInformationInferrer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/transformation/dag/udf/UDTFInformationInferrer.java
@@ -19,9 +19,9 @@ package org.apache.iotdb.db.queryengine.transformation.dag.udf; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.udf.api.UDTF; import org.apache.iotdb.udf.api.customizer.config.UDTFConfigurations; import org.apache.iotdb.udf.api.customizer.parameter.UDFParameterValidator;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java index b554eb9..80a1a30 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java
@@ -50,7 +50,6 @@ import org.apache.iotdb.commons.udf.UDFInformation; import org.apache.iotdb.commons.udf.service.UDFClassLoaderManager; import org.apache.iotdb.commons.udf.service.UDFExecutableManager; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.commons.utils.FileUtils; import org.apache.iotdb.commons.utils.PathUtils; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; @@ -89,6 +88,7 @@ import org.apache.iotdb.db.queryengine.plan.planner.distribution.DistributionPlanContext; import org.apache.iotdb.db.queryengine.plan.planner.distribution.SourceRewriter; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.schemaengine.SchemaEngine; import org.apache.iotdb.db.schemaengine.schemaregion.attribute.update.GeneralRegionAttributeSecurityService; import org.apache.iotdb.db.schemaengine.table.DataNodeTableCache;
diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/utils/EnvironmentUtils.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/utils/EnvironmentUtils.java index 4326309..a3ef631 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/utils/EnvironmentUtils.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/utils/EnvironmentUtils.java
@@ -22,13 +22,13 @@ import org.apache.iotdb.commons.conf.CommonConfig; import org.apache.iotdb.commons.conf.CommonDescriptor; import org.apache.iotdb.commons.exception.StartupException; -import org.apache.iotdb.commons.udf.service.UDFManagementService; import org.apache.iotdb.db.conf.DataNodeMemoryConfig; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.exception.StorageEngineException; import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext; import org.apache.iotdb.db.queryengine.execution.fragment.QueryContext; +import org.apache.iotdb.db.queryengine.plan.udf.UDFManagementService; import org.apache.iotdb.db.schemaengine.SchemaEngine; import org.apache.iotdb.db.storageengine.StorageEngine; import org.apache.iotdb.db.storageengine.buffer.BloomFilterCache;
diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java index 1eca6e7..3cc416f 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -22,6 +22,8 @@ import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; import org.apache.iotdb.ainode.rpc.thrift.TConfigs; import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; @@ -53,10 +55,14 @@ import org.slf4j.LoggerFactory; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; import java.util.Optional; +import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; +import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR; + public class AINodeClient implements AutoCloseable, ThriftClient { private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); @@ -188,6 +194,27 @@ } } + public TForecastResp forecast( + String modelId, TsBlock inputTsBlock, int outputLength, Map<String, String> options) { + try { + TForecastReq forecastReq = + new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), outputLength); + forecastReq.setOptions(options); + return client.forecast(forecastReq); + } catch (IOException e) { + TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode()); + tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage())); + return new TForecastResp(tsStatus, ByteBuffer.allocate(0)); + } catch (TException e) { + TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); + tsStatus.setMessage( + String.format( + "Failed to connect to AINode from DataNode when executing %s: %s", + Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); + return new TForecastResp(tsStatus, ByteBuffer.allocate(0)); + } + } + public TSStatus createTrainingTask(TTrainingReq req) throws TException { try { return client.createTrainingTask(req);
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index 9ac07b4..5643da7 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -89,6 +89,18 @@ 6: optional string existingModelId } +struct TForecastReq { + 1: required string modelId + 2: required binary inputData + 3: required i32 outputLength + 4: optional map<string, string> options +} + +struct TForecastResp { + 1: required common.TSStatus status + 2: required binary forecastResult +} + service IAINodeRPCService { // -------------- For Config Node -------------- @@ -104,4 +116,6 @@ // -------------- For Data Node -------------- TInferenceResp inference(TInferenceReq req) + + TForecastResp forecast(TForecastReq req) } \ No newline at end of file