blob: 800838822ed9e31424729f50ac4ce1b110208c3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.livy.server.auth
import java.util.Properties
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse
import org.apache.commons.codec.binary.Base64
import org.apache.directory.server.annotations.CreateLdapServer
import org.apache.directory.server.annotations.CreateTransport
import org.apache.directory.server.core.annotations.ApplyLdifs
import org.apache.directory.server.core.annotations.ContextEntry
import org.apache.directory.server.core.annotations.CreateDS
import org.apache.directory.server.core.annotations.CreatePartition
import org.apache.directory.server.core.integ.AbstractLdapTestUnit
import org.apache.directory.server.core.integ.FrameworkRunner
import org.apache.hadoop.security.authentication.client.AuthenticationException
import org.junit.Assert
import org.junit.Before
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito
/**
* This unit test verifies the functionality of LdapAuthenticationHandlerImpl.
*/
@RunWith(classOf[FrameworkRunner])
@CreateLdapServer(transports = Array(
new CreateTransport(
protocol = "LDAP",
address = "localhost"
)))
@CreateDS(
allowAnonAccess = true,
partitions = Array(
new CreatePartition(
name = "Test_Partition",
suffix = "dc=example,dc=com",
contextEntry = new ContextEntry(entryLdif = "dn: dc=example," +
"dc=com \ndc: example\nobjectClass: top\nobjectClass: domain\n\n")
)))
@ApplyLdifs(Array(
"dn: uid=bjones,dc=example,dc=com",
"cn: Bob Jones",
"sn: Jones",
"objectClass: inetOrgPerson",
"uid: bjones",
"userPassword: p@ssw0rd"
))
class TestLdapAuthenticationHandlerImpl extends AbstractLdapTestUnit {
private val handler: LdapAuthenticationHandlerImpl = new LdapAuthenticationHandlerImpl
// HTTP header used by the server endpoint during an authentication sequence.
val WWW_AUTHENTICATE_HEADER = "WWW-Authenticate"
// HTTP header used by the client endpoint during an authentication sequence.
val AUTHORIZATION_HEADER = "Authorization"
// HTTP header prefix used during the Basic authentication sequence.
val BASIC = "Basic"
@Before
def setup(): Unit = {
handler.init(getDefaultProperties)
}
protected def getDefaultProperties: Properties = {
val p = new Properties
p.setProperty("ldap.basedn", "dc=example,dc=com")
p.setProperty("ldap.providerurl", String.format("ldap://%s:%s", "localhost",
AbstractLdapTestUnit.getLdapServer.getPort.toString))
p
}
@Test
def testRequestWithAuthorization(): Unit = {
val request = Mockito.mock(classOf[HttpServletRequest])
val response = Mockito.mock(classOf[HttpServletResponse])
val base64 = new Base64(0)
val credentials = base64.encodeToString("bjones:p@ssw0rd".getBytes)
val authHeader = BASIC + " " + credentials
Mockito.when(request.getHeader(AUTHORIZATION_HEADER)).thenReturn(authHeader)
val token = handler.authenticate(request, response)
Assert.assertNotNull(token)
Mockito.verify(response).setStatus(HttpServletResponse.SC_OK)
Assert.assertEquals("bjones", token.getUserName)
Assert.assertEquals("bjones", token.getName)
}
@Test
def testRequestWithoutAuthorization(): Unit = {
val request = Mockito.mock(classOf[HttpServletRequest])
val response = Mockito.mock(classOf[HttpServletResponse])
Assert.assertNull(handler.authenticate(request, response))
Mockito.verify(response).setHeader(WWW_AUTHENTICATE_HEADER, BASIC)
Mockito.verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED)
}
@Test
def testRequestWithInvalidAuthorization(): Unit = {
val request = Mockito.mock(classOf[HttpServletRequest])
val response = Mockito.mock(classOf[HttpServletResponse])
val base64 = new Base64
val credentials = "bjones:invalidpassword"
Mockito.when(request.getHeader(AUTHORIZATION_HEADER)).
thenReturn(base64.encodeToString(credentials.getBytes))
Assert.assertNull(handler.authenticate(request, response))
Mockito.verify(response).setHeader(WWW_AUTHENTICATE_HEADER, BASIC)
Mockito.verify(response).setStatus(HttpServletResponse.SC_UNAUTHORIZED)
}
@Test
def testRequestWithIncompleteAuthorization(): Unit = {
val request = Mockito.mock(classOf[HttpServletRequest])
val response = Mockito.mock(classOf[HttpServletResponse])
Mockito.when(request.getHeader(AUTHORIZATION_HEADER)).thenReturn(BASIC)
Assert.assertNull(handler.authenticate(request, response))
}
@Test
def testRequestWithWrongCredentials(): Unit = {
val request = Mockito.mock(classOf[HttpServletRequest])
val response = Mockito.mock(classOf[HttpServletResponse])
val base64 = new Base64
val credentials = base64.encodeToString("bjones:foo123".getBytes)
val authHeader = BASIC + " " + credentials
Mockito.when(request.getHeader(AUTHORIZATION_HEADER)).thenReturn(authHeader)
try {
handler.authenticate(request, response)
Assert.fail
} catch {
case ex: AuthenticationException =>
// Expected
case ex: Exception =>
Assert.fail
}
}
}