blob: 34ad3c42bce7fad044017e8c36dc9fbda90bd089 [file] [log] [blame] [view]
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
# Data Providers
Data providers are wrappers that load external data, be it images, text, or general tensors,
and split it into mini-batches so that the model can consume the data in a uniformed way.
## AbstractDataProvider interface
```@docs
mx.AbstractDataProvider
```
The difference between *data* and *label* is that during training stage,
both *data* and *label* will be feeded into the model, while during
prediction stage, only *data* is loaded. Otherwise, they could be anything, with any names, and
of any shapes. The provided data and label names here should match the input names in a target
`SymbolicNode`.
A data provider should also implement the Julia iteration interface, in order to allow iterating
through the data set. The provider will be called in the following way:
```julia
for batch in eachbatch(provider)
data = get_data(provider, batch)
end
```
which will be translated by Julia compiler into
```julia
state = Base.start(eachbatch(provider))
while !Base.done(provider, state)
(batch, state) = Base.next(provider, state)
data = get_data(provider, batch)
end
```
By default, `eachbatch` simply returns the provider itself, so the iterator interface
is implemented on the provider type itself. But the extra layer of abstraction allows us to
implement a data provider easily via a Julia `Task` coroutine. See the
data provider defined in [the char-lstm example](tutorial/char-lstm) for an example of using coroutine to define data
providers.
The detailed interface functions for the iterator API is listed below:
Base.eltype(provider) -> AbstractDataBatch
Returns the specific subtype representing a data batch. See `AbstractDataBatch`.
* `provider::AbstractDataProvider`: the data provider.
Base.start(provider) -> AbstractDataProviderState
This function is always called before iterating into the dataset. It should initialize
the iterator, reset the index, and do data shuffling if needed.
* `provider::AbstractDataProvider`: the data provider.
Base.done(provider, state) -> Bool
True if there is no more data to iterate in this dataset.
* `provider::AbstractDataProvider`: the data provider.
* `state::AbstractDataProviderState`: the state returned by `Base.start` and `Base.next`.
Base.next(provider) -> (AbstractDataBatch, AbstractDataProviderState)
Returns the current data batch, and the state for the next iteration.
* `provider::AbstractDataProvider`: the data provider.
Note sometimes you are wrapping an existing data iterator (e.g. the built-in libmxnet data iterator) that
is built with a different convention. It might be difficult to adapt to the interfaces stated here. In this
case, you can safely assume that
* `Base.start` will always be called, and called only once before the iteration starts.
* `Base.done` will always be called at the beginning of every iteration and always be called once.
* If `Base.done` return true, the iteration will stop, until the next round, again, starting with
a call to `Base.start`.
* `Base.next` will always be called only once in each iteration. It will always be called after
one and only one call to `Base.done`; but if `Base.done` returns true, `Base.next` will
not be called.
With those assumptions, it will be relatively easy to adapt any existing iterator. See the implementation
of the built-in `MXDataProvider` for example.
!!! note
Please do not use the one data provider simultaneously in two different places, either in parallel,
or in a nested loop. For example, the behavior for the following code is undefined
```julia
for batch in data
# updating the parameters
# now let's test the performance on the training set
for b2 in data
# ...
end
end
```
```@docs
mx.get_batch_size
mx.provide_data
mx.provide_label
```
## AbstractDataBatch interface
```@docs
mx.AbstractDataProviderState
mx.count_samples
mx.get_data
mx.get_label
mx.get
mx.load_data!
mx.load_label!
```
## Implemented providers and other methods
```@autodocs
Modules = [MXNet.mx]
Pages = ["io.jl"]
```