blob: 44e588f3872860b6f184b18aa401c433921670d4 [file] [log] [blame]
import struct
import numpy as num
import sys
class KaldiParser:
NO_OPEN_BRACKET = "found > before <"
ERR_NO_CLOSE_BRACKET = "reached eof before >"
ERR_BYTES_BEFORE_TOKEN = "found bytes before <"
NO_SPACE_AFTER = "missing space after >"
def __init__(self, f):
self.f = f
self.binary = self.f.read(2) == '\0B'
assert(self.binary), "text format not supported yet"
if not self.binary:
self.f.seek(0, 0)
def is_binary(self):
return self.binary
def try_next_token(self):
pos = self.f.tell()
err, tok = self.next_token()
if err is not None:
self.f.seek(pos, 0)
print err, tok
return None
return tok.lower()
def next_token(self):
# keep reading until you get a > or at end of file (return None)
# consume the space
# return substring from < to >
# if things before < are not space, return error
buf = ""
while True:
b = self.f.read(1)
if b is None:
return KaldiParser.ERR_NO_CLOSE_BRACKET, None
buf += b
if b == ">":
break
try:
start = buf.index("<")
except ValueError:
return KaldiParser.NO_OPEN_BRACKET, None
b = self.f.read(1)
if not (b == " " or b is None):
return KaldiParser.NO_SPACE_AFTER, buf[start:]
if start != 0:
return KaldiParser.ERR_BYTES_BEFORE_TOKEN, buf[start:]
return None, buf
def read_space(self):
b = self.f.read(1)
assert(b == " " or b is None)
# http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html
def read_basic_type(self, type):
if self.binary:
size = num.fromfile(self.f, dtype=num.dtype("i1"), count=1)[0]
if type == "int":
dtype = "<i4"
dsize = 4
elif type == "float":
dtype = "<f4"
dsize = 4
elif type == "char":
dtype = 'a'
dsize = 1
else:
print "unrecognized type"
return None
assert(size == dsize)
n = num.fromfile(self.f, dtype=num.dtype(dtype), count=1)
return n[0]
else:
assert(False), "not supported yet"
def read_matrix(self):
mode = self.f.read(2)
#print mode
assert(mode == 'FM')
self.read_space()
rows = self.read_basic_type("int")
#print "rows", rows
cols = self.read_basic_type("int")
#print "cols", cols
n = num.fromfile(self.f, dtype=num.dtype("<f4"), count=rows * cols)
n = n.reshape((rows, cols))
#print n[0][0]
#print "-----------"
return n
def read_vector(self):
mode = self.f.read(2)
#print mode
assert(mode == 'FV')
self.read_space()
length = self.read_basic_type("int")
#print "length", length
n = num.fromfile(self.f, dtype=num.dtype("<f4"), count=length)
#print n[0]
#print "-----------"
return n
def fileIsBinary(filename):
f = open(filename, "rb")
binary = (f.read(2) == '\0B')
f.seek(0, 0)
return binary
def file2nnet_binary(filename):
f = open(filename, "rb")
parser = KaldiParser(f)
net = []
layer = None
while True:
tok = parser.try_next_token()
if tok is None:
print "error"
break
if tok == "<nnet>":
continue
elif tok == "<affinetransform>":
if layer is not None:
net += [layer]
layer = {}
layer["outdim"] = parser.read_basic_type("int")
layer["indim"] = parser.read_basic_type("int")
elif tok == "<learnratecoef>":
parser.read_basic_type("float")
elif tok == "<biaslearnratecoef>":
parser.read_basic_type("float")
elif tok == "<maxnorm>":
parser.read_basic_type("float")
layer["weights"] = parser.read_matrix().transpose() # kaldi writes the transpose!!!!
layer["bias"] = parser.read_vector()
elif tok == "<sigmoid>" or tok == "<softmax>":
layer["type"] = tok[1:-1]
outdim1 = parser.read_basic_type("int")
outdim2 = parser.read_basic_type("int")
assert(outdim1 == outdim2 and outdim2 == layer["outdim"])
elif tok == "</nnet>":
#print "Done!"
break
else:
print "unrecognized token", tok
break
if layer is not None:
net += [layer]
#for layer in net:
# print layer.keys()
return net
if __name__ == '__main__':
filename = "exp/dnn4_pretrain-dbn_dnn/nnet_6.dbn_dnn.init"
#filename = "/usr/users/leoliu/s5/exp/dnn4_pretrain-dbn_dnn/final.feature_transform"
print filename
print "isBinary:", fileIsBinary(filename)
a = file2nnet_binary(filename)
"""
while True:
err, tok = parser.next_token()
if err != KaldiParser.NO_SPACE_AFTER and tok != None:
print err, tok
"""
"""
fout.write('<affinetransform> ' + str(output_size) + ' ' + str(input_size) + '\n')
fout.write('[' + '\n')
for x in xrange(output_size):
fout.write(W_layer[x].strip() + '\n')
fout.write(']' + '\n')
fout.write('[ ' + b_layer.strip() + ' ]' + '\n')
if maxout:
fout.write('<maxout> ' + str(int(layers[i + 1])) + ' ' + str(output_size) + '\n')
else:
fout.write('<sigmoid> ' + str(output_size) + ' ' + str(output_size) + '\n')
"""