blob: fe16acc3c089c202ead48f54fa33fd0d1aa6480d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import logging
import unittest
import mock
from apache_beam.internal.gcp import auth
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import PipelineOptions
try:
import google.auth as gauth
import google_auth_httplib2 # pylint: disable=unused-import
except ImportError:
gauth = None # type: ignore
class MockLoggingHandler(logging.Handler):
"""Mock logging handler to check for expected logs."""
def __init__(self, *args, **kwargs):
self.reset()
logging.Handler.__init__(self, *args, **kwargs)
def emit(self, record):
self.messages[record.levelname.lower()].append(record.getMessage())
def reset(self):
self.messages = {
'debug': [],
'info': [],
'warning': [],
'error': [],
'critical': [],
}
@unittest.skipIf(gauth is None, 'Google Auth dependencies are not installed')
class AuthTest(unittest.TestCase):
@mock.patch('google.auth.default')
def test_auth_with_retrys(self, unused_mock_arg):
pipeline_options = PipelineOptions()
pipeline_options.view_as(
GoogleCloudOptions).impersonate_service_account = False
credentials = ('creds', 1)
self.is_called = False
def side_effect(scopes=None):
if self.is_called:
return credentials
else:
self.is_called = True
raise IOError('Failed')
google_auth_mock = mock.MagicMock()
gauth.default = google_auth_mock
google_auth_mock.side_effect = side_effect
# _Credentials caches the actual credentials.
# This resets it for idempotent tests.
if auth._Credentials._credentials_init:
auth._Credentials._credentials_init = False
auth._Credentials._credentials = None
returned_credentials = auth.get_service_credentials(pipeline_options)
# _Credentials caches the actual credentials.
# This resets it for idempotent tests.
if auth._Credentials._credentials_init:
auth._Credentials._credentials_init = False
auth._Credentials._credentials = None
self.assertEqual('creds', returned_credentials._google_auth_credentials)
@mock.patch(
'apache_beam.internal.gcp.auth._Credentials._get_credentials_with_retrys')
def test_auth_with_retrys_always_fail(self, unused_mock_arg):
pipeline_options = PipelineOptions()
pipeline_options.view_as(
GoogleCloudOptions).impersonate_service_account = False
loggerHandler = MockLoggingHandler()
auth._LOGGER.addHandler(loggerHandler)
#Remove call to retrying method, as otherwise test takes ~10 minutes to run
def raise_(scopes=None):
raise IOError('Failed')
retry_auth_mock = mock.MagicMock()
auth._Credentials._get_credentials_with_retrys = retry_auth_mock
retry_auth_mock.side_effect = raise_
# _Credentials caches the actual credentials.
# This resets it for idempotent tests.
if auth._Credentials._credentials_init:
auth._Credentials._credentials_init = False
auth._Credentials._credentials = None
returned_credentials = auth.get_service_credentials(pipeline_options)
self.assertEqual(None, returned_credentials)
self.assertEqual([
'Unable to find default credentials to use: Failed\n'
'Connecting anonymously. This is expected if no credentials are '
'needed to access GCP resources.'
],
loggerHandler.messages.get('warning'))
# _Credentials caches the actual credentials.
# This resets it for idempotent tests.
if auth._Credentials._credentials_init:
auth._Credentials._credentials_init = False
auth._Credentials._credentials = None
auth._LOGGER.removeHandler(loggerHandler)
if __name__ == '__main__':
unittest.main()