blob: 7785a97e53c11abcfcacf9cfbafa3b34d1e242b3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""utility functions for io.py"""
from collections import OrderedDict
import numpy as np
try:
import h5py
except ImportError:
h5py = None
from ..ndarray.sparse import CSRNDArray
from ..ndarray.sparse import array as sparse_array
from ..ndarray import NDArray
from ..ndarray import array
def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
if data is None:
data = []
if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
if isinstance(data, list):
if not allow_empty:
assert(len(data) > 0)
if len(data) == 1:
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type
else:
data = OrderedDict( # pylint: disable=redefined-variable-type
[(f'_{i}_{default_name}', d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
"a list of them or dict with them as values")
for k, v in data.items():
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError((f"Invalid type '{type(v)}' for {k}, ") +
"should be NDArray, numpy.ndarray or h5py.Dataset")
return list(sorted(data.items()))
def _has_instance(data, dtype):
"""Return True if ``data`` has instance of ``dtype``.
This function is called after _init_data.
``data`` is a list of (str, NDArray)"""
for item in data:
_, arr = item
if isinstance(arr, dtype):
return True
return False
def _getdata_by_idx(data, idx):
"""Shuffle the data."""
shuffle_data = []
for k, v in data:
if (isinstance(v, h5py.Dataset) if h5py else False):
shuffle_data.append((k, v))
elif isinstance(v, CSRNDArray):
shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
else:
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
return shuffle_data