Data providers are wrappers that load external data, be it images, text, or general tensors, and split it into mini-batches so that the model can consume the data in a uniformed way.
mx.AbstractDataProvider
The difference between data and label is that during training stage, both data and label will be feeded into the model, while during prediction stage, only data is loaded. Otherwise, they could be anything, with any names, and of any shapes. The provided data and label names here should match the input names in a target SymbolicNode.
A data provider should also implement the Julia iteration interface, in order to allow iterating through the data set. The provider will be called in the following way:
for batch in eachbatch(provider) data = get_data(provider, batch) end
which will be translated by Julia compiler into
state = Base.start(eachbatch(provider)) while !Base.done(provider, state) (batch, state) = Base.next(provider, state) data = get_data(provider, batch) end
By default, eachbatch simply returns the provider itself, so the iterator interface is implemented on the provider type itself. But the extra layer of abstraction allows us to implement a data provider easily via a Julia Task coroutine. See the data provider defined in the char-lstm example for an example of using coroutine to define data providers.
The detailed interface functions for the iterator API is listed below:
Base.eltype(provider) -> AbstractDataBatch
Returns the specific subtype representing a data batch. See AbstractDataBatch.
provider::AbstractDataProvider: the data provider.
Base.start(provider) -> AbstractDataProviderState
This function is always called before iterating into the dataset. It should initialize the iterator, reset the index, and do data shuffling if needed.
provider::AbstractDataProvider: the data provider.
Base.done(provider, state) -> Bool
True if there is no more data to iterate in this dataset.
provider::AbstractDataProvider: the data provider.
state::AbstractDataProviderState: the state returned by Base.start and Base.next.
Base.next(provider) -> (AbstractDataBatch, AbstractDataProviderState)
Returns the current data batch, and the state for the next iteration.
provider::AbstractDataProvider: the data provider.Note sometimes you are wrapping an existing data iterator (e.g. the built-in libmxnet data iterator) that is built with a different convention. It might be difficult to adapt to the interfaces stated here. In this case, you can safely assume that
Base.start will always be called, and called only once before the iteration starts.Base.done will always be called at the beginning of every iteration and always be called once.Base.done return true, the iteration will stop, until the next round, again, starting with a call to Base.start.Base.next will always be called only once in each iteration. It will always be called after one and only one call to Base.done; but if Base.done returns true, Base.next will not be called.With those assumptions, it will be relatively easy to adapt any existing iterator. See the implementation of the built-in MXDataProvider for example.
!!! note Please do not use the one data provider simultaneously in two different places, either in parallel, or in a nested loop. For example, the behavior for the following code is undefined
```julia
for batch in data
# updating the parameters
# now let's test the performance on the training set
for b2 in data
# ...
end
end
```
mx.get_batch_size mx.provide_data mx.provide_label
mx.AbstractDataProviderState mx.count_samples mx.get_data mx.get_label mx.get mx.load_data! mx.load_label!
Modules = [MXNet.mx] Pages = ["io.jl"]