JAV-181 conditional traversal

Signed-off-by: seanyinx <sean.yin@huawei.com>
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/TaskRunner.java b/saga-core/src/main/java/io/servicecomb/saga/core/TaskRunner.java
index 3ec5afc..40c47d5 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/TaskRunner.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/TaskRunner.java
@@ -51,7 +51,7 @@
     }
 
     while (traveller.hasNext()) {
-      traveller.next();
+      traveller.next(null);
       taskConsumer.consume(nodes);
       nodes.clear();
     }
@@ -62,7 +62,7 @@
     boolean played = false;
     Collection<Node<SagaResponse, SagaRequest>> nodes = traveller.nodes();
     while (traveller.hasNext() && !played) {
-      traveller.next();
+      traveller.next(null);
       played = taskConsumer.replay(nodes, completedOperations);
     }
   }
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/application/GraphBuilder.java b/saga-core/src/main/java/io/servicecomb/saga/core/application/GraphBuilder.java
index 6558aeb..0101748 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/application/GraphBuilder.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/application/GraphBuilder.java
@@ -16,16 +16,18 @@
 
 package io.servicecomb.saga.core.application;
 
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
 import io.servicecomb.saga.core.NoOpSagaRequest;
 import io.servicecomb.saga.core.SagaException;
 import io.servicecomb.saga.core.SagaRequest;
 import io.servicecomb.saga.core.SagaResponse;
+import io.servicecomb.saga.core.dag.Edge;
 import io.servicecomb.saga.core.dag.GraphCycleDetector;
 import io.servicecomb.saga.core.dag.Node;
 import io.servicecomb.saga.core.dag.SingleLeafDirectedAcyclicGraph;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
 import kamon.annotation.EnableKamon;
 import kamon.annotation.Segment;
 
@@ -56,9 +58,11 @@
 
     for (SagaRequest sagaRequest : sagaRequests) {
       if (isOrphan(sagaRequest)) {
+        new Edge<>((any) -> true, root, requestNodes.get(sagaRequest.id()));
         root.addChild(requestNodes.get(sagaRequest.id()));
       } else {
         for (String parent : sagaRequest.parents()) {
+          new Edge<>((any) -> true, requestNodes.get(parent), requestNodes.get(sagaRequest.id()));
           requestNodes.get(parent).addChild(requestNodes.get(sagaRequest.id()));
         }
       }
@@ -66,7 +70,10 @@
 
     requestNodes.values().stream()
         .filter((node) -> node.children().isEmpty())
-        .forEach(node -> node.addChild(leaf));
+        .forEach(node -> {
+          new Edge<>((any) -> true, node, leaf);
+          node.addChild(leaf);
+        });
 
     return new SingleLeafDirectedAcyclicGraph<>(root, leaf);
   }
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/ByLevelTraveller.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/ByLevelTraveller.java
index 60b2e0e..6c6bde4 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/ByLevelTraveller.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/ByLevelTraveller.java
@@ -24,6 +24,7 @@
 import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
+import java.util.function.Consumer;
 
 import kamon.annotation.EnableKamon;
 import kamon.annotation.Segment;
@@ -49,25 +50,29 @@
 
   @Segment(name = "travelNext", category = "application", library = "kamon")
   @Override
-  public void next() {
-    nodes.addAll(nodesBuffer);
-    nodesBuffer.clear();
-    boolean buffered = false;
+  public void next(C condition) {
+    do {
+      Set<Node<C, T>> orphans = new LinkedHashSet<>();
+      nodesBuffer.forEach(node -> {
+        if (!traversalDirection.parents(node, condition).isEmpty()) {
+          nodes.add(node);
+        } else {
+          nodesWithoutParent.remove(node);
+          collectOrphans(node, orphans::add);
+        }
+      });
+      nodesBuffer.clear();
+      nodesBuffer.addAll(orphans);
+    } while (!nodesBuffer.isEmpty());
 
-    while (!nodesWithoutParent.isEmpty() && !buffered) {
+    while (!nodesWithoutParent.isEmpty() && nodesBuffer.isEmpty()) {
       Node<C, T> node = nodesWithoutParent.poll();
       nodes.add(node);
 
-      for (Node<C, T> child : traversalDirection.children(node)) {
-        nodeParents.computeIfAbsent(child.id(), id -> new HashSet<>(traversalDirection.parents(child)));
-        nodeParents.get(child.id()).remove(node);
-
-        if (nodeParents.get(child.id()).isEmpty()) {
-          nodesWithoutParent.offer(child);
-          nodesBuffer.add(child);
-          buffered = true;
-        }
-      }
+      collectOrphans(node, child -> {
+        nodesWithoutParent.offer(child);
+        nodesBuffer.add(child);
+      });
     }
   }
 
@@ -80,4 +85,20 @@
   public Collection<Node<C, T>> nodes() {
     return nodes;
   }
+
+  private void collectOrphans(Node<C, T> node, Consumer<Node<C, T>> orphanConsumer) {
+    for (Node<C, T> child : traversalDirection.children(node)) {
+      removeNodeFromChildParents(node, child);
+
+      if (nodeParents.get(child.id()).isEmpty()) {
+        orphanConsumer.accept(child);
+      }
+    }
+  }
+
+  private void removeNodeFromChildParents(Node<C, T> node, Node<C, T> child) {
+    nodeParents
+        .computeIfAbsent(child.id(), id -> new HashSet<>(traversalDirection.parents(child)))
+        .remove(node);
+  }
 }
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromLeafTraversalDirection.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromLeafTraversalDirection.java
index 8dabb3c..5da19bb 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromLeafTraversalDirection.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromLeafTraversalDirection.java
@@ -34,4 +34,9 @@
   public Set<Node<C, T>> children(Node<C, T> node) {
     return node.parents();
   }
+
+  @Override
+  public Set<Node<C, T>> parents(Node<C, T> node, C condition) {
+    return node.children(condition);
+  }
 }
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromRootTraversalDirection.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromRootTraversalDirection.java
index 422b32b..a333400 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromRootTraversalDirection.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/FromRootTraversalDirection.java
@@ -34,4 +34,9 @@
   public Set<Node<C, T>> children(Node<C, T> node) {
     return node.children();
   }
+
+  @Override
+  public Set<Node<C, T>> parents(Node<C, T> node, C condition) {
+    return node.parents(condition);
+  }
 }
\ No newline at end of file
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/Node.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/Node.java
index c875567..b79a417 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/Node.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/Node.java
@@ -70,6 +70,13 @@
     parentEdges.add(edge);
   }
 
+  public Set<Node<C, T>> parents(C condition) {
+    return parentEdges.stream()
+        .filter(edge -> edge.isSatisfied(condition))
+        .map(Edge::source)
+        .collect(Collectors.toSet());
+  }
+
   public Set<Node<C, T>> children(C condition) {
     return childrenEdges.stream()
         .filter(edge -> edge.isSatisfied(condition))
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/Traveller.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/Traveller.java
index b7aeaa6..cdaebb5 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/Traveller.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/Traveller.java
@@ -20,7 +20,7 @@
 
 public interface Traveller<C, T> {
 
-  void next();
+  void next(C condition);
 
   boolean hasNext();
 
diff --git a/saga-core/src/main/java/io/servicecomb/saga/core/dag/TraversalDirection.java b/saga-core/src/main/java/io/servicecomb/saga/core/dag/TraversalDirection.java
index 3bee42a..36e7cdb 100644
--- a/saga-core/src/main/java/io/servicecomb/saga/core/dag/TraversalDirection.java
+++ b/saga-core/src/main/java/io/servicecomb/saga/core/dag/TraversalDirection.java
@@ -25,4 +25,6 @@
   Set<Node<C, T>> parents(Node<C, T> node);
 
   Set<Node<C, T>> children(Node<C, T> node);
+
+  Set<Node<C, T>> parents(Node<C, T> node, C condition);
 }
diff --git a/saga-core/src/test/java/io/servicecomb/saga/core/SagaIntegrationTest.java b/saga-core/src/test/java/io/servicecomb/saga/core/SagaIntegrationTest.java
index a159402..17aebd4 100644
--- a/saga-core/src/test/java/io/servicecomb/saga/core/SagaIntegrationTest.java
+++ b/saga-core/src/test/java/io/servicecomb/saga/core/SagaIntegrationTest.java
@@ -37,6 +37,8 @@
 import static org.mockito.Mockito.when;
 
 import com.seanyinx.github.unit.scaffolding.Randomness;
+
+import io.servicecomb.saga.core.dag.Edge;
 import io.servicecomb.saga.core.dag.Node;
 import io.servicecomb.saga.core.dag.SingleLeafDirectedAcyclicGraph;
 import io.servicecomb.saga.infrastructure.EmbeddedEventStore;
@@ -91,6 +93,9 @@
     root.addChild(node1);
     node1.addChild(node2);
     node2.addChild(leaf);
+    new Edge<>((any) -> true, root, node1);
+    new Edge<>((any) -> true, node1, node2);
+    new Edge<>((any) -> true, node2, leaf);
 
     SagaStartTask sagaStartTask = new SagaStartTask(sagaId, requestJson, eventStore);
     SagaEndTask sagaEndTask = new SagaEndTask(sagaId, eventStore);
@@ -508,6 +513,8 @@
   private void addExtraChildToNode1() {
     node1.addChild(node3);
     node3.addChild(leaf);
+    new Edge<>((any) -> true, node1, node3);
+    new Edge<>((any) -> true, node3, leaf);
   }
 
   private SagaRequest request(String requestId,
diff --git a/saga-core/src/test/java/io/servicecomb/saga/core/application/GraphBuilderTest.java b/saga-core/src/test/java/io/servicecomb/saga/core/application/GraphBuilderTest.java
index 84c6645..ce39edb 100644
--- a/saga-core/src/test/java/io/servicecomb/saga/core/application/GraphBuilderTest.java
+++ b/saga-core/src/test/java/io/servicecomb/saga/core/application/GraphBuilderTest.java
@@ -104,19 +104,19 @@
     Traveller<SagaResponse, SagaRequest> traveller = new ByLevelTraveller<>(tasks, new FromRootTraversalDirection<>());
     Collection<Node<SagaResponse, SagaRequest>> nodes = traveller.nodes();
 
-    traveller.next();
+    traveller.next(null);
     assertThat(requestsOf(nodes), contains(SAGA_START_REQUEST));
     nodes.clear();
 
-    traveller.next();
+    traveller.next(null);
     assertThat(requestsOf(nodes), contains(request1, request2));
     nodes.clear();
 
-    traveller.next();
+    traveller.next(null);
     assertThat(requestsOf(nodes), contains(request3));
     nodes.clear();
 
-    traveller.next();
+    traveller.next(null);
     assertThat(requestsOf(nodes), contains(SAGA_END_REQUEST));
   }
 
diff --git a/saga-core/src/test/java/io/servicecomb/saga/core/dag/DirectedAcyclicGraphTraversalTest.java b/saga-core/src/test/java/io/servicecomb/saga/core/dag/DirectedAcyclicGraphTraversalTest.java
index 4f9a22b..e49db74 100644
--- a/saga-core/src/test/java/io/servicecomb/saga/core/dag/DirectedAcyclicGraphTraversalTest.java
+++ b/saga-core/src/test/java/io/servicecomb/saga/core/dag/DirectedAcyclicGraphTraversalTest.java
@@ -18,25 +18,40 @@
 
 import static java.util.Arrays.asList;
 import static org.hamcrest.collection.IsIterableContainingInOrder.contains;
+import static org.hamcrest.core.Is.is;
 import static org.junit.Assert.assertThat;
+import static org.mockito.Mockito.when;
 
 import java.util.Collection;
+import java.util.function.Predicate;
 
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.Mockito;
+
+import com.seanyinx.github.unit.scaffolding.Randomness;
 
 @SuppressWarnings("unchecked")
 public class DirectedAcyclicGraphTraversalTest {
 
-  private final String value = "i don't care";
+  private final Node<String, String> root = new Node<>(0, "root");
+  private final Node<String, String> node1 = new Node<>(1, "node1");
+  private final Node<String, String> node2 = new Node<>(2, "node2");
+  private final Node<String, String> node3 = new Node<>(3, "node3");
+  private final Node<String, String> node4 = new Node<>(4, "node4");
+  private final Node<String, String> node5 = new Node<>(5, "node5");
+  private final Node<String, String> leaf = new Node<>(6, "leaf");
 
-  private final Node<String, String> root = new Node<>(0, value);
-  private final Node<String, String> node1 = new Node<>(1, value);
-  private final Node<String, String> node2 = new Node<>(2, value);
-  private final Node<String, String> node3 = new Node<>(3, value);
-  private final Node<String, String> node4 = new Node<>(4, value);
-  private final Node<String, String> node5 = new Node<>(5, value);
-  private final Node<String, String> leaf = new Node<>(6, value);
+  private final String condition = Randomness.uniquify("condition");
+
+  private final Predicate<String> predicate_p_1 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_p_2 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_1_3 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_1_4 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_3_5 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_4_5 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_2_6 = Mockito.mock(Predicate.class);
+  private final Predicate<String> predicate_5_6 = Mockito.mock(Predicate.class);
 
   private final SingleLeafDirectedAcyclicGraph<String, String> dag = new SingleLeafDirectedAcyclicGraph<>(root, leaf);
 
@@ -57,49 +72,137 @@
     node4.addChild(node5);
     node5.addChild(leaf);
     node2.addChild(leaf);
+
+    new Edge<>(predicate_p_1, root, node1);
+    new Edge<>(predicate_p_2, root, node2);
+    new Edge<>(predicate_1_3, node1, node3);
+    new Edge<>(predicate_1_4, node1, node4);
+    new Edge<>(predicate_3_5, node3, node5);
+    new Edge<>(predicate_4_5, node4, node5);
+    new Edge<>(predicate_2_6, node2, leaf);
+    new Edge<>(predicate_5_6, node5, leaf);
   }
 
   @Test
   public void traverseGraphOneLevelPerStepFromRoot() {
+    markAllSatisfied();
     Traveller<String, String> traveller = new ByLevelTraveller<>(dag, new FromRootTraversalDirection<>());
 
     Collection<Node<String, String>> nodes = traveller.nodes();
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(root));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(root, node1, node2));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(root, node1, node2, node3, node4));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(root, node1, node2, node3, node4, node5));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(root, node1, node2, node3, node4, node5, leaf));
+
+    assertThat(traveller.hasNext(), is(false));
+  }
+
+  @Test
+  public void traverseOnlySatisfiedChildrenFromRoot() {
+    Traveller<String, String> traveller = new ByLevelTraveller<>(dag, new FromRootTraversalDirection<>());
+
+    Collection<Node<String, String>> nodes = traveller.nodes();
+
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(root));
+
+    when(predicate_p_1.test(condition)).thenReturn(true);
+    when(predicate_p_2.test(condition)).thenReturn(true);
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(root, node1, node2));
+
+    when(predicate_1_4.test(condition)).thenReturn(true);
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(root, node1, node2, node4));
+
+    when(predicate_2_6.test(condition)).thenReturn(true);
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(root, node1, node2, node4, leaf));
+
+    assertThat(traveller.hasNext(), is(false));
   }
 
   @Test
   public void traverseGraphOneLevelPerStepFromLeaf() {
+    markAllSatisfied();
     Traveller<String, String> traveller = new ByLevelTraveller<>(dag, new FromLeafTraversalDirection<>());
 
     Collection<Node<String, String>> nodes = traveller.nodes();
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(leaf));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(leaf, node2, node5));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(leaf, node2, node5, node3, node4));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(leaf, node2, node5, node3, node4, node1));
 
-    traveller.next();
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
     assertThat(nodes, contains(leaf, node2, node5, node3, node4, node1, root));
+
+    assertThat(traveller.hasNext(), is(false));
+  }
+
+  @Test
+  public void traverseOnlySatisfiedNodesFromLeaf() {
+    Traveller<String, String> traveller = new ByLevelTraveller<>(dag, new FromLeafTraversalDirection<>());
+
+    Collection<Node<String, String>> nodes = traveller.nodes();
+
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(leaf));
+
+    when(predicate_2_6.test(condition)).thenReturn(true);
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(leaf, node2));
+
+    when(predicate_p_2.test(condition)).thenReturn(true);
+    assertThat(traveller.hasNext(), is(true));
+    traveller.next(condition);
+    assertThat(nodes, contains(leaf, node2, root));
+
+    assertThat(traveller.hasNext(), is(false));
+  }
+
+  private void markAllSatisfied() {
+    when(predicate_p_1.test(condition)).thenReturn(true);
+    when(predicate_p_2.test(condition)).thenReturn(true);
+    when(predicate_1_3.test(condition)).thenReturn(true);
+    when(predicate_1_4.test(condition)).thenReturn(true);
+    when(predicate_3_5.test(condition)).thenReturn(true);
+    when(predicate_4_5.test(condition)).thenReturn(true);
+    when(predicate_2_6.test(condition)).thenReturn(true);
+    when(predicate_5_6.test(condition)).thenReturn(true);
   }
 }
\ No newline at end of file
diff --git a/saga-core/src/test/java/io/servicecomb/saga/core/dag/NodeTest.java b/saga-core/src/test/java/io/servicecomb/saga/core/dag/NodeTest.java
index ef5d9de..f09da46 100644
--- a/saga-core/src/test/java/io/servicecomb/saga/core/dag/NodeTest.java
+++ b/saga-core/src/test/java/io/servicecomb/saga/core/dag/NodeTest.java
@@ -107,24 +107,31 @@
   }
 
   @Test
-  public void childrenContainsSatisfiedOnesOnly() throws Exception {
+  public void relativesContainsSatisfiedOnesOnly() throws Exception {
     satisfied_p_1 = true;
     assertThat(parent.children(condition), contains(node1));
+    assertThat(node1.parents(condition), contains(parent));
+    assertThat(node2.parents(condition).isEmpty(), is(true));
 
     satisfied_1_3 = true;
     satisfied_1_4 = true;
     assertThat(node1.children(condition), contains(node3, node4));
+    assertThat(node3.parents(condition), contains(node1));
+    assertThat(node4.parents(condition), contains(node1));
 
     assertThat(node2.children(condition).isEmpty(), is(true));
 
-    assertThat(node3.children(condition).isEmpty(), is(true));
+    satisfied_3_5 = true;
+    assertThat(node3.children(condition), contains(node5));
 
     satisfied_4_5 = true;
     assertThat(node4.children(condition), contains(node5));
+    assertThat(node5.parents(condition), contains(node3, node4));
 
     satisfied_5_6 = true;
     assertThat(node5.children(condition), contains(node6));
 
     assertThat(node6.children(condition).isEmpty(), is(true));
+    assertThat(node6.parents(condition), contains(node5));
   }
 }
\ No newline at end of file