blob: d3f56177eb6b3de7730d34dd1dc88bc4e99eed81 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.net;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.exceptions.RequestFailureReason;
import org.apache.cassandra.io.IVersionedSerializer;
import org.apache.cassandra.io.util.DataInputBuffer;
import org.apache.cassandra.io.util.DataInputPlus;
import org.apache.cassandra.io.util.DataOutputBuffer;
import org.apache.cassandra.io.util.DataOutputPlus;
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.tracing.Tracing.TraceType;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.utils.FreeRunningClock;
import static org.apache.cassandra.net.Message.serializer;
import static org.apache.cassandra.net.MessagingService.VERSION_3014;
import static org.apache.cassandra.net.MessagingService.VERSION_30;
import static org.apache.cassandra.net.MessagingService.VERSION_40;
import static org.apache.cassandra.net.NoPayload.noPayload;
import static org.apache.cassandra.net.ParamType.RESPOND_TO;
import static org.apache.cassandra.net.ParamType.TRACE_SESSION;
import static org.apache.cassandra.net.ParamType.TRACE_TYPE;
import static org.apache.cassandra.utils.MonotonicClock.approxTime;
import static org.junit.Assert.*;
public class MessageTest
{
@BeforeClass
public static void setUpClass() throws Exception
{
DatabaseDescriptor.daemonInitialization();
DatabaseDescriptor.setCrossNodeTimeout(true);
Verb._TEST_2.unsafeSetSerializer(() -> new IVersionedSerializer<Integer>()
{
public void serialize(Integer value, DataOutputPlus out, int version) throws IOException
{
out.writeInt(value);
}
public Integer deserialize(DataInputPlus in, int version) throws IOException
{
return in.readInt();
}
public long serializedSize(Integer value, int version)
{
return 4;
}
});
}
@AfterClass
public static void tearDownClass() throws Exception
{
Verb._TEST_2.unsafeSetSerializer(() -> NoPayload.serializer);
}
@Test
public void testInferMessageSize() throws Exception
{
Message<Integer> msg =
Message.builder(Verb._TEST_2, 37)
.withId(1)
.from(FBUtilities.getLocalAddressAndPort())
.withCreatedAt(approxTime.now())
.withExpiresAt(approxTime.now())
.withFlag(MessageFlag.CALL_BACK_ON_FAILURE)
.withFlag(MessageFlag.TRACK_REPAIRED_DATA)
.withParam(TRACE_TYPE, TraceType.QUERY)
.withParam(TRACE_SESSION, UUID.randomUUID())
.build();
testInferMessageSize(msg, VERSION_30);
testInferMessageSize(msg, VERSION_3014);
testInferMessageSize(msg, VERSION_40);
}
private void testInferMessageSize(Message msg, int version) throws Exception
{
try (DataOutputBuffer out = new DataOutputBuffer())
{
serializer.serialize(msg, out, version);
assertEquals(msg.serializedSize(version), out.getLength());
ByteBuffer buffer = out.buffer();
int payloadSize = (int) msg.verb().serializer().serializedSize(msg.payload, version);
int serializedSize = msg.serializedSize(version);
// should return -1 - fail to infer size - for all lengths of buffer until payload length can be read
for (int limit = 0; limit < serializedSize - payloadSize; limit++)
assertEquals(-1, serializer.inferMessageSize(buffer, 0, limit, version));
// once payload size can be read, should correctly infer message size
for (int limit = serializedSize - payloadSize; limit < serializedSize; limit++)
assertEquals(serializedSize, serializer.inferMessageSize(buffer, 0, limit, version));
}
}
@Test
public void testBuilder()
{
long id = 1;
InetAddressAndPort from = FBUtilities.getLocalAddressAndPort();
long createAtNanos = approxTime.now();
long expiresAtNanos = createAtNanos + TimeUnit.SECONDS.toNanos(1);
TraceType traceType = TraceType.QUERY;
UUID traceSession = UUID.randomUUID();
Message<NoPayload> msg =
Message.builder(Verb._TEST_1, noPayload)
.withId(1)
.from(from)
.withCreatedAt(createAtNanos)
.withExpiresAt(expiresAtNanos)
.withFlag(MessageFlag.CALL_BACK_ON_FAILURE)
.withParam(TRACE_TYPE, TraceType.QUERY)
.withParam(TRACE_SESSION, traceSession)
.build();
assertEquals(id, msg.id());
assertEquals(from, msg.from());
assertEquals(createAtNanos, msg.createdAtNanos());
assertEquals(expiresAtNanos, msg.expiresAtNanos());
assertTrue(msg.callBackOnFailure());
assertFalse(msg.trackRepairedData());
assertEquals(traceType, msg.traceType());
assertEquals(traceSession, msg.traceSession());
assertNull(msg.forwardTo());
assertNull(msg.respondTo());
}
@Test
public void testCycleNoPayload() throws IOException
{
Message<NoPayload> msg =
Message.builder(Verb._TEST_1, noPayload)
.withId(1)
.from(FBUtilities.getLocalAddressAndPort())
.withCreatedAt(approxTime.now())
.withExpiresAt(approxTime.now() + TimeUnit.SECONDS.toNanos(1))
.withFlag(MessageFlag.CALL_BACK_ON_FAILURE)
.withParam(TRACE_SESSION, UUID.randomUUID())
.build();
testCycle(msg);
}
@Test
public void testCycleWithPayload() throws Exception
{
testCycle(Message.out(Verb._TEST_2, 42));
testCycle(Message.outWithFlag(Verb._TEST_2, 42, MessageFlag.CALL_BACK_ON_FAILURE));
testCycle(Message.outWithFlags(Verb._TEST_2, 42, MessageFlag.CALL_BACK_ON_FAILURE, MessageFlag.TRACK_REPAIRED_DATA));
testCycle(Message.outWithParam(1, Verb._TEST_2, 42, RESPOND_TO, FBUtilities.getBroadcastAddressAndPort()));
}
@Test
public void testFailureResponse() throws IOException
{
long expiresAt = approxTime.now();
Message<RequestFailureReason> msg = Message.failureResponse(1, expiresAt, RequestFailureReason.INCOMPATIBLE_SCHEMA);
assertEquals(1, msg.id());
assertEquals(Verb.FAILURE_RSP, msg.verb());
assertEquals(expiresAt, msg.expiresAtNanos());
assertEquals(RequestFailureReason.INCOMPATIBLE_SCHEMA, msg.payload);
assertTrue(msg.isFailureResponse());
testCycle(msg);
}
@Test
public void testBuilderAddTraceHeaderWhenTraceSessionPresent()
{
Stream.of(TraceType.values()).forEach(this::testAddTraceHeaderWithType);
}
@Test
public void testBuilderNotAddTraceHeaderWithNoTraceSession()
{
Message<NoPayload> msg = Message.builder(Verb._TEST_1, noPayload).withTracingParams().build();
assertNull(msg.header.traceSession());
}
@Test
public void testCustomParams() throws CharacterCodingException, IOException
{
long id = 1;
InetAddressAndPort from = FBUtilities.getLocalAddressAndPort();
Message<NoPayload> msg =
Message.builder(Verb._TEST_1, noPayload)
.withId(1)
.from(from)
.withCustomParam("custom1", "custom1value".getBytes(StandardCharsets.UTF_8))
.withCustomParam("custom2", "custom2value".getBytes(StandardCharsets.UTF_8))
.build();
assertEquals(id, msg.id());
assertEquals(from, msg.from());
assertEquals(2, msg.header.customParams().size());
assertEquals("custom1value", new String(msg.header.customParams().get("custom1"), StandardCharsets.UTF_8));
assertEquals("custom2value", new String(msg.header.customParams().get("custom2"), StandardCharsets.UTF_8));
DataOutputBuffer out = DataOutputBuffer.scratchBuffer.get();
Message.serializer.serialize(msg, out, VERSION_40);
DataInputBuffer in = new DataInputBuffer(out.buffer(), true);
msg = Message.serializer.deserialize(in, from, VERSION_40);
assertEquals(id, msg.id());
assertEquals(from, msg.from());
assertEquals(2, msg.header.customParams().size());
assertEquals("custom1value", new String(msg.header.customParams().get("custom1"), StandardCharsets.UTF_8));
assertEquals("custom2value", new String(msg.header.customParams().get("custom2"), StandardCharsets.UTF_8));
}
private void testAddTraceHeaderWithType(TraceType traceType)
{
try
{
UUID sessionId = Tracing.instance.newSession(traceType);
Message<NoPayload> msg = Message.builder(Verb._TEST_1, noPayload).withTracingParams().build();
assertEquals(sessionId, msg.header.traceSession());
assertEquals(traceType, msg.header.traceType());
}
finally
{
Tracing.instance.stopSession();
}
}
private void testCycle(Message msg) throws IOException
{
testCycle(msg, VERSION_30);
testCycle(msg, VERSION_3014);
testCycle(msg, VERSION_40);
}
// serialize (using both variants, all in one or header then rest), verify serialized size, deserialize, compare to the original
private void testCycle(Message msg, int version) throws IOException
{
try (DataOutputBuffer out = new DataOutputBuffer())
{
serializer.serialize(msg, out, version);
assertEquals(msg.serializedSize(version), out.getLength());
// deserialize the message in one go, compare outcomes
try (DataInputBuffer in = new DataInputBuffer(out.buffer(), true))
{
Message msgOut = serializer.deserialize(in, msg.from(), version);
assertEquals(0, in.available());
assertMessagesEqual(msg, msgOut);
}
// extract header first, then deserialize the rest of the message and compare outcomes
ByteBuffer buffer = out.buffer();
try (DataInputBuffer in = new DataInputBuffer(out.buffer(), false))
{
Message.Header headerOut = serializer.extractHeader(buffer, msg.from(), approxTime.now(), version);
Message msgOut = serializer.deserialize(in, headerOut, version);
assertEquals(0, in.available());
assertMessagesEqual(msg, msgOut);
}
}
}
private static void assertMessagesEqual(Message msg1, Message msg2)
{
assertEquals(msg1.id(), msg2.id());
assertEquals(msg1.verb(), msg2.verb());
assertEquals(msg1.callBackOnFailure(), msg2.callBackOnFailure());
assertEquals(msg1.trackRepairedData(), msg2.trackRepairedData());
assertEquals(msg1.traceType(), msg2.traceType());
assertEquals(msg1.traceSession(), msg2.traceSession());
assertEquals(msg1.respondTo(), msg2.respondTo());
assertEquals(msg1.forwardTo(), msg2.forwardTo());
Object payload1 = msg1.payload;
Object payload2 = msg2.payload;
if (null == payload1)
assertTrue(payload2 == noPayload || payload2 == null);
else if (null == payload2)
assertSame(payload1, noPayload);
else
assertEquals(payload1, payload2);
}
@Test
public void testCreationTime()
{
long remoteTime = 1632087572480L; // 10111110000000000000000000000000000000000
long localTime = 1632087572479L; // 10111101111111111111111111111111111111111
FreeRunningClock localClock = new FreeRunningClock(TimeUnit.DAYS.toNanos(1), localTime, 0);
int remoteCreatedAt = (int) (remoteTime & 0x00000000FFFFFFFFL);
long localTimeNanos = localClock.now();
assertTrue( Message.Serializer.calculateCreationTimeNanos(remoteCreatedAt, localClock.translate(), localTimeNanos) > 0);
}
}