blob: 61c90f89cb41ce90a84c11714adbf81a826bc465 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from unittest import mock
from airflow.contrib.hooks.pagerduty_hook import PagerdutyHook
from airflow.models import Connection
from airflow.utils import db
DEFAULT_CONN_ID = "pagerduty_default"
class TestPagerdutyHook(unittest.TestCase):
@db.provide_session
def setUp(self, session=None):
session.add(Connection(
conn_id=DEFAULT_CONN_ID,
password="pagerduty_token",
extra='{"routing_key": "route"}',
))
session.commit()
@db.provide_session
def test_without_routing_key_extra(self, session):
session.add(Connection(
conn_id="pagerduty_no_extra",
password="pagerduty_token_without_extra",
))
session.commit()
hook = PagerdutyHook(pagerduty_conn_id="pagerduty_no_extra")
self.assertEqual(hook.token, 'pagerduty_token_without_extra', 'token initialised.')
self.assertEqual(hook.routing_key, None, 'default routing key skipped.')
def test_get_token_from_password(self):
hook = PagerdutyHook(pagerduty_conn_id=DEFAULT_CONN_ID)
self.assertEqual(hook.token, 'pagerduty_token', 'token initialised.')
def test_token_parameter_override(self):
hook = PagerdutyHook(token="pagerduty_param_token", pagerduty_conn_id=DEFAULT_CONN_ID)
self.assertEqual(hook.token, 'pagerduty_param_token', 'token initialised.')
@mock.patch('airflow.contrib.hooks.pagerduty_hook.pypd.EventV2.create')
def test_create_event(self, mock_event_create):
hook = PagerdutyHook(pagerduty_conn_id=DEFAULT_CONN_ID)
mock_event_create.return_value = {
"status": "success",
"message": "Event processed",
"dedup_key": "samplekeyhere",
}
resp = hook.create_event(
routing_key="key",
summary="test",
source="airflow_test",
severity="error",
)
self.assertEqual(resp["status"], "success")
mock_event_create.assert_called_once_with(
api_key="pagerduty_token",
data={
"routing_key": "key",
"event_action": "trigger",
"payload": {
"severity": "error",
"source": "airflow_test",
"summary": "test",
},
})
@mock.patch('airflow.contrib.hooks.pagerduty_hook.pypd.EventV2.create')
def test_create_event_with_default_routing_key(self, mock_event_create):
hook = PagerdutyHook(pagerduty_conn_id=DEFAULT_CONN_ID)
mock_event_create.return_value = {
"status": "success",
"message": "Event processed",
"dedup_key": "samplekeyhere",
}
resp = hook.create_event(
summary="test",
source="airflow_test",
severity="error",
custom_details='{"foo": "bar"}',
)
self.assertEqual(resp["status"], "success")
mock_event_create.assert_called_once_with(
api_key="pagerduty_token",
data={
"routing_key": "route",
"event_action": "trigger",
"payload": {
"severity": "error",
"source": "airflow_test",
"summary": "test",
"custom_details": '{"foo": "bar"}',
},
})
if __name__ == '__main__':
unittest.main()