| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import unittest |
| |
| from unittest.mock import Mock |
| |
| from cassandra import ProtocolVersion, UnsupportedOperation |
| from cassandra.protocol import ( |
| PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation, |
| _PAGING_OPTIONS_FLAG, _WITH_SERIAL_CONSISTENCY_FLAG, |
| _PAGE_SIZE_FLAG, _WITH_PAGING_STATE_FLAG, |
| BatchMessage |
| ) |
| from cassandra.query import BatchType |
| from cassandra.marshal import uint32_unpack |
| from cassandra.cluster import ContinuousPagingOptions |
| |
| |
| class MessageTest(unittest.TestCase): |
| |
| def test_prepare_message(self): |
| """ |
| Test to check the appropriate calls are made |
| |
| @since 3.9 |
| @jira_ticket PYTHON-713 |
| @expected_result the values are correctly written |
| |
| @test_category connection |
| """ |
| message = PrepareMessage("a") |
| io = Mock() |
| |
| message.send_body(io, 4) |
| self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',)]) |
| |
| io.reset_mock() |
| message.send_body(io, 5) |
| |
| self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x00\x00\x00',)]) |
| |
| def test_execute_message(self): |
| message = ExecuteMessage('1', [], 4) |
| io = Mock() |
| |
| message.send_body(io, 4) |
| self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)]) |
| |
| io.reset_mock() |
| message.result_metadata_id = 'foo' |
| message.send_body(io, 5) |
| |
| self._check_calls(io, [(b'\x00\x01',), (b'1',), |
| (b'\x00\x03',), (b'foo',), |
| (b'\x00\x04',), |
| (b'\x00\x00\x00\x01',), (b'\x00\x00',)]) |
| |
| def test_query_message(self): |
| """ |
| Test to check the appropriate calls are made |
| |
| @since 3.9 |
| @jira_ticket PYTHON-713 |
| @expected_result the values are correctly written |
| |
| @test_category connection |
| """ |
| message = QueryMessage("a", 3) |
| io = Mock() |
| |
| message.send_body(io, 4) |
| self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00',)]) |
| |
| io.reset_mock() |
| message.send_body(io, 5) |
| self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) |
| |
| def _check_calls(self, io, expected): |
| self.assertEqual( |
| tuple(c[1] for c in io.write.mock_calls), |
| tuple(expected) |
| ) |
| |
| def test_continuous_paging(self): |
| """ |
| Test to check continuous paging throws an Exception if it's not supported and the correct values |
| are written to the buffer if the option is enabled. |
| |
| @since DSE 2.0b3 GRAPH 1.0b1 |
| @jira_ticket PYTHON-694 |
| @expected_result the values are correctly written |
| |
| @test_category connection |
| """ |
| max_pages = 4 |
| max_pages_per_second = 3 |
| continuous_paging_options = ContinuousPagingOptions(max_pages=max_pages, |
| max_pages_per_second=max_pages_per_second) |
| message = QueryMessage("a", 3, continuous_paging_options=continuous_paging_options) |
| io = Mock() |
| for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS |
| if not ProtocolVersion.has_continuous_paging_support(version)]: |
| self.assertRaises(UnsupportedOperation, message.send_body, io, version) |
| |
| io.reset_mock() |
| message.send_body(io, ProtocolVersion.DSE_V1) |
| |
| # continuous paging adds two write calls to the buffer |
| self.assertEqual(len(io.write.mock_calls), 6) |
| # Check that the appropriate flag is set to True |
| self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 0) |
| self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG, 0) |
| self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_PAGING_STATE_FLAG, 0) |
| self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG, _PAGING_OPTIONS_FLAG) |
| |
| # Test max_pages and max_pages_per_second are correctly written |
| self.assertEqual(uint32_unpack(io.write.mock_calls[4][1][0]), max_pages) |
| self.assertEqual(uint32_unpack(io.write.mock_calls[5][1][0]), max_pages_per_second) |
| |
| def test_prepare_flag(self): |
| """ |
| Test to check the prepare flag is properly set, This should only happen for V5 at the moment. |
| |
| @since 3.9 |
| @jira_ticket PYTHON-694, PYTHON-713 |
| @expected_result the values are correctly written |
| |
| @test_category connection |
| """ |
| message = PrepareMessage("a") |
| io = Mock() |
| for version in ProtocolVersion.SUPPORTED_VERSIONS: |
| message.send_body(io, version) |
| if ProtocolVersion.uses_prepare_flags(version): |
| self.assertEqual(len(io.write.mock_calls), 3) |
| else: |
| self.assertEqual(len(io.write.mock_calls), 2) |
| io.reset_mock() |
| |
| def test_prepare_flag_with_keyspace(self): |
| message = PrepareMessage("a", keyspace='ks') |
| io = Mock() |
| |
| for version in ProtocolVersion.SUPPORTED_VERSIONS: |
| if ProtocolVersion.uses_keyspace_flag(version): |
| message.send_body(io, version) |
| self._check_calls(io, [ |
| (b'\x00\x00\x00\x01',), |
| (b'a',), |
| (b'\x00\x00\x00\x01',), |
| (b'\x00\x02',), |
| (b'ks',), |
| ]) |
| else: |
| with self.assertRaises(UnsupportedOperation): |
| message.send_body(io, version) |
| io.reset_mock() |
| |
| def test_keyspace_flag_raises_before_v5(self): |
| keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks') |
| io = Mock(name='io') |
| |
| with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'): |
| keyspace_message.send_body(io, protocol_version=4) |
| io.assert_not_called() |
| |
| def test_keyspace_written_with_length(self): |
| io = Mock(name='io') |
| base_expected = [ |
| (b'\x00\x00\x00\x01',), |
| (b'a',), |
| (b'\x00\x03',), |
| (b'\x00\x00\x00\x80',), # options w/ keyspace flag |
| ] |
| |
| QueryMessage('a', consistency_level=3, keyspace='ks').send_body( |
| io, protocol_version=5 |
| ) |
| self._check_calls(io, base_expected + [ |
| (b'\x00\x02',), # length of keyspace string |
| (b'ks',), |
| ]) |
| |
| io.reset_mock() |
| |
| QueryMessage('a', consistency_level=3, keyspace='keyspace').send_body( |
| io, protocol_version=5 |
| ) |
| self._check_calls(io, base_expected + [ |
| (b'\x00\x08',), # length of keyspace string |
| (b'keyspace',), |
| ]) |
| |
| def test_batch_message_with_keyspace(self): |
| self.maxDiff = None |
| io = Mock(name='io') |
| batch = BatchMessage( |
| batch_type=BatchType.LOGGED, |
| queries=((False, 'stmt a', ('param a',)), |
| (False, 'stmt b', ('param b',)), |
| (False, 'stmt c', ('param c',)) |
| ), |
| consistency_level=3, |
| keyspace='ks' |
| ) |
| batch.send_body(io, protocol_version=5) |
| self._check_calls(io, |
| ((b'\x00',), (b'\x00\x03',), (b'\x00',), |
| (b'\x00\x00\x00\x06',), (b'stmt a',), |
| (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param a',), |
| (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt b',), |
| (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param b',), |
| (b'\x00',), (b'\x00\x00\x00\x06',), (b'stmt c',), |
| (b'\x00\x01',), (b'\x00\x00\x00\x07',), ('param c',), |
| (b'\x00\x03',), |
| (b'\x00\x00\x00\x80',), (b'\x00\x02',), (b'ks',)) |
| ) |