blob: 6eea6f307ce3737d0ce9fa6ea7f57877bf0ba289 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.thriftserver
import java.sql.{Date, Statement}
import org.apache.livy.LivyConf
trait CommonThriftTests {
def hiveSupportEnabled(sparkMajorVersion: Int, livyConf: LivyConf): Boolean = {
sparkMajorVersion > 1 || livyConf.getBoolean(LivyConf.ENABLE_HIVE_CONTEXT)
}
def dataTypesTest(statement: Statement, mapSupported: Boolean): Unit = {
val resultSet = statement.executeQuery(
"select 1, 'a', cast(null as int), 1.2345, CAST('2018-08-06' as date)")
resultSet.next()
assert(resultSet.getInt(1) == 1)
assert(resultSet.getString(2) == "a")
assert(resultSet.getInt(3) == 0)
assert(resultSet.wasNull())
assert(resultSet.getDouble(4) == 1.2345)
assert(resultSet.getDate(5) == Date.valueOf("2018-08-06"))
assert(!resultSet.next())
val resultSetWithNulls = statement.executeQuery("select cast(null as string), " +
"cast(null as decimal), cast(null as double), cast(null as date), null")
resultSetWithNulls.next()
assert(resultSetWithNulls.getString(1) == null)
assert(resultSetWithNulls.wasNull())
assert(resultSetWithNulls.getBigDecimal(2) == null)
assert(resultSetWithNulls.wasNull())
assert(resultSetWithNulls.getDouble(3) == 0.0)
assert(resultSetWithNulls.wasNull())
assert(resultSetWithNulls.getDate(4) == null)
assert(resultSetWithNulls.wasNull())
assert(resultSetWithNulls.getString(5) == null)
assert(resultSetWithNulls.wasNull())
assert(!resultSetWithNulls.next())
val complexTypesQuery = if (mapSupported) {
"select array(1.5, 2.4, 1.3), struct('a', 1, 1.5), map(1, 'a', 2, 'b')"
} else {
"select array(1.5, 2.4, 1.3), struct('a', 1, 1.5)"
}
val resultSetComplex = statement.executeQuery(complexTypesQuery)
resultSetComplex.next()
assert(resultSetComplex.getString(1) == "[1.5,2.4,1.3]")
assert(resultSetComplex.getString(2) == "{\"col1\":\"a\",\"col2\":1,\"col3\":1.5}")
if (mapSupported) {
assert(resultSetComplex.getString(3) == "{1:\"a\",2:\"b\"}")
}
assert(!resultSetComplex.next())
}
}
class BinaryThriftServerSuite extends ThriftServerBaseTest with CommonThriftTests {
override def mode: ServerMode.Value = ServerMode.binary
override def port: Int = 20000
test("Reuse existing session") {
withJdbcConnection { _ =>
val sessionManager = LivyThriftServer.getInstance.get.getSessionManager()
val sessionHandle = sessionManager.getSessions.head
// Blocks until the session is ready
val session = sessionManager.getLivySession(sessionHandle)
withJdbcConnection("default", Seq(s"livy.server.sessionId=${session.id}")) { _ =>
val it = sessionManager.getSessions.iterator
// Blocks until all the sessions are ready
while (it.hasNext) {
sessionManager.getLivySession(it.next())
}
assert(LivyThriftServer.getInstance.get.livySessionManager.size() == 1)
}
}
}
test("fetch different data types") {
val supportMap = hiveSupportEnabled(formattedSparkVersion._1, livyConf)
withJdbcStatement { statement =>
dataTypesTest(statement, supportMap)
}
}
test("support default database in connection URIs") {
assume(hiveSupportEnabled(formattedSparkVersion._1, livyConf))
val db = "new_db"
withJdbcConnection { c =>
val s1 = c.createStatement()
s1.execute(s"create database $db")
s1.close()
val sessionManager = LivyThriftServer.getInstance.get.getSessionManager()
val sessionHandle = sessionManager.getSessions.head
// Blocks until the session is ready
val session = sessionManager.getLivySession(sessionHandle)
withJdbcConnection(db, Seq(s"livy.server.sessionId=${session.id}")) { c =>
val statement = c.createStatement()
val resultSet = statement.executeQuery("select current_database()")
resultSet.next()
assert(resultSet.getString(1) === db)
statement.close()
}
val s2 = c.createStatement()
s2.execute(s"drop database $db")
s2.close()
}
}
test("support hivevar") {
assume(hiveSupportEnabled(formattedSparkVersion._1, livyConf))
withJdbcConnection(jdbcUri("default") + "#myVar1=val1;myVar2=val2") { c =>
val statement = c.createStatement()
val selectRes = statement.executeQuery("select \"${myVar1}\", \"${myVar2}\"")
selectRes.next()
assert(selectRes.getString(1) === "val1")
assert(selectRes.getString(2) === "val2")
statement.close()
}
}
}
class HttpThriftServerSuite extends ThriftServerBaseTest with CommonThriftTests {
override def mode: ServerMode.Value = ServerMode.http
override def port: Int = 20001
test("fetch different data types") {
val supportMap = hiveSupportEnabled(formattedSparkVersion._1, livyConf)
withJdbcStatement { statement =>
dataTypesTest(statement, supportMap)
}
}
}