| from imdb import Imdb |
| import random |
| |
| class ConcatDB(Imdb): |
| """ |
| ConcatDB is used to concatenate multiple imdbs to form a larger db. |
| It is very useful to combine multiple dataset with same classes. |
| Parameters |
| ---------- |
| imdbs : Imdb or list of Imdb |
| Imdbs to be concatenated |
| shuffle : bool |
| whether to shuffle the initial list |
| """ |
| def __init__(self, imdbs, shuffle): |
| super(ConcatDB, self).__init__('concatdb') |
| if not isinstance(imdbs, list): |
| imdbs = [imdbs] |
| self.imdbs = imdbs |
| self._check_classes() |
| self.image_set_index = self._load_image_set_index(shuffle) |
| |
| def _check_classes(self): |
| """ |
| check input imdbs, make sure they have same classes |
| """ |
| try: |
| self.classes = self.imdbs[0].classes |
| self.num_classes = len(self.classes) |
| except AttributeError: |
| # fine, if no classes is provided |
| pass |
| |
| if self.num_classes > 0: |
| for db in self.imdbs: |
| assert self.classes == db.classes, "Multiple imdb must have same classes" |
| |
| def _load_image_set_index(self, shuffle): |
| """ |
| get total number of images, init indices |
| |
| Parameters |
| ---------- |
| shuffle : bool |
| whether to shuffle the initial indices |
| """ |
| self.num_images = 0 |
| for db in self.imdbs: |
| self.num_images += db.num_images |
| indices = range(self.num_images) |
| if shuffle: |
| random.shuffle(indices) |
| return indices |
| |
| def _locate_index(self, index): |
| """ |
| given index, find out sub-db and sub-index |
| |
| Parameters |
| ---------- |
| index : int |
| index of a specific image |
| |
| Returns |
| ---------- |
| a tuple (sub-db, sub-index) |
| """ |
| assert index >= 0 and index < self.num_images, "index out of range" |
| pos = self.image_set_index[index] |
| for k, v in enumerate(self.imdbs): |
| if pos >= v.num_images: |
| pos -= v.num_images |
| else: |
| return (k, pos) |
| |
| def image_path_from_index(self, index): |
| """ |
| given image index, find out full path |
| |
| Parameters |
| ---------- |
| index: int |
| index of a specific image |
| |
| Returns |
| ---------- |
| full path of this image |
| """ |
| assert self.image_set_index is not None, "Dataset not initialized" |
| pos = self.image_set_index[index] |
| n_db, n_index = self._locate_index(index) |
| return self.imdbs[n_db].image_path_from_index(n_index) |
| |
| def label_from_index(self, index): |
| """ |
| given image index, return preprocessed ground-truth |
| |
| Parameters |
| ---------- |
| index: int |
| index of a specific image |
| |
| Returns |
| ---------- |
| ground-truths of this image |
| """ |
| assert self.image_set_index is not None, "Dataset not initialized" |
| pos = self.image_set_index[index] |
| n_db, n_index = self._locate_index(index) |
| return self.imdbs[n_db].label_from_index(n_index) |