blob: e5f4c9ee9f6060a789c47a7e2674f23c7cef0e3c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.cep.nfa;
import org.apache.flink.cep.Event;
import org.apache.flink.cep.nfa.sharedbuffer.SharedBuffer;
import org.apache.flink.cep.nfa.sharedbuffer.SharedBufferAccessor;
import org.apache.flink.cep.pattern.Pattern;
import org.apache.flink.cep.pattern.conditions.BooleanConditions;
import org.apache.flink.cep.pattern.conditions.IterativeCondition;
import org.apache.flink.cep.pattern.conditions.SimpleCondition;
import org.apache.flink.cep.utils.TestSharedBuffer;
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.util.TestLogger;
import org.junit.Test;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.apache.flink.cep.utils.NFAUtils.compile;
import static org.junit.Assert.assertEquals;
/**
* Tests for {@link NFA}.
*/
public class NFATest extends TestLogger {
@Test
public void testSimpleNFA() throws Exception {
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "bar", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "start", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 4L));
State<Event> startState = new State<>("start", State.StateType.Start);
State<Event> endState = new State<>("end", State.StateType.Normal);
State<Event> endingState = new State<>("", State.StateType.Final);
startState.addTake(
endState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = -4869589195918650396L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("start");
}
});
endState.addTake(
endingState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = 2979804163709590673L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("end");
}
});
endState.addIgnore(BooleanConditions.<Event>trueFunction());
List<State<Event>> states = new ArrayList<>();
states.add(startState);
states.add(endState);
states.add(endingState);
NFA<Event> nfa = new NFA<>(states, 0, false);
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> firstPattern = new HashMap<>();
firstPattern.put("start", Collections.singletonList(new Event(1, "start", 1.0)));
firstPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(3, "start", 3.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(firstPattern);
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, nfa.createInitialNFAState(), streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
@Test
public void testTimeoutWindowPruning() throws Exception {
NFA<Event> nfa = createStartEndNFA();
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "bar", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "start", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 4L));
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(3, "start", 3.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, nfa.createInitialNFAState(), streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
/**
* Tests that elements whose timestamp difference is exactly the window length are not matched.
* The reason is that the right window side (later elements) is exclusive.
*/
@Test
public void testWindowBorders() throws Exception {
NFA<Event> nfa = createStartEndNFA();
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "end", 2.0), 3L));
Set<Map<String, List<Event>>> expectedPatterns = Collections.emptySet();
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, nfa.createInitialNFAState(), streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
/**
* Tests that pruning shared buffer elements and computations state use the same window border
* semantics (left side inclusive and right side exclusive).
*/
@Test
public void testTimeoutWindowPruningWindowBorders() throws Exception {
NFA<Event> nfa = createStartEndNFA();
List<StreamRecord<Event>> streamEvents = new ArrayList<>();
streamEvents.add(new StreamRecord<>(new Event(1, "start", 1.0), 1L));
streamEvents.add(new StreamRecord<>(new Event(2, "start", 2.0), 2L));
streamEvents.add(new StreamRecord<>(new Event(3, "foobar", 3.0), 3L));
streamEvents.add(new StreamRecord<>(new Event(4, "end", 4.0), 3L));
Set<Map<String, List<Event>>> expectedPatterns = new HashSet<>();
Map<String, List<Event>> secondPattern = new HashMap<>();
secondPattern.put("start", Collections.singletonList(new Event(2, "start", 2.0)));
secondPattern.put("end", Collections.singletonList(new Event(4, "end", 4.0)));
expectedPatterns.add(secondPattern);
Collection<Map<String, List<Event>>> actualPatterns = runNFA(nfa, nfa.createInitialNFAState(), streamEvents);
assertEquals(expectedPatterns, actualPatterns);
}
public Collection<Map<String, List<Event>>> runNFA(
NFA<Event> nfa, NFAState nfaState, List<StreamRecord<Event>> inputs) throws Exception {
Set<Map<String, List<Event>>> actualPatterns = new HashSet<>();
SharedBuffer<Event> sharedBuffer = TestSharedBuffer.createTestBuffer(Event.createTypeSerializer());
try (SharedBufferAccessor<Event> sharedBufferAccessor = sharedBuffer.getAccessor()) {
for (StreamRecord<Event> streamEvent : inputs) {
nfa.advanceTime(sharedBufferAccessor, nfaState, streamEvent.getTimestamp());
Collection<Map<String, List<Event>>> matchedPatterns = nfa.process(
sharedBufferAccessor,
nfaState,
streamEvent.getValue(),
streamEvent.getTimestamp());
actualPatterns.addAll(matchedPatterns);
}
}
return actualPatterns;
}
@Test
public void testNFASerialization() throws Exception {
Pattern<Event, ?> pattern1 = Pattern.<Event>begin("start").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 1858562682635302605L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("a");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().optional().allowCombinations().followedByAny("end").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("d");
}
});
Pattern<Event, ?> pattern2 = Pattern.<Event>begin("start").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 1858562682635302605L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("a");
}
}).notFollowedBy("not").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = -6085237016591726715L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("c");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().optional().allowCombinations().followedByAny("end").where(new IterativeCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value, IterativeCondition.Context<Event> ctx) throws Exception {
double sum = 0.0;
for (Event e : ctx.getEventsForPattern("middle")) {
sum += e.getPrice();
}
return sum > 5.0;
}
});
Pattern<Event, ?> pattern3 = Pattern.<Event>begin("start")
.notFollowedBy("not").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = -6085237016591726715L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("c");
}
}).followedByAny("middle").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("b");
}
}).oneOrMore().allowCombinations().followedByAny("end").where(new SimpleCondition<Event>() {
private static final long serialVersionUID = 8061969839441121955L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("d");
}
});
List<Pattern<Event, ?>> patterns = new ArrayList<>();
patterns.add(pattern1);
patterns.add(pattern2);
patterns.add(pattern3);
SharedBuffer<Event> sharedBuffer = TestSharedBuffer.createTestBuffer(Event.createTypeSerializer());
try (SharedBufferAccessor<Event> sharedBufferAccessor = sharedBuffer.getAccessor()) {
for (Pattern<Event, ?> p : patterns) {
NFA<Event> nfa = compile(p, false);
Event a = new Event(40, "a", 1.0);
Event b = new Event(41, "b", 2.0);
Event c = new Event(42, "c", 3.0);
Event b1 = new Event(41, "b", 3.0);
Event b2 = new Event(41, "b", 4.0);
Event b3 = new Event(41, "b", 5.0);
Event d = new Event(43, "d", 4.0);
NFAState nfaState = nfa.createInitialNFAState();
nfa.process(sharedBufferAccessor, nfaState, a, 1);
nfa.process(sharedBufferAccessor, nfaState, b, 2);
nfa.process(sharedBufferAccessor, nfaState, c, 3);
nfa.process(sharedBufferAccessor, nfaState, b1, 4);
nfa.process(sharedBufferAccessor, nfaState, b2, 5);
nfa.process(sharedBufferAccessor, nfaState, b3, 6);
nfa.process(sharedBufferAccessor, nfaState, d, 7);
nfa.process(sharedBufferAccessor, nfaState, a, 8);
NFAStateSerializer serializer = NFAStateSerializer.INSTANCE;
//serialize
ByteArrayOutputStream baos = new ByteArrayOutputStream();
serializer.serialize(nfaState, new DataOutputViewStreamWrapper(baos));
baos.close();
// copy
ByteArrayInputStream in = new ByteArrayInputStream(baos.toByteArray());
ByteArrayOutputStream out = new ByteArrayOutputStream();
serializer.duplicate().copy(new DataInputViewStreamWrapper(in), new DataOutputViewStreamWrapper(out));
in.close();
out.close();
// deserialize
ByteArrayInputStream bais = new ByteArrayInputStream(out.toByteArray());
NFAState copy = serializer.duplicate().deserialize(new DataInputViewStreamWrapper(bais));
bais.close();
assertEquals(nfaState, copy);
}
}
}
private NFA<Event> createStartEndNFA() {
State<Event> startState = new State<>("start", State.StateType.Start);
State<Event> endState = new State<>("end", State.StateType.Normal);
State<Event> endingState = new State<>("", State.StateType.Final);
startState.addTake(
endState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = -4869589195918650396L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("start");
}
});
endState.addTake(
endingState,
new SimpleCondition<Event>() {
private static final long serialVersionUID = 2979804163709590673L;
@Override
public boolean filter(Event value) throws Exception {
return value.getName().equals("end");
}
});
endState.addIgnore(BooleanConditions.<Event>trueFunction());
List<State<Event>> states = new ArrayList<>();
states.add(startState);
states.add(endState);
states.add(endingState);
return new NFA<>(states, 2L, false);
}
}