blob: 4a14e0df8037a8b5f8ae0f3a179521ec1794aec1 [file] [log] [blame]
import plpy
import re
class Formula(object):
def __init__(self, y_str, x_str, coef_len):
self.n_coef = coef_len
self.y = y_str.replace('"','')
self.x = self.parse(x_str)
def parse(self, x_str):
array_expr = re.compile(r'array[[](["a-z0-9_, .]+)[]]', flags=re.I)
simple_col = re.compile(r'["a-z0-9_]+', flags=re.I)
prefix = 'x'
if array_expr.match(x_str) is not None:
x_csv = array_expr.sub(r'\1', x_str)
ret = [s.strip().replace('"','') for s in x_csv.split(',')]
if len(ret) == self.n_coef:
return ret
else:
pass # fall back to using 'x'
elif simple_col.match(x_str) is not None:
prefix = x_str.replace('"','')
return ["{0}[{1}]".format(prefix, str(i+1)) for i in range(self.n_coef)]
def rename(self, spec):
if isinstance(spec, str):
if spec.find('{') == 0:
spec = spec.replace('{','').replace('}','')
spec = [s.strip() for s in spec.split(',')]
return self.rename(spec)
if '~' in spec:
(y, spec) = spec.split('~')
y = y.strip()
else:
y = self.y
if '+' in spec:
x = [s.strip() for s in spec.split('+')]
else:
x = [s.strip() for s in spec.split(',')]
if self.n_coef != len(x):
plpy.warning("PMML warning: unexpected namespec '" + \
spec + "', using default names")
else:
self.y = y
self.x = x
else:
if len(spec) == self.n_coef + 1:
self.y = spec[0]
self.x = spec[1:]
elif len(spec) == self.n_coef:
self.x = spec
else:
plpy.warning("PMML warning: unexpected namespec '" + \
str(spec) + "', using default names")