blob: 7ad381c1c300cd41cafb31ff715ef05cdd70114c [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import types
import os
import sys
import random
MAX_UNICODE = 0x10FFFF
# TODO
# - could be more minimal
# - eg when bracket lands on a utf8 boundary, like 3 - 2047 -- they can share the two * edges
# - also 3 2048 or 3 65536 -- it should not have an * down the red path, but it does
# MASKS[0] is bottom 1-bit
# MASKS[1] is bottom 2-bits
# ...
utf8Ranges = [(0, 127),
(128, 2047),
(2048, 65535),
(65536, 1114111)]
typeToColor = {'startend': 'purple',
'start': 'blue',
'end': 'red'}
class FSA:
def __init__(self):
# maps fromNode -> (startUTF8, endUTF8, endNode)
self.states = {}
self.nodeUpto = 0
def run(self, bytes):
state = self.start
for b in bytes:
found = False
oldState = state
for label, s, e, n in self.states[state][1:]:
if b >= s and b <= e:
if found:
raise RuntimeError('state %s has ambiguous output for byte %s' % (oldState, b))
state = n
found = True
if not found:
return -1
return state
def addEdge(self, n1, n2, v1, v2, label):
"""
Adds edge from n1-n2, utf8 byte range v1-v2.
"""
assert n1 in self.states
assert type(v1) is types.IntType
assert type(v2) is types.IntType
self.states[n1].append((label, v1, v2, n2))
def addNode(self, label=None):
try:
self.states[self.nodeUpto] = [label]
return self.nodeUpto
finally:
self.nodeUpto += 1
def toDOT(self, label):
__l = []
w = __l.append
endNode = startNode = None
for id, details in self.states.items():
name = details[0]
if name == 'end':
endNode = id
elif name == 'start':
startNode = id
w('digraph %s {' % label)
w(' rankdir=LR;')
w(' size="8,5";')
w(' node [color=white label=""]; Ns;')
w(' node [color=black];')
w(' node [shape=doublecircle, label=""]; N%s [label="%s"];' % (endNode, endNode))
w(' node [shape=circle];')
w(' N%s [label="%s"];' % (startNode, startNode))
w(' Ns -> N%s;' % startNode)
for id, details in self.states.items():
edges = details[1:]
w(' N%s [label="%s"];' % (id, id))
for type, s, e, dest in edges:
c = typeToColor.get(type, 'black')
if type == 'all*':
# special case -- matches any utf8 byte at this point
label = '*'
elif s == e:
label = '%s' % binary(s)
else:
label = '%s-%s' % (binary(s), binary(e))
w(' N%s -> N%s [label="%s" color="%s"];' % (id, dest, label, c))
if name == 'end':
endNode = id
elif name == 'start':
startNode = id
w('}')
return '\n'.join(__l)
def toPNG(self, label, pngOut):
open('tmp.dot', 'wb').write(self.toDOT(label))
if os.system('dot -Tpng tmp.dot -o %s' % pngOut):
raise RuntimeException('dot failed')
MASKS = []
v = 2
for i in range(32):
MASKS.append(v-1)
v *= 2
def binary(x):
if x == 0:
return '00000000'
l = []
while x > 0:
if x & 1 == 1:
l.append('1')
else:
l.append('0')
x = x >> 1
# big endian!
l.reverse()
l2 = []
while len(l) > 0:
s = ''.join(l[-8:])
if len(s) < 8:
s = '0'*(8-len(s)) + s
l2.append(s)
del l[-8:]
return ' '.join(l2)
def getUTF8Rest(code, numBytes):
l = []
for i in range(numBytes):
l.append((128 | (code & MASKS[5]), 6))
code = code >> 6
l.reverse()
return tuple(l)
def toUTF8(code):
# code = Unicode code point
assert code >= 0
assert code <= MAX_UNICODE
if code < 128:
# 0xxxxxxx
bytes = ((code, 7),)
elif code < 2048:
# 110yyyxx 10xxxxxx
byte1 = (6 << 5) | (code >> 6)
bytes = ((byte1, 5),) + getUTF8Rest(code, 1)
elif code < 65536:
# 1110yyyy 10yyyyxx 10xxxxxx
len = 3
byte1 = (14 << 4) | (code >> 12)
bytes = ((byte1, 4),) + getUTF8Rest(code, 2)
else:
# 11110zzz 10zzyyyy 10yyyyxx 10xxxxxx
len = 4
byte1 = (30 << 3) | (code >> 18)
bytes = ((byte1, 3),) + getUTF8Rest(code, 3)
return bytes
def all(fsa, startNode, endNode, startCode, endCode, left):
if len(left) == 0:
fsa.addEdge(startNode, endNode, startCode, endCode, 'all')
else:
lastN = fsa.addNode()
fsa.addEdge(startNode, lastN, startCode, endCode, 'all')
while len(left) > 1:
n = fsa.addNode()
fsa.addEdge(lastN, n, 128, 191, 'all*')
left = left[1:]
lastN = n
fsa.addEdge(lastN, endNode, 128, 191, 'all*')
def start(fsa, startNode, endNode, utf8, doAll):
if len(utf8) == 1:
fsa.addEdge(startNode, endNode, utf8[0][0], utf8[0][0] | MASKS[utf8[0][1]-1], 'start')
else:
n = fsa.addNode()
fsa.addEdge(startNode, n, utf8[0][0], utf8[0][0], 'start')
start(fsa, n, endNode, utf8[1:], True)
end = utf8[0][0] | MASKS[utf8[0][1]-1]
if doAll and utf8[0][0] != end:
all(fsa, startNode, endNode, utf8[0][0]+1, end, utf8[1:])
def end(fsa, startNode, endNode, utf8, doAll):
if len(utf8) == 1:
fsa.addEdge(startNode, endNode, utf8[0][0] & ~MASKS[utf8[0][1]-1], utf8[0][0], 'end')
else:
if utf8[0][1] == 5:
# special case -- avoid created unused edges (utf8 doesn't accept certain byte sequences):
start = 194
else:
start = utf8[0][0] & (~MASKS[utf8[0][1]-1])
if doAll and utf8[0][0] != start:
all(fsa, startNode, endNode, start, utf8[0][0]-1, utf8[1:])
n = fsa.addNode()
fsa.addEdge(startNode, n, utf8[0][0], utf8[0][0], 'end')
end(fsa, n, endNode, utf8[1:], True)
def build(fsa,
startNode, endNode,
startUTF8, endUTF8):
# Break into start, middle, end:
if startUTF8[0][0] == endUTF8[0][0]:
# Degen case: lead with the same byte:
if len(startUTF8) == 1 and len(endUTF8) == 1:
fsa.addEdge(startNode, endNode, startUTF8[0][0], endUTF8[0][0], 'startend')
return
else:
assert len(startUTF8) != 1
assert len(endUTF8) != 1
n = fsa.addNode()
# single value edge
fsa.addEdge(startNode, n, startUTF8[0][0], startUTF8[0][0], 'single')
build(fsa, n, endNode, startUTF8[1:], endUTF8[1:])
elif len(startUTF8) == len(endUTF8):
if len(startUTF8) == 1:
fsa.addEdge(startNode, endNode, startUTF8[0][0], endUTF8[0][0], 'startend')
else:
start(fsa, startNode, endNode, startUTF8, False)
if endUTF8[0][0] - startUTF8[0][0] > 1:
all(fsa, startNode, endNode, startUTF8[0][0]+1, endUTF8[0][0]-1, startUTF8[1:])
end(fsa, startNode, endNode, endUTF8, False)
else:
# start
start(fsa, startNode, endNode, startUTF8, True)
# possibly middle
byteCount = 1+len(startUTF8)
while byteCount < len(endUTF8):
s = toUTF8(utf8Ranges[byteCount-1][0])
e = toUTF8(utf8Ranges[byteCount-1][1])
all(fsa, startNode, endNode,
s[0][0],
e[0][0],
s[1:])
byteCount += 1
# end
end(fsa, startNode, endNode, endUTF8, True)
def main():
if len(sys.argv) not in (3, 4):
print
print 'Usage: python %s startUTF32 endUTF32 [testCode]' % sys.argv[0]
print
sys.exit(1)
utf32Start = int(sys.argv[1])
utf32End = int(sys.argv[2])
if utf32Start > utf32End:
print 'ERROR: start must be <= end'
sys.exit(1)
fsa = FSA()
fsa.start = fsa.addNode('start')
fsa.end = fsa.addNode('end')
print 's=%s' % ' '.join([binary(x[0]) for x in toUTF8(utf32Start)])
print 'e=%s' % ' '.join([binary(x[0]) for x in toUTF8(utf32End)])
if len(sys.argv) == 4:
print 't=%s [%s]' % \
(' '.join([binary(x[0]) for x in toUTF8(int(sys.argv[3]))]),
' '.join(['%2x' % x[0] for x in toUTF8(int(sys.argv[3]))]))
build(fsa, fsa.start, fsa.end,
toUTF8(utf32Start),
toUTF8(utf32End))
fsa.toPNG('test', '/tmp/outpy.png')
print 'Saved to /tmp/outpy.png...'
test(fsa, utf32Start, utf32End, 100000);
def test(fsa, utf32Start, utf32End, count):
# verify correct ints are accepted
for i in range(count):
r = random.randint(utf32Start, utf32End)
dest = fsa.run([tup[0] for tup in toUTF8(r)])
if dest != fsa.end:
print 'FAILED: valid %s (%s) is not accepted' % (r, ' '.join([binary(x[0]) for x in toUTF8(r)]))
return False
invalidRange = MAX_UNICODE - (utf32End - utf32Start + 1)
if invalidRange >= 0:
# verify invalid ints are not accepted
for i in range(count):
r = random.randint(0, invalidRange-1)
if r >= utf32Start:
r = utf32End + 1 + r - utf32Start
dest = fsa.run([tup[0] for tup in toUTF8(r)])
if dest != -1:
print 'FAILED: invalid %s (%s) is accepted' % (r, ' '.join([binary(x[0]) for x in toUTF8(r)]))
return False
return True
def stress():
print 'Testing...'
iter = 0
while True:
if iter % 10 == 0:
print '%s...' % iter
iter += 1
v1 = random.randint(0, MAX_UNICODE)
v2 = random.randint(0, MAX_UNICODE)
if v2 < v1:
v1, v2 = v2, v1
utf32Start = v1
utf32End = v2
fsa = FSA()
fsa.start = fsa.addNode('start')
fsa.end = fsa.addNode('end')
build(fsa, fsa.start, fsa.end,
toUTF8(utf32Start),
toUTF8(utf32End))
if not test(fsa, utf32Start, utf32End, 10000):
print 'FAILED on utf32Start=%s utf32End=%s' % (utf32Start, utf32End)
if __name__ == '__main__':
if len(sys.argv) > 1:
main()
else:
stress()