blob: acf49099a99139ae03dfae1529f1ad90aeed58a1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
# pylint: disable=
"""Dataset sampler."""
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'FilterSampler', 'BatchSampler',
'IntervalSampler']
import numpy as np
class Sampler(object):
"""Base class for samplers.
All samplers should subclass `Sampler` and define `__iter__` and `__len__`
methods.
"""
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
"""Samples elements from [start, start+length) sequentially.
Parameters
----------
length : int
Length of the sequence.
start : int, default is 0
The start of the sequence index.
"""
def __init__(self, length, start=0):
self._length = length
self._start = start
def __iter__(self):
return iter(range(self._start, self._start + self._length))
def __len__(self):
return self._length
class RandomSampler(Sampler):
"""Samples elements from [0, length) randomly without replacement.
Parameters
----------
length : int
Length of the sequence.
"""
def __init__(self, length):
self._length = length
def __iter__(self):
indices = np.arange(self._length)
np.random.shuffle(indices)
return iter(indices)
def __len__(self):
return self._length
class FilterSampler(Sampler):
"""Samples elements from a Dataset for which `fn` returns True.
Parameters
----------
fn : callable
A callable function that takes a sample and returns a boolean
dataset : Dataset
The dataset to filter.
"""
def __init__(self, fn, dataset):
self._fn = fn
self._dataset = dataset
self._indices = [i for i, sample in enumerate(dataset) if fn(sample)]
def __iter__(self):
return iter(self._indices)
def __len__(self):
return len(self._indices)
class BatchSampler(Sampler):
"""Wraps over another `Sampler` and return mini-batches of samples.
Parameters
----------
sampler : Sampler
The source Sampler.
batch_size : int
Size of mini-batch.
last_batch : {'keep', 'discard', 'rollover'}
Specifies how the last batch is handled if batch_size does not evenly
divide sequence length.
If 'keep', the last batch will be returned directly, but will contain
less element than `batch_size` requires.
If 'discard', the last batch will be discarded.
If 'rollover', the remaining elements will be rolled over to the next
iteration.
Examples
--------
>>> sampler = gluon.data.SequentialSampler(10)
>>> batch_sampler = gluon.data.BatchSampler(sampler, 3, 'keep')
>>> list(batch_sampler)
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
"""
def __init__(self, sampler, batch_size, last_batch='keep'):
self._sampler = sampler
self._batch_size = batch_size
self._last_batch = last_batch
self._prev = []
def __iter__(self):
batch, self._prev = self._prev, []
for i in self._sampler:
batch.append(i)
if len(batch) == self._batch_size:
yield batch
batch = []
if batch:
if self._last_batch == 'keep':
yield batch
elif self._last_batch == 'discard':
return
elif self._last_batch == 'rollover':
self._prev = batch
else:
raise ValueError(
"last_batch must be one of 'keep', 'discard', or 'rollover', " \
f"but got {self._last_batch}")
def __len__(self):
if self._last_batch == 'keep':
return (len(self._sampler) + self._batch_size - 1) // self._batch_size
if self._last_batch == 'discard':
return len(self._sampler) // self._batch_size
if self._last_batch == 'rollover':
return (len(self._prev) + len(self._sampler)) // self._batch_size
raise ValueError(
"last_batch must be one of 'keep', 'discard', or 'rollover', " \
f"but got {self._last_batch}")
class IntervalSampler(Sampler):
"""Samples elements from [0, length) at fixed intervals.
Parameters
----------
length : int
Length of the sequence.
interval : int
The number of items to skip between two samples.
rollover : bool, default True
Whether to start again from the first skipped item after reaching the end.
If true, this sampler would start again from the first skipped item until all items
are visited.
Otherwise, iteration stops when end is reached and skipped items are ignored.
Examples
--------
>>> sampler = contrib.data.IntervalSampler(13, interval=3)
>>> list(sampler)
[0, 3, 6, 9, 12, 1, 4, 7, 10, 2, 5, 8, 11]
>>> sampler = contrib.data.IntervalSampler(13, interval=3, rollover=False)
>>> list(sampler)
[0, 3, 6, 9, 12]
"""
def __init__(self, length, interval, rollover=True):
assert interval <= length, \
"Interval {} must be smaller than or equal to length {}".format(interval, length)
self._length = length
self._interval = interval
self._rollover = rollover
def __iter__(self):
for i in range(self._interval if self._rollover else 1):
for j in range(i, self._length, self._interval):
yield j
def __len__(self):
return self._length