blob: d53f8e3278689869468f0eb5f8f8b2282b33070b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.record.sink.db
import org.apache.nifi.attribute.expression.language.StandardPropertyValue
import org.apache.nifi.components.PropertyValue
import org.apache.nifi.components.state.StateManager
import org.apache.nifi.controller.ConfigurationContext
import org.apache.nifi.controller.ControllerServiceInitializationContext
import org.apache.nifi.dbcp.DBCPConnectionPool
import org.apache.nifi.dbcp.DBCPService
import org.apache.nifi.logging.ComponentLog
import org.apache.nifi.record.sink.RecordSinkService
import org.apache.nifi.reporting.InitializationException
import org.apache.nifi.serialization.RecordSetWriterFactory
import org.apache.nifi.serialization.SimpleRecordSchema
import org.apache.nifi.serialization.WriteResult
import org.apache.nifi.serialization.record.ListRecordSet
import org.apache.nifi.serialization.record.MapRecord
import org.apache.nifi.serialization.record.MockRecordWriter
import org.apache.nifi.serialization.record.RecordField
import org.apache.nifi.serialization.record.RecordFieldType
import org.apache.nifi.serialization.record.RecordSchema
import org.apache.nifi.serialization.record.RecordSet
import org.apache.nifi.state.MockStateManager
import org.apache.nifi.util.MockControllerServiceInitializationContext
import org.apache.nifi.util.MockPropertyValue
import org.apache.nifi.util.file.FileUtils
import org.junit.AfterClass
import org.junit.BeforeClass
import org.junit.Test
import java.sql.DriverManager
import java.sql.ResultSet
import java.sql.SQLException
import java.sql.SQLNonTransientConnectionException
import java.sql.Statement
import static org.apache.nifi.dbcp.DBCPConnectionPool.DATABASE_URL
import static org.apache.nifi.dbcp.DBCPConnectionPool.DB_DRIVERNAME
import static org.apache.nifi.dbcp.DBCPConnectionPool.DB_DRIVER_LOCATION
import static org.apache.nifi.dbcp.DBCPConnectionPool.DB_PASSWORD
import static org.apache.nifi.dbcp.DBCPConnectionPool.DB_USER
import static org.apache.nifi.dbcp.DBCPConnectionPool.EVICTION_RUN_PERIOD
import static org.apache.nifi.dbcp.DBCPConnectionPool.KERBEROS_CREDENTIALS_SERVICE
import static org.apache.nifi.dbcp.DBCPConnectionPool.KERBEROS_PASSWORD
import static org.apache.nifi.dbcp.DBCPConnectionPool.KERBEROS_PRINCIPAL
import static org.apache.nifi.dbcp.DBCPConnectionPool.KERBEROS_USER_SERVICE
import static org.apache.nifi.dbcp.DBCPConnectionPool.MAX_CONN_LIFETIME
import static org.apache.nifi.dbcp.DBCPConnectionPool.MAX_IDLE
import static org.apache.nifi.dbcp.DBCPConnectionPool.MAX_TOTAL_CONNECTIONS
import static org.apache.nifi.dbcp.DBCPConnectionPool.MAX_WAIT_TIME
import static org.apache.nifi.dbcp.DBCPConnectionPool.MIN_EVICTABLE_IDLE_TIME
import static org.apache.nifi.dbcp.DBCPConnectionPool.MIN_IDLE
import static org.apache.nifi.dbcp.DBCPConnectionPool.SOFT_MIN_EVICTABLE_IDLE_TIME
import static org.apache.nifi.dbcp.DBCPConnectionPool.VALIDATION_QUERY
import static org.junit.Assert.assertEquals
import static org.junit.Assert.assertFalse
import static org.junit.Assert.assertNotNull
import static org.junit.Assert.assertTrue
import static org.junit.Assert.fail
import static org.mockito.Mockito.mock
import static org.mockito.Mockito.when
class DatabaseRecordSinkTest {
final static String DB_LOCATION = "target/db"
DBCPService dbcpService
@BeforeClass
static void setup() {
System.setProperty("derby.stream.error.file", "target/derby.log")
}
@AfterClass
static void cleanUpAfterClass() throws Exception {
try {
DriverManager.getConnection("jdbc:derby:" + DB_LOCATION + ";shutdown=true")
} catch (SQLNonTransientConnectionException ignore) {
// Do nothing, this is what happens at Derby shutdown
}
// remove previous test database, if any
final File dbLocation = new File(DB_LOCATION)
try {
FileUtils.deleteFile(dbLocation, true)
} catch (IOException ignore) {
// Do nothing, may not have existed
}
}
private ConfigurationContext context
@Test
void testRecordFormat() throws IOException, InitializationException {
DatabaseRecordSink task = initTask('TESTTABLE')
// Create the table
Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
def con = DriverManager.getConnection("jdbc:derby:${DB_LOCATION};create=true")
final Statement stmt = con.createStatement()
try {
stmt.execute("drop table TESTTABLE");
} catch (final SQLException sqle) {
// Ignore, usually due to Derby not having DROP TABLE IF EXISTS
}
try {
stmt.executeUpdate('CREATE TABLE testTable (field1 integer, field2 varchar(20))')
} finally {
stmt.close()
}
List<RecordField> recordFields = Arrays.asList(
new RecordField("field1", RecordFieldType.INT.getDataType()),
new RecordField("field2", RecordFieldType.STRING.getDataType())
)
RecordSchema recordSchema = new SimpleRecordSchema(recordFields)
Map<String, Object> row1 = new HashMap<>()
row1.put('field1', 15)
row1.put('field2', 'Hello')
Map<String, Object> row2 = new HashMap<>()
row2.put('field1', 6)
row2.put('field2', 'World!')
RecordSet recordSet = new ListRecordSet(recordSchema, Arrays.asList(
new MapRecord(recordSchema, row1),
new MapRecord(recordSchema, row2)
))
WriteResult writeResult = task.sendData(recordSet, ['a': 'Hello'], true)
assertNotNull(writeResult)
assertEquals(2, writeResult.recordCount)
assertEquals('Hello', writeResult.attributes['a'])
final Statement st = con.createStatement()
final ResultSet resultSet = st.executeQuery('select * from testTable')
assertTrue(resultSet.next())
def f1 = resultSet.getObject(1)
assertNotNull(f1)
assertTrue(f1 instanceof Integer)
assertEquals(15, f1)
def f2 = resultSet.getObject(2)
assertNotNull(f2)
assertTrue(f2 instanceof String)
assertEquals('Hello', f2)
assertTrue(resultSet.next())
f1 = resultSet.getObject(1)
assertNotNull(f1)
assertTrue(f1 instanceof Integer)
assertEquals(6, f1)
f2 = resultSet.getObject(2)
assertNotNull(f2)
assertTrue(f2 instanceof String)
assertEquals('World!', f2)
assertFalse(resultSet.next())
}
@Test(expected = IOException.class)
void testMissingTable() throws IOException, InitializationException {
DatabaseRecordSink task = initTask('NO_SUCH_TABLE')
List<RecordField> recordFields = Arrays.asList(
new RecordField("field1", RecordFieldType.INT.getDataType()),
new RecordField("field2", RecordFieldType.STRING.getDataType())
)
RecordSchema recordSchema = new SimpleRecordSchema(recordFields)
Map<String, Object> row1 = new HashMap<>()
row1.put('field1', 15)
row1.put('field2', 'Hello')
RecordSet recordSet = new ListRecordSet(recordSchema, Collections.singletonList(new MapRecord(recordSchema, row1)))
task.sendData(recordSet, new HashMap<>(), true)
fail('Should have generated an exception for table not present')
}
@Test(expected = IOException.class)
void testMissingField() throws IOException, InitializationException {
DatabaseRecordSink task = initTask('TESTTABLE')
// Create the table
Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
def con = DriverManager.getConnection("jdbc:derby:${DB_LOCATION};create=true")
final Statement stmt = con.createStatement()
try {
stmt.execute("drop table TESTTABLE");
} catch (final SQLException sqle) {
// Ignore, usually due to Derby not having DROP TABLE IF EXISTS
}
try {
stmt.executeUpdate('CREATE TABLE testTable (field1 integer, field2 varchar(20) not null)')
} finally {
stmt.close()
}
List<RecordField> recordFields = Arrays.asList(
new RecordField("field1", RecordFieldType.INT.getDataType())
)
RecordSchema recordSchema = new SimpleRecordSchema(recordFields)
Map<String, Object> row1 = new HashMap<>()
row1.put('field1', 15)
row1.put('field2', 'Hello')
row1.put('field3', 'fail')
RecordSet recordSet = new ListRecordSet(recordSchema, Collections.singletonList(new MapRecord(recordSchema, row1)))
task.sendData(recordSet, new HashMap<>(), true)
fail('Should have generated an exception for column not present')
}
@Test(expected = IOException.class)
void testMissingColumn() throws IOException, InitializationException {
DatabaseRecordSink task = initTask('TESTTABLE')
// Create the table
Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
def con = DriverManager.getConnection("jdbc:derby:${DB_LOCATION};create=true")
final Statement stmt = con.createStatement()
try {
stmt.execute("drop table TESTTABLE");
} catch (final SQLException sqle) {
// Ignore, usually due to Derby not having DROP TABLE IF EXISTS
}
try {
stmt.executeUpdate('CREATE TABLE testTable (field1 integer, field2 varchar(20))')
} finally {
stmt.close()
}
List<RecordField> recordFields = Arrays.asList(
new RecordField("field1", RecordFieldType.INT.getDataType()),
new RecordField("field2", RecordFieldType.STRING.getDataType()),
new RecordField("field3", RecordFieldType.STRING.getDataType())
)
RecordSchema recordSchema = new SimpleRecordSchema(recordFields)
Map<String, Object> row1 = new HashMap<>()
row1.put('field1', 15)
RecordSet recordSet = new ListRecordSet(recordSchema, Collections.singletonList(new MapRecord(recordSchema, row1)))
task.sendData(recordSet, new HashMap<>(), true)
fail('Should have generated an exception for field not present')
}
DatabaseRecordSink initTask(String tableName) throws InitializationException, IOException {
final ComponentLog logger = mock(ComponentLog.class)
final DatabaseRecordSink task = new DatabaseRecordSink()
context = mock(ConfigurationContext.class)
final StateManager stateManager = new MockStateManager(task)
final PropertyValue pValue = mock(StandardPropertyValue.class)
final MockRecordWriter writer = new MockRecordWriter(null, false) // No header, don't quote values
when(context.getProperty(RecordSinkService.RECORD_WRITER_FACTORY)).thenReturn(pValue)
when(pValue.asControllerService(RecordSetWriterFactory.class)).thenReturn(writer)
when(context.getProperty(DatabaseRecordSink.CATALOG_NAME)).thenReturn(new MockPropertyValue(null))
when(context.getProperty(DatabaseRecordSink.SCHEMA_NAME)).thenReturn(new MockPropertyValue(null))
when(context.getProperty(DatabaseRecordSink.TABLE_NAME)).thenReturn(new MockPropertyValue(tableName ?: 'TESTTABLE'))
when(context.getProperty(DatabaseRecordSink.QUOTED_IDENTIFIERS)).thenReturn(new MockPropertyValue('false'))
when(context.getProperty(DatabaseRecordSink.QUOTED_TABLE_IDENTIFIER)).thenReturn(new MockPropertyValue('true'))
when(context.getProperty(DatabaseRecordSink.QUERY_TIMEOUT)).thenReturn(new MockPropertyValue('5 sec'))
when(context.getProperty(DatabaseRecordSink.TRANSLATE_FIELD_NAMES)).thenReturn(new MockPropertyValue('true'))
when(context.getProperty(DatabaseRecordSink.UNMATCHED_FIELD_BEHAVIOR)).thenReturn(new MockPropertyValue(DatabaseRecordSink.FAIL_UNMATCHED_FIELD.value))
when(context.getProperty(DatabaseRecordSink.UNMATCHED_COLUMN_BEHAVIOR)).thenReturn(new MockPropertyValue(DatabaseRecordSink.FAIL_UNMATCHED_COLUMN.value))
// Set up the DBCPService to connect to a temp H2 database
dbcpService = new DBCPConnectionPool()
when(pValue.asControllerService(DBCPService.class)).thenReturn(dbcpService)
when(context.getProperty(DatabaseRecordSink.DBCP_SERVICE)).thenReturn(pValue)
final ConfigurationContext dbContext = mock(ConfigurationContext.class)
final StateManager dbStateManager = new MockStateManager(dbcpService)
when(dbContext.getProperty(DATABASE_URL)).thenReturn(new MockPropertyValue("jdbc:derby:${DB_LOCATION}"))
when(dbContext.getProperty(DB_USER)).thenReturn(new MockPropertyValue(null))
when(dbContext.getProperty(DB_PASSWORD)).thenReturn(new MockPropertyValue(null))
when(dbContext.getProperty(DB_DRIVERNAME)).thenReturn(new MockPropertyValue('org.apache.derby.jdbc.EmbeddedDriver'))
when(dbContext.getProperty(DB_DRIVER_LOCATION)).thenReturn(new MockPropertyValue(''))
when(dbContext.getProperty(MAX_TOTAL_CONNECTIONS)).thenReturn(new MockPropertyValue('1'))
when(dbContext.getProperty(VALIDATION_QUERY)).thenReturn(new MockPropertyValue(''))
when(dbContext.getProperty(MAX_WAIT_TIME)).thenReturn(new MockPropertyValue('5 sec'))
when(dbContext.getProperty(MIN_IDLE)).thenReturn(new MockPropertyValue('0'))
when(dbContext.getProperty(MAX_IDLE)).thenReturn(new MockPropertyValue('0'))
when(dbContext.getProperty(MAX_CONN_LIFETIME)).thenReturn(new MockPropertyValue('5 sec'))
when(dbContext.getProperty(EVICTION_RUN_PERIOD)).thenReturn(new MockPropertyValue('5 sec'))
when(dbContext.getProperty(MIN_EVICTABLE_IDLE_TIME)).thenReturn(new MockPropertyValue('5 sec'))
when(dbContext.getProperty(SOFT_MIN_EVICTABLE_IDLE_TIME)).thenReturn(new MockPropertyValue('5 sec'))
when(dbContext.getProperty(KERBEROS_CREDENTIALS_SERVICE)).thenReturn(new MockPropertyValue(null))
when(dbContext.getProperty(KERBEROS_USER_SERVICE)).thenReturn(new MockPropertyValue(null))
when(dbContext.getProperty(KERBEROS_PRINCIPAL)).thenReturn(new MockPropertyValue(null))
when(dbContext.getProperty(KERBEROS_PASSWORD)).thenReturn(new MockPropertyValue(null))
final ControllerServiceInitializationContext dbInitContext = new MockControllerServiceInitializationContext(dbcpService, UUID.randomUUID().toString(), logger, dbStateManager)
dbcpService.initialize(dbInitContext)
dbcpService.onConfigured(dbContext)
final ControllerServiceInitializationContext initContext = new MockControllerServiceInitializationContext(writer, UUID.randomUUID().toString(), logger, stateManager)
task.initialize(initContext)
task.onEnabled(context)
return task
}
}