blob: 4404f39b86980ee78fe7e264747490262a5682f7 [file] [log] [blame]
"""Helper functions for parsing caffe prototxt into a workable DAG
"""
def process_network_proto(caffe_root, deploy_proto):
"""
Runs the caffe upgrade tool on the prototxt to create a prototxt in the latest format.
This enable us to work just with latest structures, instead of supporting all the variants
:param caffe_root: link to caffe root folder, where the upgrade tool is located
:param deploy_proto: name of the original prototxt file
:return: name of new processed prototxt file
"""
processed_deploy_proto = deploy_proto + ".processed"
from shutil import copyfile
copyfile(deploy_proto, processed_deploy_proto)
# run upgrade tool on new file name (same output file)
import os
upgrade_tool_command_line = caffe_root + '/build/tools/upgrade_net_proto_text.bin ' \
+ processed_deploy_proto + ' ' + processed_deploy_proto
os.system(upgrade_tool_command_line)
return processed_deploy_proto
class LayerRecord(object):
"""
A record which describe basic layer parameters
"""
def __init__(self, layer_def):
self.layer_def = layer_def
self.name = layer_def.name
self.type = layer_def.type
# keep filter, stride and pad
if layer_def.type == 'Convolution':
if LayerRecord._is_iterable(layer_def.convolution_param.kernel_size):
self.filter = list(layer_def.convolution_param.kernel_size)
else:
self.filter = list([layer_def.convolution_param.kernel_size])
if len(self.filter) == 1:
self.filter *= 2
if LayerRecord._is_iterable(layer_def.convolution_param.pad):
self.pad = list(layer_def.convolution_param.pad)
else:
self.pad = list([layer_def.convolution_param.pad])
if len(self.pad) == 0:
self.pad = [0, 0]
elif len(self.pad) == 1:
self.pad *= 2
if LayerRecord._is_iterable(layer_def.convolution_param.stride):
self.stride = list(layer_def.convolution_param.stride)
else:
self.stride = list([layer_def.convolution_param.stride])
if len(self.stride) == 0:
self.stride = [1, 1]
elif len(self.stride) == 1:
self.stride *= 2
elif layer_def.type == 'Pooling':
self.filter = [layer_def.pooling_param.kernel_size]
if len(self.filter) == 1:
self.filter *= 2
self.pad = [layer_def.pooling_param.pad]
if len(self.pad) == 0:
self.pad = [0, 0]
elif len(self.pad) == 1:
self.pad *= 2
self.stride = [layer_def.pooling_param.stride]
if len(self.stride) == 0:
self.stride = [1, 1]
elif len(self.stride) == 1:
self.stride *= 2
else:
self.filter = [0, 0]
self.pad = [0, 0]
self.stride = [1, 1]
# keep tops
self.tops = list(layer_def.top)
# keep bottoms
self.bottoms = list(layer_def.bottom)
# list of parent layers
self.parents = []
# list of child layers
self.children = []
@staticmethod
def _is_iterable(obj):
return hasattr(obj, '__iter__')
def read_network_dag(processed_deploy_prototxt):
"""
Reads from the caffe prototxt the network structure
:param processed_deploy_prototxt: name of prototxt to load, preferably the prototxt should
be processed before using a call to process_network_proto()
:return: network_def, layer_name_to_record, top_to_layers
network_def: caffe network structure, gives access to *all* the network information
layer_name_to_record: *ordered* dictionary which maps between layer name and a structure which
describes in a simple form the layer parameters
top_to_layers: dictionary which maps a blob name to an ordered list of layers which output it
when a top is used several times, like in inplace layhers, the list will contain all the layers
by order of appearance
"""
from caffe.proto import caffe_pb2
from google.protobuf import text_format
from collections import OrderedDict
# load prototxt file
network_def = caffe_pb2.NetParameter()
with open(processed_deploy_prototxt, 'r') as proto_file:
text_format.Merge(str(proto_file.read()), network_def)
# map layer name to layer record
layer_name_to_record = OrderedDict()
for layer_def in network_def.layer:
if (len(layer_def.include) == 0) or \
(caffe_pb2.TEST in [item.phase for item in layer_def.include]):
layer_name_to_record[layer_def.name] = LayerRecord(layer_def)
top_to_layers = dict()
for layer in network_def.layer:
# no specific phase, or TEST phase is specifically asked for
if (len(layer.include) == 0) or (caffe_pb2.TEST in [item.phase for item in layer.include]):
for top in layer.top:
if top not in top_to_layers:
top_to_layers[top] = list()
top_to_layers[top].append(layer.name)
# find parents and children of all layers
for child_layer_name in layer_name_to_record.keys(): # pylint: disable=too-many-nested-blocks
child_layer_def = layer_name_to_record[child_layer_name]
for bottom in child_layer_def.bottoms:
if bottom in top_to_layers:
for parent_layer_name in top_to_layers[bottom]:
if parent_layer_name in layer_name_to_record:
parent_layer_def = layer_name_to_record[parent_layer_name]
if parent_layer_def not in child_layer_def.parents:
child_layer_def.parents.append(parent_layer_def)
if child_layer_def not in parent_layer_def.children:
parent_layer_def.children.append(child_layer_def)
# update filter, strid, pad for maxout "structures"
for layer_name in layer_name_to_record.keys():
layer_def = layer_name_to_record[layer_name]
if layer_def.type == 'Eltwise' and \
len(layer_def.parents) == 1 and \
layer_def.parents[0].type == 'Slice' and \
len(layer_def.parents[0].parents) == 1 and \
layer_def.parents[0].parents[0].type in ['Convolution', 'InnerProduct']:
layer_def.filter = layer_def.parents[0].parents[0].filter
layer_def.stride = layer_def.parents[0].parents[0].stride
layer_def.pad = layer_def.parents[0].parents[0].pad
return network_def, layer_name_to_record, top_to_layers
def read_caffe_mean(caffe_mean_file):
"""
Reads caffe formatted mean file
:param caffe_mean_file: path to caffe mean file, presumably with 'binaryproto' suffix
:return: mean image, converted from BGR to RGB format
"""
import caffe_parser
import numpy as np
mean_blob = caffe_parser.caffe_pb2.BlobProto()
with open(caffe_mean_file, 'rb') as f:
mean_blob.ParseFromString(f.read())
img_mean_np = np.array(mean_blob.data)
img_mean_np = img_mean_np.reshape(mean_blob.channels, mean_blob.height, mean_blob.width)
# swap channels from Caffe BGR to RGB
img_mean_np[[0, 2], :, :] = img_mean_np[[2, 0], :, :]
return img_mean_np