| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.cassandra.distributed.test; |
| |
| import java.net.InetSocketAddress; |
| import java.util.Arrays; |
| import java.util.HashSet; |
| import java.util.Set; |
| import java.util.concurrent.atomic.AtomicInteger; |
| |
| import org.junit.Assert; |
| import org.junit.Test; |
| |
| import org.apache.cassandra.distributed.Cluster; |
| import org.apache.cassandra.distributed.api.ConsistencyLevel; |
| import org.apache.cassandra.distributed.api.IIsolatedExecutor; |
| import org.apache.cassandra.distributed.api.IMessage; |
| import org.apache.cassandra.distributed.api.IMessageFilters; |
| import org.apache.cassandra.distributed.impl.Instance; |
| import org.apache.cassandra.distributed.shared.MessageFilters; |
| import org.apache.cassandra.net.MessageIn; |
| import org.apache.cassandra.net.MessagingService; |
| |
| public class MessageFiltersTest extends TestBaseImpl |
| { |
| @Test |
| public void simpleInboundFiltersTest() |
| { |
| simpleFiltersTest(true); |
| } |
| |
| @Test |
| public void simpleOutboundFiltersTest() |
| { |
| simpleFiltersTest(false); |
| } |
| |
| private interface Permit |
| { |
| boolean test(int from, int to, IMessage msg); |
| } |
| |
| private static void simpleFiltersTest(boolean inbound) |
| { |
| int VERB1 = MessagingService.Verb.READ.ordinal(); |
| int VERB2 = MessagingService.Verb.REQUEST_RESPONSE.ordinal(); |
| int VERB3 = MessagingService.Verb.READ_REPAIR.ordinal(); |
| |
| int i1 = 1; |
| int i2 = 2; |
| int i3 = 3; |
| String MSG1 = "msg1"; |
| String MSG2 = "msg2"; |
| |
| MessageFilters filters = new MessageFilters(); |
| Permit permit = inbound ? filters::permitInbound : filters::permitOutbound; |
| |
| IMessageFilters.Filter filter = filters.allVerbs().inbound(inbound).from(1).drop(); |
| Assert.assertFalse(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertFalse(permit.test(i1, i2, msg(VERB2, MSG1))); |
| Assert.assertFalse(permit.test(i1, i2, msg(VERB3, MSG1))); |
| Assert.assertTrue(permit.test(i2, i1, msg(VERB1, MSG1))); |
| filter.off(); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB1, MSG1))); |
| filters.reset(); |
| |
| filters.verbs(VERB1).inbound(inbound).from(1).to(2).drop(); |
| Assert.assertFalse(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB2, MSG1))); |
| Assert.assertTrue(permit.test(i2, i1, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i2, i3, msg(VERB2, MSG1))); |
| |
| filters.reset(); |
| AtomicInteger counter = new AtomicInteger(); |
| filters.verbs(VERB1).inbound(inbound).from(1).to(2).messagesMatching((from, to, msg) -> { |
| counter.incrementAndGet(); |
| return Arrays.equals(msg.bytes(), MSG1.getBytes()); |
| }).drop(); |
| Assert.assertFalse(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertEquals(counter.get(), 1); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB1, MSG2))); |
| Assert.assertEquals(counter.get(), 2); |
| |
| // filter chain gets interrupted because a higher level filter returns no match |
| Assert.assertTrue(permit.test(i2, i1, msg(VERB1, MSG1))); |
| Assert.assertEquals(counter.get(), 2); |
| Assert.assertTrue(permit.test(i2, i1, msg(VERB2, MSG1))); |
| Assert.assertEquals(counter.get(), 2); |
| filters.reset(); |
| |
| filters.allVerbs().inbound(inbound).from(3, 2).to(2, 1).drop(); |
| Assert.assertFalse(permit.test(i3, i1, msg(VERB1, MSG1))); |
| Assert.assertFalse(permit.test(i3, i2, msg(VERB1, MSG1))); |
| Assert.assertFalse(permit.test(i2, i1, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i2, i3, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i1, i3, msg(VERB1, MSG1))); |
| filters.reset(); |
| |
| counter.set(0); |
| filters.allVerbs().inbound(inbound).from(1).to(2).messagesMatching((from, to, msg) -> { |
| counter.incrementAndGet(); |
| return false; |
| }).drop(); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i1, i3, msg(VERB1, MSG1))); |
| Assert.assertTrue(permit.test(i1, i2, msg(VERB1, MSG1))); |
| Assert.assertEquals(2, counter.get()); |
| } |
| |
| private static IMessage msg(int verb, String msg) |
| { |
| return new IMessage() |
| { |
| public int verb() { return verb; } |
| public byte[] bytes() { return msg.getBytes(); } |
| public int id() { return 0; } |
| public int version() { return 0; } |
| public InetSocketAddress from() { return null; } |
| }; |
| } |
| @Test |
| public void testFilters() throws Throwable |
| { |
| String read = "SELECT * FROM " + KEYSPACE + ".tbl"; |
| String write = "INSERT INTO " + KEYSPACE + ".tbl (pk, ck, v) VALUES (1, 1, 1)"; |
| |
| try (Cluster cluster = Cluster.create(2)) |
| { |
| cluster.schemaChange("CREATE KEYSPACE " + KEYSPACE + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': " + cluster.size() + "};"); |
| cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); |
| |
| // Reads and writes are going to time out in both directions |
| cluster.filters().allVerbs().from(1).to(2).drop(); |
| for (int i : new int[]{ 1, 2 }) |
| assertTimeOut(() -> cluster.coordinator(i).execute(read, ConsistencyLevel.ALL)); |
| for (int i : new int[]{ 1, 2 }) |
| assertTimeOut(() -> cluster.coordinator(i).execute(write, ConsistencyLevel.ALL)); |
| |
| cluster.filters().reset(); |
| // Reads are going to timeout only when 1 serves as a coordinator |
| cluster.verbs(MessagingService.Verb.RANGE_SLICE).from(1).to(2).drop(); |
| assertTimeOut(() -> cluster.coordinator(1).execute(read, ConsistencyLevel.ALL)); |
| cluster.coordinator(2).execute(read, ConsistencyLevel.ALL); |
| |
| // Writes work in both directions |
| for (int i : new int[]{ 1, 2 }) |
| cluster.coordinator(i).execute(write, ConsistencyLevel.ALL); |
| } |
| } |
| |
| @Test |
| public void testMessageMatching() throws Throwable |
| { |
| String read = "SELECT * FROM " + KEYSPACE + ".tbl"; |
| String write = "INSERT INTO " + KEYSPACE + ".tbl (pk, ck, v) VALUES (1, 1, 1)"; |
| |
| try (Cluster cluster = Cluster.create(2)) |
| { |
| cluster.schemaChange("CREATE KEYSPACE " + KEYSPACE + " WITH replication = {'class': 'SimpleStrategy', 'replication_factor': " + cluster.size() + "};"); |
| cluster.schemaChange("CREATE TABLE " + KEYSPACE + ".tbl (pk int, ck int, v int, PRIMARY KEY (pk, ck))"); |
| |
| AtomicInteger counter = new AtomicInteger(); |
| |
| Set<Integer> verbs = new HashSet<>(Arrays.asList(MessagingService.Verb.RANGE_SLICE.ordinal(), |
| MessagingService.Verb.MUTATION.ordinal())); |
| |
| // Reads and writes are going to time out in both directions |
| IMessageFilters.Filter filter = cluster.filters() |
| .allVerbs() |
| .from(1) |
| .to(2) |
| .messagesMatching((from, to, msg) -> { |
| // Decode and verify message on instance; return the result back here |
| Integer id = cluster.get(1).callsOnInstance((IIsolatedExecutor.SerializableCallable<Integer>) () -> { |
| MessageIn decoded = Instance.deserializeMessage(msg); |
| if (decoded != null) |
| return (Integer) decoded.verb.ordinal(); |
| return -1; |
| }).call(); |
| if (id > 0) |
| Assert.assertTrue(verbs.contains(id)); |
| counter.incrementAndGet(); |
| return false; |
| }).drop(); |
| |
| for (int i : new int[]{ 1, 2 }) |
| cluster.coordinator(i).execute(read, ConsistencyLevel.ALL); |
| for (int i : new int[]{ 1, 2 }) |
| cluster.coordinator(i).execute(write, ConsistencyLevel.ALL); |
| |
| filter.off(); |
| Assert.assertEquals(4, counter.get()); |
| } |
| } |
| |
| private static void assertTimeOut(Runnable r) |
| { |
| try |
| { |
| r.run(); |
| Assert.fail("Should have timed out"); |
| } |
| catch (Throwable t) |
| { |
| if (!t.toString().contains("TimeoutException")) |
| throw t; |
| // ignore |
| } |
| } |
| } |