blob: a105b1cbbb0fc5f878ad76f54193e919951c3089 [file] [log] [blame]
# -*- coding: utf-8 -*-
import pytest
from tests.utils import TestCaseBase
import sqlparse
from sqlparse.exceptions import SQLParseError
class TestFormat(TestCaseBase):
def test_keywordcase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, keyword_case='upper')
self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- select foo\n')
res = sqlparse.format(sql, keyword_case='capitalize')
self.ndiffAssertEqual(res, 'Select * From bar; -- select foo\n')
res = sqlparse.format(sql.upper(), keyword_case='lower')
self.ndiffAssertEqual(res, 'select * from BAR; -- SELECT FOO\n')
self.assertRaises(SQLParseError, sqlparse.format, sql,
keyword_case='foo')
def test_identifiercase(self):
sql = 'select * from bar; -- select foo\n'
res = sqlparse.format(sql, identifier_case='upper')
self.ndiffAssertEqual(res, 'select * from BAR; -- select foo\n')
res = sqlparse.format(sql, identifier_case='capitalize')
self.ndiffAssertEqual(res, 'select * from Bar; -- select foo\n')
res = sqlparse.format(sql.upper(), identifier_case='lower')
self.ndiffAssertEqual(res, 'SELECT * FROM bar; -- SELECT FOO\n')
self.assertRaises(SQLParseError, sqlparse.format, sql,
identifier_case='foo')
sql = 'select * from "foo"."bar"'
res = sqlparse.format(sql, identifier_case="upper")
self.ndiffAssertEqual(res, 'select * from "foo"."bar"')
def test_strip_comments_single(self):
sql = 'select *-- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select * from foo')
sql = 'select * -- statement starts here\nfrom foo'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select * from foo')
sql = 'select-- foo\nfrom -- bar\nwhere'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select from where')
self.assertRaises(SQLParseError, sqlparse.format, sql,
strip_comments=None)
def test_strip_comments_multi(self):
sql = '/* sql starts here */\nselect'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select')
sql = '/* sql starts here */ select'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select')
sql = '/*\n * sql starts here\n */\nselect'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select')
sql = 'select (/* sql starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select (select 2)')
sql = 'select (/* sql /* starts here */ select 2)'
res = sqlparse.format(sql, strip_comments=True)
self.ndiffAssertEqual(res, 'select (select 2)')
def test_strip_ws(self):
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* from foo\n\twhere ( 1 = 2 )\n'
self.ndiffAssertEqual(f(s), 'select * from foo where (1 = 2)')
s = 'select -- foo\nfrom bar\n'
self.ndiffAssertEqual(f(s), 'select -- foo\nfrom bar')
self.assertRaises(SQLParseError, sqlparse.format, s,
strip_whitespace=None)
def test_preserve_ws(self):
# preserve at least one whitespace after subgroups
f = lambda sql: sqlparse.format(sql, strip_whitespace=True)
s = 'select\n* /* foo */ from bar '
self.ndiffAssertEqual(f(s), 'select * /* foo */ from bar')
def test_notransform_of_quoted_crlf(self):
# Make sure that CR/CR+LF characters inside string literals don't get
# affected by the formatter.
s1 = "SELECT some_column LIKE 'value\r'"
s2 = "SELECT some_column LIKE 'value\r'\r\nWHERE id = 1\n"
s3 = "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\r"
s4 = "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\r\n"
f = lambda x: sqlparse.format(x)
# Because of the use of
self.ndiffAssertEqual(f(s1), "SELECT some_column LIKE 'value\r'")
self.ndiffAssertEqual(f(s2), "SELECT some_column LIKE 'value\r'\nWHERE id = 1\n")
self.ndiffAssertEqual(f(s3), "SELECT some_column LIKE 'value\\'\r' WHERE id = 1\n")
self.ndiffAssertEqual(f(s4), "SELECT some_column LIKE 'value\\\\\\'\r' WHERE id = 1\n")
def test_outputformat(self):
sql = 'select * from foo;'
self.assertRaises(SQLParseError, sqlparse.format, sql,
output_format='foo')
class TestFormatReindent(TestCaseBase):
def test_option(self):
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
reindent=2)
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
indent_tabs=2)
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
reindent=True, indent_width='foo')
self.assertRaises(SQLParseError, sqlparse.format, 'foo',
reindent=True, indent_width=-12)
def test_stmts(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo; select bar'
self.ndiffAssertEqual(f(s), 'select foo;\n\nselect bar')
s = 'select foo'
self.ndiffAssertEqual(f(s), 'select foo')
s = 'select foo; -- test\n select bar'
self.ndiffAssertEqual(f(s), 'select foo; -- test\n\nselect bar')
def test_keywords(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo union select * from bar;'
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
'from foo',
'union',
'select *',
'from bar;']))
def test_keywords_between(self): # issue 14
# don't break AND after BETWEEN
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'and foo between 1 and 2 and bar = 3'
self.ndiffAssertEqual(f(s), '\n'.join(['',
'and foo between 1 and 2',
'and bar = 3']))
def test_parenthesis(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select count(*) from (select * from foo);'
self.ndiffAssertEqual(f(s),
'\n'.join(['select count(*)',
'from',
' (select *',
' from foo);',
])
)
def test_where(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo where bar = 1 and baz = 2 or bzz = 3;'
self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
'where bar = 1\n'
' and baz = 2\n'
' or bzz = 3;'))
s = 'select * from foo where bar = 1 and (baz = 2 or bzz = 3);'
self.ndiffAssertEqual(f(s), ('select *\nfrom foo\n'
'where bar = 1\n'
' and (baz = 2\n'
' or bzz = 3);'))
def test_join(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select * from foo join bar on 1 = 2'
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
'from foo',
'join bar on 1 = 2']))
s = 'select * from foo inner join bar on 1 = 2'
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
'from foo',
'inner join bar on 1 = 2']))
s = 'select * from foo left outer join bar on 1 = 2'
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
'from foo',
'left outer join bar on 1 = 2']
))
s = 'select * from foo straight_join bar on 1 = 2'
self.ndiffAssertEqual(f(s), '\n'.join(['select *',
'from foo',
'straight_join bar on 1 = 2']
))
def test_identifier_list(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo, bar, baz from table1, table2 where 1 = 2'
self.ndiffAssertEqual(f(s), '\n'.join(['select foo,',
' bar,',
' baz',
'from table1,',
' table2',
'where 1 = 2']))
s = 'select a.*, b.id from a, b'
self.ndiffAssertEqual(f(s), '\n'.join(['select a.*,',
' b.id',
'from a,',
' b']))
def test_identifier_list_with_functions(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = ("select 'abc' as foo, coalesce(col1, col2)||col3 as bar,"
"col3 from my_table")
self.ndiffAssertEqual(f(s), '\n'.join(
["select 'abc' as foo,",
" coalesce(col1, col2)||col3 as bar,",
" col3",
"from my_table"]))
def test_case(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case when foo = 1 then 2 when foo = 3 then 4 else 5 end'
self.ndiffAssertEqual(f(s), '\n'.join(['case',
' when foo = 1 then 2',
' when foo = 3 then 4',
' else 5',
'end']))
def test_case2(self):
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'case(foo) when bar = 1 then 2 else 3 end'
self.ndiffAssertEqual(f(s), '\n'.join(['case(foo)',
' when bar = 1 then 2',
' else 3',
'end']))
def test_nested_identifier_list(self): # issue4
f = lambda sql: sqlparse.format(sql, reindent=True)
s = '(foo as bar, bar1, bar2 as bar3, b4 as b5)'
self.ndiffAssertEqual(f(s), '\n'.join(['(foo as bar,',
' bar1,',
' bar2 as bar3,',
' b4 as b5)']))
def test_duplicate_linebreaks(self): # issue3
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select c1 -- column1\nfrom foo'
self.ndiffAssertEqual(f(s), '\n'.join(['select c1 -- column1',
'from foo']))
s = 'select c1 -- column1\nfrom foo'
r = sqlparse.format(s, reindent=True, strip_comments=True)
self.ndiffAssertEqual(r, '\n'.join(['select c1',
'from foo']))
s = 'select c1\nfrom foo\norder by c1'
self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
'from foo',
'order by c1']))
s = 'select c1 from t1 where (c1 = 1) order by c1'
self.ndiffAssertEqual(f(s), '\n'.join(['select c1',
'from t1',
'where (c1 = 1)',
'order by c1']))
def test_keywordfunctions(self): # issue36
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select max(a) b, foo, bar'
self.ndiffAssertEqual(f(s), '\n'.join(['select max(a) b,',
' foo,',
' bar']))
def test_identifier_and_functions(self): # issue45
f = lambda sql: sqlparse.format(sql, reindent=True)
s = 'select foo.bar, nvl(1) from dual'
self.ndiffAssertEqual(f(s), '\n'.join(['select foo.bar,',
' nvl(1)',
'from dual']))
class TestOutputFormat(TestCaseBase):
def test_python(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='python')
self.ndiffAssertEqual(f(sql), "sql = 'select * from foo;'")
f = lambda sql: sqlparse.format(sql, output_format='python',
reindent=True)
self.ndiffAssertEqual(f(sql), ("sql = ('select * '\n"
" 'from foo;')"))
def test_php(self):
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='php')
self.ndiffAssertEqual(f(sql), '$sql = "select * from foo;";')
f = lambda sql: sqlparse.format(sql, output_format='php',
reindent=True)
self.ndiffAssertEqual(f(sql), ('$sql = "select * ";\n'
'$sql .= "from foo;";'))
def test_sql(self): # "sql" is an allowed option but has no effect
sql = 'select * from foo;'
f = lambda sql: sqlparse.format(sql, output_format='sql')
self.ndiffAssertEqual(f(sql), 'select * from foo;')
def test_format_column_ordering(): # issue89
sql = 'select * from foo order by c1 desc, c2, c3;'
formatted = sqlparse.format(sql, reindent=True)
expected = '\n'.join(['select *',
'from foo',
'order by c1 desc,',
' c2,',
' c3;'])
assert formatted == expected
def test_truncate_strings():
sql = 'update foo set value = \'' + 'x' * 1000 + '\';'
formatted = sqlparse.format(sql, truncate_strings=10)
assert formatted == 'update foo set value = \'xxxxxxxxxx[...]\';'
formatted = sqlparse.format(sql, truncate_strings=3, truncate_char='YYY')
assert formatted == 'update foo set value = \'xxxYYY\';'
def test_truncate_strings_invalid_option():
pytest.raises(SQLParseError, sqlparse.format,
'foo', truncate_strings='bar')
pytest.raises(SQLParseError, sqlparse.format,
'foo', truncate_strings=-1)
pytest.raises(SQLParseError, sqlparse.format,
'foo', truncate_strings=0)
@pytest.mark.parametrize('sql', ['select verrrylongcolumn from foo',
'select "verrrylongcolumn" from "foo"'])
def test_truncate_strings_doesnt_truncate_identifiers(sql):
formatted = sqlparse.format(sql, truncate_strings=2)
assert formatted == sql
def test_having_produces_newline():
sql = (
'select * from foo, bar where bar.id = foo.bar_id'
' having sum(bar.value) > 100')
formatted = sqlparse.format(sql, reindent=True)
expected = [
'select *',
'from foo,',
' bar',
'where bar.id = foo.bar_id',
'having sum(bar.value) > 100'
]
assert formatted == '\n'.join(expected)