blob: 45efe4715f030148f2e7864c53de7d0760433e6f [file] [log] [blame]
"""Parse caffe's protobuf
"""
import re
try:
import caffe
from caffe.proto import caffe_pb2
use_caffe = True
except ImportError:
try:
import caffe_pb2
except ImportError:
raise ImportError('You used to compile with protoc --python_out=./ ./caffe.proto')
use_caffe = False
from google.protobuf import text_format
def read_prototxt(fname):
"""Return a caffe_pb2.NetParameter object that defined in a prototxt file
"""
proto = caffe_pb2.NetParameter()
with open(fname, 'r') as f:
text_format.Merge(str(f.read()), proto)
return proto
def get_layers(proto):
"""Returns layers in a caffe_pb2.NetParameter object
"""
if len(proto.layer):
return proto.layer
elif len(proto.layers):
return proto.layers
else:
raise ValueError('Invalid proto file.')
def read_caffemodel(prototxt_fname, caffemodel_fname):
"""Return a caffe_pb2.NetParameter object that defined in a binary
caffemodel file
"""
if use_caffe:
caffe.set_mode_cpu()
net = caffe.Net(prototxt_fname, caffemodel_fname, caffe.TEST)
layer_names = net._layer_names
layers = net.layers
return (layers, layer_names)
else:
proto = caffe_pb2.NetParameter()
with open(caffemodel_fname, 'rb') as f:
proto.ParseFromString(f.read())
return (get_layers(proto), None)
def layer_iter(layers, layer_names):
if use_caffe:
for layer_idx, layer in enumerate(layers):
layer_name = re.sub('[-/]', '_', layer_names[layer_idx])
layer_type = layer.type
layer_blobs = layer.blobs
yield (layer_name, layer_type, layer_blobs)
else:
for layer in layers:
layer_name = re.sub('[-/]', '_', layer.name)
layer_type = layer.type
layer_blobs = layer.blobs
yield (layer_name, layer_type, layer_blobs)