[IOTDB-5680] Implement the basic data loader on MLNode (#9372)
diff --git a/mlnode/iotdb/mlnode/datats/offline/data_source.py b/mlnode/iotdb/mlnode/datats/offline/data_source.py new file mode 100644 index 0000000..cd8e9a8 --- /dev/null +++ b/mlnode/iotdb/mlnode/datats/offline/data_source.py
@@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import pandas as pd + +from iotdb.mlnode import serde +from iotdb.mlnode.client import client_manager + + +class DataSource(object): + """ + Pre-fetched in multi-variate time series in memory + + Methods: + get_data: returns self.data, the time series value (Numpy.2DArray) + get_timestamp: returns self.timestamp, the aligned timestamp value + """ + + def __init__(self): + self.data = None + self.timestamp = None + + def _read_data(self): + raise NotImplementedError + + def get_data(self): + return self.data + + def get_timestamp(self): + return self.timestamp + + +class FileDataSource(DataSource): + def __init__(self, filename: str = None): + super(FileDataSource, self).__init__() + self.filename = filename + self._read_data() + + def _read_data(self): + try: + raw_data = pd.read_csv(self.filename) + except Exception: + raise RuntimeError(f'Fail to load data with filename: {self.filename}') + cols_data = raw_data.columns[1:] + self.data = raw_data[cols_data].values + self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values) + + +class ThriftDataSource(DataSource): + def __init__(self, query_expressions: list = None, query_filter: str = None): + super(DataSource, self).__init__() + self.query_expressions = query_expressions + self.query_filter = query_filter + self._read_data() + + def _read_data(self): + try: + data_client = client_manager.borrow_data_node_client() + except Exception: # is this exception catch needed??? + raise RuntimeError('Fail to establish connection with DataNode') + + try: + res = data_client.fetch_timeseries( + queryExpressions=self.query_expressions, + queryFilter=self.query_filter, + ) + except Exception: + raise RuntimeError(f'Fail to fetch data with query expressions: {self.query_expressions}' + f' and query filter: {self.query_filter}') + + if len(res.tsDataset) == 0: + raise RuntimeError(f'No data fetched with query filter: {self.query_filter}') + + raw_data = serde.convert_to_df(res.columnNameList, + res.columnTypeList, + res.columnNameIndexMap, + res.tsDataset) + if raw_data.empty: + raise RuntimeError(f'Fetched empty data with query expressions: ' + f'{self.query_expressions} and query filter: {self.query_filter}') + cols_data = raw_data.columns[1:] + self.data = raw_data[cols_data].values + self.timestamp = pd.to_datetime(raw_data[raw_data.columns[0]].values, unit='ms', utc=True) \ + .tz_convert('Asia/Shanghai') # for iotdb
diff --git a/mlnode/iotdb/mlnode/datats/offline/dataset.py b/mlnode/iotdb/mlnode/datats/offline/dataset.py new file mode 100644 index 0000000..c71aaf8 --- /dev/null +++ b/mlnode/iotdb/mlnode/datats/offline/dataset.py
@@ -0,0 +1,116 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +import argparse + +from torch.utils.data import Dataset + +from iotdb.mlnode.datats.offline.data_source import DataSource +from iotdb.mlnode.datats.utils.timefeatures import time_features + +# currently support for multivariate forecasting only + + +class TimeSeriesDataset(Dataset): + """ + Build Row-by-Row dataset (with each element as multivariable time series at + the same time and correponding timestamp embedding) + + Args: + data_source: the whole multivariate time series for a while + time_embed: embedding frequency, see `utils/timefeatures.py` for more detail + + Returns: + Random accessible dataset + """ + + def __init__(self, data_source: DataSource, time_embed: str = 'h'): + self.time_embed = time_embed + self.data = data_source.get_data() + self.data_stamp = time_features(data_source.get_timestamp(), time_embed=self.time_embed).transpose(1, 0) + self.n_vars = self.data.shape[-1] + + def get_variable_num(self): + return self.n_vars # number of series in data_source + + def __getitem__(self, index): + seq = self.data[index] + seq_t = self.data_stamp[index] + return seq, seq_t + + def __len__(self): + return len(self.data) + + +class WindowDataset(TimeSeriesDataset): + """ + Build Windowed dataset (with each element as multivariable time series + with a sliding window and corresponding timestamps embedding), + the sliding step is one unit in give data source + + Args: + data_source: the whole multivariate time series for a while + time_embed: embedding frequency, see `utils/timefeatures.py` for more detail + input_len: input window size (unit) [1, 2, ... I] + pred_len: output window size (unit) right after the input window [I+1, I+2, ... I+P] + + Returns: + Random accessible dataset + """ + + def __init__(self, + data_source: DataSource = None, + input_len: int = 96, + pred_len: int = 96, + time_embed: str = 'h'): + self.input_len = input_len + self.pred_len = pred_len + if input_len <= self.data.shape[0]: + raise RuntimeError('input_len should not be larger than the number of time series points') + if pred_len <= self.data.shape[0]: + raise RuntimeError('pred_len should not be larger than the number of time series points') + super(WindowDataset, self).__init__(data_source, time_embed) + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.input_len + r_begin = s_end + r_end = s_end + self.pred_len + seq_x = self.data[s_begin:s_end] + seq_y = self.data[r_begin:r_end] + seq_x_t = self.data_stamp[s_begin:s_end] + seq_y_t = self.data_stamp[r_begin:r_end] + return seq_x, seq_y, seq_x_t, seq_y_t + + def __len__(self): + return len(self.data) - self.input_len - self.pred_len + 1 + + +def get_timeseries_dataset(data_config: argparse.Namespace) -> TimeSeriesDataset: + # TODO (@lcy) + # init datasource + # init dataset + pass + + +def get_window_dataset(data_config: argparse.Namespace) -> WindowDataset: + # TODO (@lcy) + # init datasource + # init dataset + pass
diff --git a/mlnode/iotdb/mlnode/datats/utils/__init__.py b/mlnode/iotdb/mlnode/datats/utils/__init__.py new file mode 100644 index 0000000..2a1e720 --- /dev/null +++ b/mlnode/iotdb/mlnode/datats/utils/__init__.py
@@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +#
diff --git a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py new file mode 100644 index 0000000..bd1681c --- /dev/null +++ b/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
@@ -0,0 +1,173 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + + +from typing import List + +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + '()' + + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Embedding timestamp by given frequency string + Args: + freq_str: frequency string of the form [multiple][granularity] such as '12H', '5min', '1D' etc. + Returns: + a list of time features that will be appropriate for the given frequency string. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + try: + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + except ValueError: + supported_freq_msg = f''' + Unsupported time embedding frequency ({freq_str}) + The following frequencies are supported (case-insensitive): + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + ''' + raise RuntimeError(supported_freq_msg) + + +def time_features(dates, time_embed='h'): + return np.vstack([feat(dates) for feat in time_features_from_frequency_str(time_embed)]) + + +def data_transform(data_raw: pd.DataFrame, freq='h'): + """ + data: dataframe, column 0 is the time stamp + """ + columns = data_raw.columns + data = data_raw[columns[1:]] + data_stamp = data_raw[columns[0]] + return data.values, data_stamp + + +def timestamp_transform(timestamp_raw: pd.DataFrame, freq='h'): + """ + """ + timestamp = pd.to_datetime(timestamp_raw.values.squeeze(), unit='ms', utc=True).tz_convert('Asia/Shanghai') + timestamp = time_features(timestamp, freq=freq) + timestamp = timestamp.transpose(1, 0) + return timestamp