blob: 396192d1d865a22658363179bd4af3d03fc194bc [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.samza.sql.translator;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.calcite.DataContext;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.logical.LogicalFilter;
import org.apache.samza.application.descriptors.StreamApplicationDescriptorImpl;
import org.apache.samza.context.ContainerContext;
import org.apache.samza.context.Context;
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.MessageStreamImpl;
import org.apache.samza.operators.functions.FilterFunction;
import org.apache.samza.operators.spec.OperatorSpec;
import org.apache.samza.operators.spec.StreamOperatorSpec;
import org.apache.samza.sql.data.Expression;
import org.apache.samza.sql.data.RexToJavaCompiler;
import org.apache.samza.sql.data.SamzaSqlExecutionContext;
import org.apache.samza.sql.data.SamzaSqlRelMessage;
import org.apache.samza.sql.data.SamzaSqlRelMsgMetadata;
import org.apache.samza.sql.runner.SamzaSqlApplicationContext;
import org.apache.samza.sql.util.TestMetricsRegistryImpl;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.internal.util.reflection.Whitebox;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.powermock.modules.junit4.PowerMockRunner;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
/**
* Tests for {@link FilterTranslator}
*/
@RunWith(PowerMockRunner.class)
@PrepareForTest(LogicalFilter.class)
public class TestFilterTranslator extends TranslatorTestBase {
private static final String LOGICAL_OP_ID = "sql0_filter_0";
@Test
public void testTranslate() throws IOException, ClassNotFoundException {
// setup mock values to the constructor of FilterTranslator
LogicalFilter mockFilter = PowerMockito.mock(LogicalFilter.class);
Context mockContext = mock(Context.class);
ContainerContext mockContainerContext = mock(ContainerContext.class);
TranslatorContext mockTranslatorContext = mock(TranslatorContext.class);
TestMetricsRegistryImpl metricsRegistry = new TestMetricsRegistryImpl();
RelNode mockInput = mock(RelNode.class);
when(mockFilter.getInput()).thenReturn(mockInput);
when(mockInput.getId()).thenReturn(1);
when(mockFilter.getId()).thenReturn(2);
StreamApplicationDescriptorImpl mockGraph = mock(StreamApplicationDescriptorImpl.class);
OperatorSpec<Object, SamzaSqlRelMessage> mockInputOp = mock(OperatorSpec.class);
MessageStream<SamzaSqlRelMessage> mockStream = new MessageStreamImpl<>(mockGraph, mockInputOp);
when(mockTranslatorContext.getMessageStream(eq(1))).thenReturn(mockStream);
doAnswer(this.getRegisterMessageStreamAnswer()).when(mockTranslatorContext).registerMessageStream(eq(2), any(MessageStream.class));
RexToJavaCompiler mockCompiler = mock(RexToJavaCompiler.class);
when(mockTranslatorContext.getExpressionCompiler()).thenReturn(mockCompiler);
Expression mockExpr = mock(Expression.class);
when(mockCompiler.compile(any(), any())).thenReturn(mockExpr);
when(mockContext.getContainerContext()).thenReturn(mockContainerContext);
when(mockContainerContext.getContainerMetricsRegistry()).thenReturn(metricsRegistry);
// Apply translate() method to verify that we are getting the correct filter operator constructed
FilterTranslator filterTranslator = new FilterTranslator(1);
filterTranslator.translate(mockFilter, LOGICAL_OP_ID, mockTranslatorContext);
// make sure that context has been registered with LogicFilter and output message streams
verify(mockTranslatorContext, times(1)).registerRelNode(2, mockFilter);
verify(mockTranslatorContext, times(1)).registerMessageStream(2, this.getRegisteredMessageStream(2));
when(mockTranslatorContext.getRelNode(2)).thenReturn(mockFilter);
when(mockTranslatorContext.getMessageStream(2)).thenReturn(this.getRegisteredMessageStream(2));
StreamOperatorSpec filterSpec = (StreamOperatorSpec) Whitebox.getInternalState(this.getRegisteredMessageStream(2), "operatorSpec");
assertNotNull(filterSpec);
assertEquals(filterSpec.getOpCode(), OperatorSpec.OpCode.FILTER);
// Verify that the describe() method will establish the context for the filter function
Map<Integer, TranslatorContext> mockContexts = new HashMap<>();
mockContexts.put(1, mockTranslatorContext);
when(mockContext.getApplicationTaskContext()).thenReturn(new SamzaSqlApplicationContext(mockContexts));
filterSpec.getTransformFn().init(mockContext);
FilterFunction filterFn = (FilterFunction) Whitebox.getInternalState(filterSpec, "filterFn");
assertNotNull(filterFn);
assertEquals(mockTranslatorContext, Whitebox.getInternalState(filterFn, "translatorContext"));
assertEquals(mockFilter, Whitebox.getInternalState(filterFn, "filter"));
assertEquals(mockExpr, Whitebox.getInternalState(filterFn, "expr"));
// Verify MetricsRegistry works with Project
assertEquals(1, metricsRegistry.getGauges().size());
assertTrue(metricsRegistry.getGauges().get(LOGICAL_OP_ID).size() > 0);
assertEquals(1, metricsRegistry.getCounters().size());
assertEquals(3, metricsRegistry.getCounters().get(LOGICAL_OP_ID).size());
assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(0, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
// Calling filterFn.apply() to verify the filter function is correctly applied to the input message
SamzaSqlRelMessage mockInputMsg = new SamzaSqlRelMessage(new ArrayList<>(), new ArrayList<>(),
new SamzaSqlRelMsgMetadata(0L, 0L));
SamzaSqlExecutionContext executionContext = mock(SamzaSqlExecutionContext.class);
DataContext dataContext = mock(DataContext.class);
when(mockTranslatorContext.getExecutionContext()).thenReturn(executionContext);
when(mockTranslatorContext.getDataContext()).thenReturn(dataContext);
Object[] result = new Object[1];
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = new Boolean(true);
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext),
eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
assertTrue(filterFn.apply(mockInputMsg));
doAnswer(invocation -> {
Object[] retValue = invocation.getArgumentAt(4, Object[].class);
retValue[0] = new Boolean(false);
return null;
}).when(mockExpr).execute(eq(executionContext), eq(mockContext), eq(dataContext),
eq(mockInputMsg.getSamzaSqlRelRecord().getFieldValues().toArray()), eq(result));
assertFalse(filterFn.apply(mockInputMsg));
// Verify filterFn.apply() updates the MetricsRegistry metrics
assertEquals(2, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(0).getCount());
assertEquals(1, metricsRegistry.getCounters().get(LOGICAL_OP_ID).get(1).getCount());
}
}