blob: 19a61334fd773eed4fa58b65843ce961269c4cd4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import io
import os
import sys
import pytest
from cassandra.auth import PlainTextAuthProvider
from cqlshlib.authproviderhandling import load_auth_provider
def construct_config_path(config_file_name):
return os.path.join(os.path.dirname(__file__),
'test_authproviderhandling_config',
config_file_name)
# Simple class to help verify AuthProviders that don't need arguments.
class NoUserNamePlainTextAuthProvider(PlainTextAuthProvider):
def __init__(self):
super(NoUserNamePlainTextAuthProvider, self).__init__('', '')
class ComplexTextAuthProvider(PlainTextAuthProvider):
def __init__(self, username, password='default_pass', extra_flag=None):
super(ComplexTextAuthProvider, self).__init__(username, password)
self.extra_flag = extra_flag
def _assert_auth_provider_matches(actual, klass, expected_props):
"""
Assert that the provider matches class and properties
* actual ..........: Thing to compare with it
* klass ...........: Class to ensure this matches to (ie PlainTextAuthProvider)
* expected_props ..: Dict of var properties to match
"""
assert isinstance(actual, klass)
assert expected_props == vars(actual)
class CustomAuthProviderTest(unittest.TestCase):
def setUp(self):
self._captured_std_err = io.StringIO()
sys.stderr = self._captured_std_err
def tearDown(self):
self._captured_std_err.close()
sys.stdout = sys.__stderr__
def test_no_warning_insecure_if_no_pass(self):
load_auth_provider(construct_config_path('plain_text_partial_example'))
err_msg = self._captured_std_err.getvalue()
assert err_msg == ''
def test_insecure_creds(self):
load_auth_provider(construct_config_path('full_plain_text_example'))
err_msg = self._captured_std_err.getvalue()
assert "Notice:" in err_msg
assert "Warning:" in err_msg
def test_creds_not_checked_for_non_plaintext(self):
load_auth_provider(construct_config_path('complex_auth_provider_with_pass'))
err_msg = self._captured_std_err.getvalue()
assert err_msg == ''
def test_partial_property_example(self):
actual = load_auth_provider(construct_config_path('partial_example'))
_assert_auth_provider_matches(
actual,
NoUserNamePlainTextAuthProvider,
{"username": '',
"password": ''})
def test_full_property_example(self):
actual = load_auth_provider(construct_config_path('full_plain_text_example'))
_assert_auth_provider_matches(
actual,
PlainTextAuthProvider,
{"username": 'user1',
"password": 'pass1'})
def test_empty_example(self):
actual = load_auth_provider(construct_config_path('empty_example'))
assert actual is None
def test_plaintextauth_when_not_defined(self):
creds_file = construct_config_path('plain_text_full_creds')
actual = load_auth_provider(cred_file=creds_file)
_assert_auth_provider_matches(
actual,
PlainTextAuthProvider,
{"username": 'user2',
"password": 'pass2'})
def test_no_cqlshrc_file(self):
actual = load_auth_provider()
assert actual is None
def test_no_classname_example(self):
actual = load_auth_provider(construct_config_path('no_classname_example'))
assert actual is None
def test_improper_config_example(self):
with pytest.raises(ModuleNotFoundError) as error:
load_auth_provider(construct_config_path('illegal_example'))
assert error is not None
def test_username_password_passed_from_commandline(self):
creds_file = construct_config_path('complex_auth_provider_creds')
cqlshrc = construct_config_path('complex_auth_provider')
actual = load_auth_provider(cqlshrc, creds_file, 'user-from-legacy', 'pass-from-legacy')
_assert_auth_provider_matches(
actual,
ComplexTextAuthProvider,
{"username": 'user-from-legacy',
"password": 'pass-from-legacy',
"extra_flag": 'flag2'})
def test_creds_example(self):
creds_file = construct_config_path('complex_auth_provider_creds')
cqlshrc = construct_config_path('complex_auth_provider')
actual = load_auth_provider(cqlshrc, creds_file)
_assert_auth_provider_matches(
actual,
ComplexTextAuthProvider,
{"username": 'user1',
"password": 'pass2',
"extra_flag": 'flag2'})
def test_legacy_example_use_passed_username(self):
creds_file = construct_config_path('plain_text_partial_creds')
cqlshrc = construct_config_path('plain_text_partial_example')
actual = load_auth_provider(cqlshrc, creds_file, 'user3')
_assert_auth_provider_matches(
actual,
PlainTextAuthProvider,
{"username": 'user3',
"password": 'pass2'})
def test_legacy_example_no_auth_provider_given(self):
cqlshrc = construct_config_path('empty_example')
creds_file = construct_config_path('complex_auth_provider_creds')
actual = load_auth_provider(cqlshrc, creds_file, 'user3', 'pass3')
_assert_auth_provider_matches(
actual,
PlainTextAuthProvider,
{"username": 'user3',
"password": 'pass3'})
def test_shouldnt_pass_no_password_when_alt_auth_provider(self):
cqlshrc = construct_config_path('complex_auth_provider')
creds_file = None
actual = load_auth_provider(cqlshrc, creds_file, 'user3')
_assert_auth_provider_matches(
actual,
ComplexTextAuthProvider,
{"username": 'user3',
"password": 'default_pass',
"extra_flag": 'flag1'})
def test_legacy_example_no_password(self):
cqlshrc = construct_config_path('plain_text_partial_example')
creds_file = None
actual = load_auth_provider(cqlshrc, creds_file, 'user3')
_assert_auth_provider_matches(
actual,
PlainTextAuthProvider,
{"username": 'user3',
"password": None})