blob: d2c69b96a714e5b765ae3a5715daef55810672f7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.predictionio.data.storage.elasticsearch
import org.apache.http.HttpHost
import org.apache.http.auth.{AuthScope, UsernamePasswordCredentials}
import org.apache.http.impl.client.BasicCredentialsProvider
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder
import org.apache.predictionio.data.storage.BaseStorageClient
import org.apache.predictionio.data.storage.StorageClientConfig
import org.apache.predictionio.data.storage.StorageClientException
import org.apache.predictionio.workflow.CleanupFunctions
import org.elasticsearch.client.RestClient
import org.elasticsearch.client.RestClientBuilder.HttpClientConfigCallback
import grizzled.slf4j.Logging
object ESClient extends Logging {
private var _sharedRestClient: Option[RestClient] = None
def open(
hosts: Seq[HttpHost],
basicAuth: Option[(String, String)] = None): RestClient = {
try {
val newClient = _sharedRestClient match {
case Some(c) => c
case None => {
var builder = RestClient.builder(hosts: _*)
builder = basicAuth match {
case Some((username, password)) => builder.setHttpClientConfigCallback(
new BasicAuthProvider(username, password))
case None => builder}
builder.build()
}
}
_sharedRestClient = Some(newClient)
newClient
} catch {
case e: Throwable =>
throw new StorageClientException(e.getMessage, e)
}
}
def close(): Unit = {
_sharedRestClient.foreach { client =>
client.close()
_sharedRestClient = None
}
}
}
class StorageClient(val config: StorageClientConfig)
extends BaseStorageClient with Logging {
override val prefix = "ES"
val usernamePassword = (
config.properties.get("USERNAME"),
config.properties.get("PASSWORD"))
val optionalBasicAuth: Option[(String, String)] = usernamePassword match {
case (None, None) => None
case (username, password) => Some(
(username.getOrElse(""), password.getOrElse("")))
}
CleanupFunctions.add { ESClient.close }
val client = ESClient.open(ESUtils.getHttpHosts(config), optionalBasicAuth)
}
class BasicAuthProvider(
val username: String,
val password: String)
extends HttpClientConfigCallback {
val credentialsProvider = new BasicCredentialsProvider()
credentialsProvider.setCredentials(
AuthScope.ANY,
new UsernamePasswordCredentials(username, password))
override def customizeHttpClient(
httpClientBuilder: HttpAsyncClientBuilder
): HttpAsyncClientBuilder = {
httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider)
}
}