blob: 2ea7f978da38c694a12fd84285eb9d8f70c2131c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
from iotdb.mlnode.parser import parse_training_request
from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
def test_parse_training_request():
model_id = 'mid_etth1_dlinear_default'
is_auto = False
model_configs = {
'task_class': 'forecast_training_task',
'source_type': 'thrift',
'dataset_type': 'window', # or use DatasetType.WINDOW,
'filename': 'ETTh1.csv',
'time_embed': 'h',
'input_len': 96,
'pred_len': 96,
'model_name': 'dlinear',
'input_vars': 7,
'output_vars': 7,
'forecast_type': 'endogenous',
'kernel_size': 25,
'learning_rate': 1e-3,
'batch_size': 32,
'num_workers': 0,
'epochs': 10,
'metric_names': ['MSE', 'MAE']
}
query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
query_filter = '0,1501516800000'
req = TCreateTrainingTaskReq(
modelId=str(model_id),
isAuto=is_auto,
modelConfigs={k: str(v) for k, v in model_configs.items()},
queryExpressions=[str(query) for query in query_expressions],
queryFilter=str(query_filter),
)
data_config, model_config, task_config = parse_training_request(req)
for config in model_configs:
if config in data_config:
assert data_config[config] == model_configs[config]
if config in model_config:
assert model_config[config] == model_configs[config]
if config in task_config:
assert task_config[config] == model_configs[config]
def test_missing_argument():
# missing model_name
model_id = 'mid_etth1_dlinear_default'
is_auto = False
model_configs = {
'task_class': 'forecast_training_task',
'source_type': 'thrift',
'dataset_type': 'window',
'filename': 'ETTh1.csv',
'time_embed': 'h',
'input_len': 96,
'pred_len': 96,
'input_vars': 7,
'output_vars': 7,
'forecast_type': 'endogenous',
'kernel_size': 25,
'learning_rate': 1e-3,
'batch_size': 32,
'num_workers': 0,
'epochs': 10,
'metric_names': ['MSE', 'MAE']
}
query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
query_filter = '0,1501516800000'
req = TCreateTrainingTaskReq(
modelId=str(model_id),
isAuto=is_auto,
modelConfigs={k: str(v) for k, v in model_configs.items()},
queryExpressions=[str(query) for query in query_expressions],
queryFilter=str(query_filter),
)
try:
parse_training_request(req)
except Exception as e:
assert e.message == MissingConfigError(config_name='model_name').message
def test_wrong_argument_type():
model_id = 'mid_etth1_dlinear_default'
is_auto = False
model_configs = {
'task_class': 'forecast_training_task',
'source_type': 'thrift',
'dataset_type': 'window',
'filename': 'ETTh1.csv',
'time_embed': 'h',
'input_len': 96.7,
'pred_len': 96,
'model_name': 'dlinear',
'input_vars': 7,
'output_vars': 7,
'forecast_type': 'endogenous',
'kernel_size': 25,
'learning_rate': 1e-3,
'batch_size': 32,
'num_workers': 0,
'epochs': 10,
'metric_names': ['MSE', 'MAE']
}
query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**', 'root.eg.etth1.**']
query_filter = '0,1501516800000'
req = TCreateTrainingTaskReq(
modelId=str(model_id),
isAuto=is_auto,
modelConfigs={k: str(v) for k, v in model_configs.items()},
queryExpressions=[str(query) for query in query_expressions],
queryFilter=str(query_filter),
)
try:
data_config, model_config, task_config = parse_training_request(req)
except Exception as e:
assert e.message == WrongTypeConfigError(config_name='input_len', expected_type='int').message