blob: da9e151054c317ead74044a15d54a55a17e49595 [file] [log] [blame]
from imdb import Imdb
import random
class ConcatDB(Imdb):
"""
ConcatDB is used to concatenate multiple imdbs to form a larger db.
It is very useful to combine multiple dataset with same classes.
Parameters
----------
imdbs : Imdb or list of Imdb
Imdbs to be concatenated
shuffle : bool
whether to shuffle the initial list
"""
def __init__(self, imdbs, shuffle):
super(ConcatDB, self).__init__('concatdb')
if not isinstance(imdbs, list):
imdbs = [imdbs]
self.imdbs = imdbs
self._check_classes()
self.image_set_index = self._load_image_set_index(shuffle)
def _check_classes(self):
"""
check input imdbs, make sure they have same classes
"""
try:
self.classes = self.imdbs[0].classes
self.num_classes = len(self.classes)
except AttributeError:
# fine, if no classes is provided
pass
if self.num_classes > 0:
for db in self.imdbs:
assert self.classes == db.classes, "Multiple imdb must have same classes"
def _load_image_set_index(self, shuffle):
"""
get total number of images, init indices
Parameters
----------
shuffle : bool
whether to shuffle the initial indices
"""
self.num_images = 0
for db in self.imdbs:
self.num_images += db.num_images
indices = range(self.num_images)
if shuffle:
random.shuffle(indices)
return indices
def _locate_index(self, index):
"""
given index, find out sub-db and sub-index
Parameters
----------
index : int
index of a specific image
Returns
----------
a tuple (sub-db, sub-index)
"""
assert index >= 0 and index < self.num_images, "index out of range"
pos = self.image_set_index[index]
for k, v in enumerate(self.imdbs):
if pos >= v.num_images:
pos -= v.num_images
else:
return (k, pos)
def image_path_from_index(self, index):
"""
given image index, find out full path
Parameters
----------
index: int
index of a specific image
Returns
----------
full path of this image
"""
assert self.image_set_index is not None, "Dataset not initialized"
pos = self.image_set_index[index]
n_db, n_index = self._locate_index(index)
return self.imdbs[n_db].image_path_from_index(n_index)
def label_from_index(self, index):
"""
given image index, return preprocessed ground-truth
Parameters
----------
index: int
index of a specific image
Returns
----------
ground-truths of this image
"""
assert self.image_set_index is not None, "Dataset not initialized"
pos = self.image_set_index[index]
n_db, n_index = self._locate_index(index)
return self.imdbs[n_db].label_from_index(n_index)