RATIS-1913. Assert that the primary peers in DataStreamClient and RoutingTable are equal (#945)

diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
index 7a25a8e..353f532 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/DataStreamClientImpl.java
@@ -44,6 +44,7 @@
 import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
 import org.apache.ratis.util.JavaUtils;
 import org.apache.ratis.util.MemoizedSupplier;
+import org.apache.ratis.util.Preconditions;
 import org.apache.ratis.util.SlidingWindow;
 
 import java.io.IOException;
@@ -230,6 +231,12 @@
 
   @Override
   public DataStreamOutputRpc stream(ByteBuffer headerMessage, RoutingTable routingTable) {
+    if (routingTable != null) {
+      // Validate that the primary peer is equal to the primary peer passed by the RoutingTable
+      Preconditions.assertTrue(dataStreamServer.getId().equals(routingTable.getPrimary()),
+          () -> "Primary peer mismatched: the routing table has " + routingTable.getPrimary()
+              + " but the client has " + dataStreamServer.getId());
+    }
     final Message message =
         Optional.ofNullable(headerMessage).map(ByteString::copyFrom).map(Message::valueOf).orElse(null);
     RaftClientRequest request = RaftClientRequest.newBuilder()
diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
index 0157fe4..56181c4 100644
--- a/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
+++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RoutingTable.java
@@ -43,6 +43,9 @@
   /** @return the successor peers of the given peer. */
   Set<RaftPeerId> getSuccessors(RaftPeerId peerId);
 
+  /** @return the primary peer. */
+  RaftPeerId getPrimary();
+
   /** @return the proto of this {@link RoutingTable}. */
   RoutingTableProto toProto();
 
@@ -78,15 +81,15 @@
     }
 
     public RoutingTable build() {
-      return Optional.ofNullable(ref.getAndSet(null))
-          .map(RoutingTable::newRoutingTable)
-          .orElseThrow(() -> new IllegalStateException("RoutingTable Already built"));
+      final Map<RaftPeerId, Set<RaftPeerId>> map = ref.getAndSet(null);
+      if (map == null) {
+        throw new IllegalStateException("RoutingTable is already built.");
+      }
+      return RoutingTable.newRoutingTable(map);
     }
 
-    static void validate(Map<RaftPeerId, Set<RaftPeerId>> map) {
-      if (map != null && !map.isEmpty()) {
-        new Builder.Validation(map).run();
-      }
+    static RaftPeerId validate(Map<RaftPeerId, Set<RaftPeerId>> map) {
+      return new Builder.Validation(map).run();
     }
 
     /** Validate if a map represents a valid routing table. */
@@ -131,10 +134,11 @@
         this.unreachablePeers = allPeers;
       }
 
-      private void run() {
+      private RaftPeerId run() {
         depthFirstSearch(primary);
         Preconditions.assertTrue(unreachablePeers.isEmpty() ,
             () -> "Invalid routing table: peer(s) " + unreachablePeers +  " are unreachable, " + this);
+        return primary;
       }
 
       private void depthFirstSearch(RaftPeerId current) {
@@ -159,7 +163,10 @@
 
   /** @return a new {@link RoutingTable} represented by the given map. */
   static RoutingTable newRoutingTable(Map<RaftPeerId, Set<RaftPeerId>> map){
-    Builder.validate(map);
+    if (map == null || map.isEmpty()) {
+      return null;
+    }
+    final RaftPeerId primary = Builder.validate(map);
 
     final Supplier<RoutingTableProto> proto = JavaUtils.memoize(
         () -> RoutingTableProto.newBuilder().addAllRoutes(ProtoUtils.toRouteProtos(map)).build());
@@ -170,6 +177,11 @@
       }
 
       @Override
+      public RaftPeerId getPrimary() {
+        return primary;
+      }
+
+      @Override
       public RoutingTableProto toProto() {
         return proto.get();
       }
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index fd47045..dcb54be 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -62,6 +62,11 @@
     runWithNewCluster(NUM_SERVERS, this::testStreamWrites);
   }
 
+  @Test
+  public void testStreamWithInvalidRoutingTable() throws Exception {
+    runWithNewCluster(NUM_SERVERS, this::runTestInvalidPrimaryInRoutingTable);
+  }
+
   void testStreamWrites(CLUSTER cluster) throws Exception {
     waitForLeader(cluster);
     runTestDataStreamOutput(cluster);
@@ -96,6 +101,31 @@
     assertLogEntry(cluster, request);
   }
 
+  void runTestInvalidPrimaryInRoutingTable(CLUSTER cluster) throws Exception {
+    final RaftPeer primaryServer = CollectionUtils.random(cluster.getGroup().getPeers());
+
+    RaftPeer notPrimary = null;
+    for (RaftPeer peer: cluster.getGroup().getPeers()) {
+      if (!peer.equals(primaryServer)) {
+        notPrimary = peer;
+        break;
+      }
+    }
+
+    Assert.assertNotNull(
+        "Cannot find peer other than the primary", notPrimary);
+    Assert.assertNotEquals(primaryServer, notPrimary);
+
+    try (RaftClient client = cluster.createClient(primaryServer)) {
+      RoutingTable routingTableWithWrongPrimary =
+          getRoutingTable(cluster.getGroup().getPeers(), notPrimary);
+      testFailureCase("",
+          () -> client.getDataStreamApi().stream(null,
+              routingTableWithWrongPrimary),
+          IllegalStateException.class);
+    }
+  }
+
   void runTestWriteFile(CLUSTER cluster, int i,
       CheckedConsumer<DataStreamOutputImpl, Exception> testCase) throws Exception {
     final RaftClientRequest request;
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
index 06702f9..8e423ab 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamSslWithRpcTypeGrpcAndDataStreamTypeNetty.java
@@ -62,6 +62,11 @@
 
   @Ignore
   @Override
+  public void testStreamWithInvalidRoutingTable() {
+  }
+
+  @Ignore
+  @Override
   public void testMultipleStreamsMultipleServers() {
   }