| """Presets for various network configurations""" |
| import logging |
| import symbol_builder |
| |
| def get_config(network, data_shape, **kwargs): |
| """Configuration factory for various networks |
| |
| Parameters |
| ---------- |
| network : str |
| base network name, such as vgg_reduced, inceptionv3, resnet... |
| data_shape : int |
| input data dimension |
| kwargs : dict |
| extra arguments |
| """ |
| if network == 'vgg16_reduced': |
| if data_shape >= 448: |
| from_layers = ['relu4_3', 'relu7', '', '', '', '', ''] |
| num_filters = [512, -1, 512, 256, 256, 256, 256] |
| strides = [-1, -1, 2, 2, 2, 2, 1] |
| pads = [-1, -1, 1, 1, 1, 1, 1] |
| sizes = [[.07, .1025], [.15,.2121], [.3, .3674], [.45, .5196], [.6, .6708], \ |
| [.75, .8216], [.9, .9721]] |
| ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ |
| [1,2,.5,3,1./3], [1,2,.5], [1,2,.5]] |
| normalizations = [20, -1, -1, -1, -1, -1, -1] |
| steps = [] if data_shape != 512 else [x / 512.0 for x in |
| [8, 16, 32, 64, 128, 256, 512]] |
| else: |
| from_layers = ['relu4_3', 'relu7', '', '', '', ''] |
| num_filters = [512, -1, 512, 256, 256, 256] |
| strides = [-1, -1, 2, 2, 1, 1] |
| pads = [-1, -1, 1, 1, 0, 0] |
| sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]] |
| ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ |
| [1,2,.5], [1,2,.5]] |
| normalizations = [20, -1, -1, -1, -1, -1] |
| steps = [] if data_shape != 300 else [x / 300.0 for x in [8, 16, 32, 64, 100, 300]] |
| if not (data_shape == 300 or data_shape == 512): |
| logging.warn('data_shape %d was not tested, use with caucious.' % data_shape) |
| return locals() |
| elif network == 'inceptionv3': |
| from_layers = ['ch_concat_mixed_7_chconcat', 'ch_concat_mixed_10_chconcat', '', '', '', ''] |
| num_filters = [-1, -1, 512, 256, 256, 128] |
| strides = [-1, -1, 2, 2, 2, 2] |
| pads = [-1, -1, 1, 1, 1, 1] |
| sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]] |
| ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ |
| [1,2,.5], [1,2,.5]] |
| normalizations = -1 |
| steps = [] |
| return locals() |
| elif network == 'resnet50': |
| num_layers = 50 |
| image_shape = '3,224,224' # resnet require it as shape check |
| network = 'resnet' |
| from_layers = ['_plus12', '_plus15', '', '', '', ''] |
| num_filters = [-1, -1, 512, 256, 256, 128] |
| strides = [-1, -1, 2, 2, 2, 2] |
| pads = [-1, -1, 1, 1, 1, 1] |
| sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]] |
| ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ |
| [1,2,.5], [1,2,.5]] |
| normalizations = -1 |
| steps = [] |
| return locals() |
| elif network == 'resnet101': |
| num_layers = 101 |
| image_shape = '3,224,224' |
| network = 'resnet' |
| from_layers = ['_plus12', '_plus15', '', '', '', ''] |
| num_filters = [-1, -1, 512, 256, 256, 128] |
| strides = [-1, -1, 2, 2, 2, 2] |
| pads = [-1, -1, 1, 1, 1, 1] |
| sizes = [[.1, .141], [.2,.272], [.37, .447], [.54, .619], [.71, .79], [.88, .961]] |
| ratios = [[1,2,.5], [1,2,.5,3,1./3], [1,2,.5,3,1./3], [1,2,.5,3,1./3], \ |
| [1,2,.5], [1,2,.5]] |
| normalizations = -1 |
| steps = [] |
| return locals() |
| else: |
| msg = 'No configuration found for %s with data_shape %d' % (network, data_shape) |
| raise NotImplementedError(msg) |
| |
| def get_symbol_train(network, data_shape, **kwargs): |
| """Wrapper for get symbol for train |
| |
| Parameters |
| ---------- |
| network : str |
| name for the base network symbol |
| data_shape : int |
| input shape |
| kwargs : dict |
| see symbol_builder.get_symbol_train for more details |
| """ |
| if network.startswith('legacy'): |
| logging.warn('Using legacy model.') |
| return symbol_builder.import_module(network).get_symbol_train(**kwargs) |
| config = get_config(network, data_shape, **kwargs).copy() |
| config.update(kwargs) |
| return symbol_builder.get_symbol_train(**config) |
| |
| def get_symbol(network, data_shape, **kwargs): |
| """Wrapper for get symbol for test |
| |
| Parameters |
| ---------- |
| network : str |
| name for the base network symbol |
| data_shape : int |
| input shape |
| kwargs : dict |
| see symbol_builder.get_symbol for more details |
| """ |
| if network.startswith('legacy'): |
| logging.warn('Using legacy model.') |
| return symbol_builder.import_module(network).get_symbol(**kwargs) |
| config = get_config(network, data_shape, **kwargs).copy() |
| config.update(kwargs) |
| return symbol_builder.get_symbol(**config) |