blob: 08516eba9e8996d491d5ecaee955db56bbad5abf [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import Mock
from cassandra import ProtocolVersion, UnsupportedOperation
from cassandra.protocol import (
PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation,
_PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG,
_PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG,
BatchMessage
)
from cassandra.query import BatchType
from cassandra.marshal import uint32_unpack
from cassandra.cluster import ContinuousPagingOptions
class MessageTest(unittest.TestCase):
def test_prepare_message(self):
"""
Test to check the appropriate calls are made
@since 3.9
@jira_ticket PYTHON-713
@expected_result the values are correctly written
@test_category connection
"""
message = PrepareMessage("a")
io = Mock()
message.send_body(io, 4)
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)])
io.reset_mock()
message.send_body(io, 5)
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)])
def test_execute_message(self):
message = ExecuteMessage('1', [], 4)
io = Mock()
message.send_body(io, 4)
self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)])
io.reset_mock()
message.result_metadata_id = 'foo'
message.send_body(io, 5)
self._check_calls(io, [(b'\x00\x01',), (b'1',),
(b'\x00\x03',), (b'foo',),
(b'\x00\x04',),
(b'\x00\x00\x00\x01',), (b'\x00\x00',)])
def test_query_message(self):
"""
Test to check the appropriate calls are made
@since 3.9
@jira_ticket PYTHON-713
@expected_result the values are correctly written
@test_category connection
"""
message = QueryMessage("a", 3)
io = Mock()
message.send_body(io, 4)
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)])
io.reset_mock()
message.send_body(io, 5)
self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)])
def _check_calls(self, io, expected):
self.assertEqual(
tuple(c[1] for c in io.write.mock_calls),
tuple(expected)
)
def test_continuous_paging(self):
"""
Test to check continuous paging throws an Exception if it's not supported and the correct values
are written to the buffer if the option is enabled.
@since DSE 2.0b3 GRAPH 1.0b1
@jira_ticket PYTHON-694
@expected_result the values are correctly written
@test_category connection
"""
max_pages = 4
max_pages_per_second = 3
continuous_paging_options = ContinuousPagingOptions(max_pages=max_pages,
max_pages_per_second=max_pages_per_second)
message = QueryMessage("a", 3, continuous_paging_options=continuous_paging_options)
io = Mock()
for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS
if not ProtocolVersion.has_continuous_paging_support(version)]:
self.assertRaises(UnsupportedOperation, message.send_body, io, version)
io.reset_mock()
message.send_body(io, ProtocolVersion.DSE_V1)
# continuous paging adds two write calls to the buffer
self.assertEqual(len(io.write.mock_calls), 6)
# Check that the appropriate flag is set to True
self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 0)
self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG, 0)
self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_PAGING_STATE_FLAG, 0)
self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG, _PAGING_OPTIONS_FLAG)
# Test max_pages and max_pages_per_second are correctly written
self.assertEqual(uint32_unpack(io.write.mock_calls[4][1][0]), max_pages)
self.assertEqual(uint32_unpack(io.write.mock_calls[5][1][0]), max_pages_per_second)
def test_prepare_flag(self):
"""
Test to check the prepare flag is properly set, This should only happen for V5 at the moment.
@since 3.9
@jira_ticket PYTHON-694, PYTHON-713
@expected_result the values are correctly written
@test_category connection
"""
message = PrepareMessage("a")
io = Mock()
for version in ProtocolVersion.SUPPORTED_VERSIONS:
message.send_body(io, version)
if ProtocolVersion.uses_prepare_flags(version):
self.assertEqual(len(io.write.mock_calls), 3)
else:
self.assertEqual(len(io.write.mock_calls), 2)
io.reset_mock()
def test_prepare_flag_with_keyspace(self):
message = PrepareMessage("a", keyspace='ks')
io = Mock()
for version in ProtocolVersion.SUPPORTED_VERSIONS:
if ProtocolVersion.uses_keyspace_flag(version):
message.send_body(io, version)
self._check_calls(io, [
(b'\x00\x00\x00\x01',),
(b'a',),
(b'\x00\x00\x00\x01',),
(b'\x00\x02',),
(b'ks',),
])
else:
with self.assertRaises(UnsupportedOperation):
message.send_body(io, version)
io.reset_mock()
def test_keyspace_flag_raises_before_v5(self):
keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks')
io = Mock(name='io')
with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'):
keyspace_message.send_body(io, protocol_version=4)
io.assert_not_called()
def test_keyspace_written_with_length(self):
io = Mock(name='io')
base_expected = [
(b'\x00\x00\x00\x01',),
(b'a',),
(b'\x00\x03',),
(b'\x00\x00\x00\x80',), # options w/ keyspace flag
]
QueryMessage('a', consistency_level=3, keyspace='ks').send_body(
io, protocol_version=5
)
self._check_calls(io, base_expected + [
(b'\x00\x02',), # length of keyspace string
(b'ks',),
])
io.reset_mock()
QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body(
io, protocol_version=5
)
self._check_calls(io, base_expected + [
(b'\x00\x08',), # length of keyspace string
(b'keyspace',),
])
def test_batch_message_with_keyspace(self):
self.maxDiff = None
io = Mock(name='io')
batch = BatchMessage(
batch_type=BatchType.LOGGED,
queries=((False, 'stmt a', ('param a',)),
(False, 'stmt b', ('param b',)),
(False, 'stmt c', ('param c',))
),
consistency_level=3,
keyspace='ks'
)
batch.send_body(io, protocol_version=5)
self._check_calls(io,
((b'\x00',), (b'\x00\x03',), (b'\x00',),
(b'\x00\x00\x00\x06',), (b'stmt a',),
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',),
(b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',),
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',),
(b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',),
(b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',),
(b'\x00\x03',),
(b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',))
)