blob: b1acbb5e718e6b2cc7a450fc721c1459a7ea6501 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pytest
from cryptography.fernet import Fernet
from airflow.sdk.crypto import _NullFernet, _RealFernet, get_fernet
from airflow.sdk.exceptions import AirflowException
from tests_common.test_utils.config import conf_vars
class TestNullFernet:
def test_decryption_of_invalid_type(self):
"""Should raise ValueError for non-string/bytes input."""
null_fernet = _NullFernet()
with pytest.raises(ValueError, match="Expected bytes or str, got <class 'int'>"):
null_fernet.decrypt(123) # type: ignore[arg-type]
def test_encrypt_decrypt_roundtrip(self):
"""Should preserve data through encrypt/decrypt cycle."""
null_fernet = _NullFernet()
original = b"test message"
encrypted = null_fernet.encrypt(original)
decrypted = null_fernet.decrypt(encrypted)
assert decrypted == original
assert encrypted == original
class TestRealFernet:
def test_encryption(self):
"""Encryption should produce encrypted output different from input."""
from cryptography.fernet import MultiFernet
key = Fernet.generate_key()
real_fernet = _RealFernet(MultiFernet([Fernet(key)]))
msg = b"secret message"
encrypted = real_fernet.encrypt(msg)
assert encrypted != msg
assert real_fernet.decrypt(encrypted) == msg
def test_rotate_reencrypt_with_primary_key(self):
"""rotate() should re-encrypt data with the primary key."""
from cryptography.fernet import MultiFernet
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
# encrypt with key1 only
encrypted_with_key1 = Fernet(key1).encrypt(b"rotate test")
# MultiFernet with key2 as primary, key1 as fallback
multi = MultiFernet([Fernet(key2), Fernet(key1)])
real_fernet = _RealFernet(multi)
# rotate should re-encrypt with key2
rotated = real_fernet.rotate(encrypted_with_key1)
# key2 should be able to decrypt
assert Fernet(key2).decrypt(rotated) == b"rotate test"
assert rotated != encrypted_with_key1
class TestGetFernet:
@conf_vars({("core", "FERNET_KEY"): ""})
def test_empty_key(self):
get_fernet.cache_clear()
fernet = get_fernet()
assert not fernet.is_encrypted
test_data = b"unencrypted"
assert fernet.encrypt(test_data) == test_data
assert fernet.decrypt(test_data) == test_data
@conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
def test_valid_key_encrypts_data(self):
"""Valid FERNET_KEY should return working encryption."""
get_fernet.cache_clear()
fernet = get_fernet()
assert fernet.is_encrypted
original = b"sensitive data"
encrypted = fernet.encrypt(original)
assert encrypted != original
assert fernet.decrypt(encrypted) == original
@conf_vars({("core", "FERNET_KEY"): "invalid-key"})
def test_invalid_key(self):
"""Invalid FERNET_KEY should raise a AirflowException."""
get_fernet.cache_clear()
with pytest.raises(AirflowException, match="Could not create Fernet object"):
get_fernet()
def test_multiple_keys(self):
"""Multiple comma separated keys should support key rotation."""
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
# encrypt with key1
data_encrypted_with_key1 = Fernet(key1).encrypt(b"secret data")
# get_fernet with both keys (key2 primary, key1 fallback)
with conf_vars({("core", "FERNET_KEY"): f"{key2.decode()},{key1.decode()}"}):
get_fernet.cache_clear()
fernet = get_fernet()
# decrypt data encrypted with old key1
assert fernet.decrypt(data_encrypted_with_key1) == b"secret data"
new_encrypted = fernet.encrypt(b"new")
assert Fernet(key2).decrypt(new_encrypted) == b"new"
@conf_vars({("core", "FERNET_KEY"): Fernet.generate_key().decode()})
def test_caching(self):
"""get_fernet() should return cached instance."""
get_fernet.cache_clear()
fernet1 = get_fernet()
fernet2 = get_fernet()
assert fernet1 is fernet2