| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """utility functions for io.py""" |
| from collections import OrderedDict |
| |
| import numpy as np |
| try: |
| import h5py |
| except ImportError: |
| h5py = None |
| |
| from ..ndarray.sparse import CSRNDArray |
| from ..ndarray.sparse import array as sparse_array |
| from ..ndarray import NDArray |
| from ..ndarray import array |
| |
| def _init_data(data, allow_empty, default_name): |
| """Convert data into canonical form.""" |
| assert (data is not None) or allow_empty |
| if data is None: |
| data = [] |
| |
| if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) |
| if h5py else (np.ndarray, NDArray)): |
| data = [data] |
| if isinstance(data, list): |
| if not allow_empty: |
| assert(len(data) > 0) |
| if len(data) == 1: |
| data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type |
| else: |
| data = OrderedDict( # pylint: disable=redefined-variable-type |
| [(f'_{i}_{default_name}', d) for i, d in enumerate(data)]) |
| if not isinstance(data, dict): |
| raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + |
| "a list of them or dict with them as values") |
| for k, v in data.items(): |
| if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): |
| try: |
| data[k] = array(v) |
| except: |
| raise TypeError((f"Invalid type '{type(v)}' for {k}, ") + |
| "should be NDArray, numpy.ndarray or h5py.Dataset") |
| |
| return list(sorted(data.items())) |
| |
| |
| def _has_instance(data, dtype): |
| """Return True if ``data`` has instance of ``dtype``. |
| This function is called after _init_data. |
| ``data`` is a list of (str, NDArray)""" |
| for item in data: |
| _, arr = item |
| if isinstance(arr, dtype): |
| return True |
| return False |
| |
| |
| def _getdata_by_idx(data, idx): |
| """Shuffle the data.""" |
| shuffle_data = [] |
| |
| for k, v in data: |
| if (isinstance(v, h5py.Dataset) if h5py else False): |
| shuffle_data.append((k, v)) |
| elif isinstance(v, CSRNDArray): |
| shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) |
| else: |
| shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) |
| |
| return shuffle_data |