blob: 5c8af6a4dd93e3006fe21923467466f688df7728 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import collections
class DotDict(dict):
"""
Simple class for dot access elements in dict, support nested initialization
Example:
d = DotDict({'child': 'dotdict'}, name='dotdict', index=1, contents=['a', 'b'])
# add new key
d.new_key = '!' # or d['new_key'] = '!'
# update values
d.new_key = '!!!'
# delete keys
del d.new_key
"""
def __init__(self, *args, **kwargs):
super(DotDict, self).__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super(DotDict, self).__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super(DotDict, self).__delitem__(key)
del self.__dict__[key]
def namedtuple_with_defaults(typename, field_names, default_values=()):
""" create a namedtuple with default values """
T = collections.namedtuple(typename, field_names)
T.__new__.__defaults__ = (None, ) * len(T._fields)
if isinstance(default_values, collections.Mapping):
prototype = T(**default_values)
else:
prototype = T(*default_values)
T.__new__.__defaults__ = tuple(prototype)
return T
def merge_dict(a, b):
""" merge dict a, b, with b overriding keys in a """
c = a.copy()
c.update(b)
return c
def zip_namedtuple(nt_list):
""" accept list of namedtuple, return a dict of zipped fields """
if not nt_list:
return dict()
if not isinstance(nt_list, list):
nt_list = [nt_list]
for nt in nt_list:
assert type(nt) == type(nt_list[0])
ret = {k : [v] for k, v in nt_list[0]._asdict().items()}
for nt in nt_list[1:]:
for k, v in nt._asdict().items():
ret[k].append(v)
return ret
def config_as_dict(cfg):
""" convert raw configuration to unified dictionary """
ret = cfg.__dict__.copy()
# random cropping params
del ret['rand_crop_samplers']
assert isinstance(cfg.rand_crop_samplers, list)
ret = merge_dict(ret, zip_namedtuple(cfg.rand_crop_samplers))
num_crop_sampler = len(cfg.rand_crop_samplers)
ret['num_crop_sampler'] = num_crop_sampler # must specify the #
ret['rand_crop_prob'] = 1.0 / (num_crop_sampler + 1) * num_crop_sampler
# random padding params
del ret['rand_pad']
ret = merge_dict(ret, cfg.rand_pad._asdict())
# color jitter
del ret['color_jitter']
ret = merge_dict(ret, cfg.color_jitter._asdict())
return ret