blob: ef790ef74252ac00715d86881a3f7b3571b48e83 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from pyfory import Fory, DeserializationPolicy
class BlockClassPolicy(DeserializationPolicy):
"""Policy that blocks specific class names from deserialization."""
def __init__(self, blocked_class_names):
self.blocked_class_names = blocked_class_names
def validate_class(self, cls, is_local, **kwargs):
if cls.__name__ in self.blocked_class_names:
raise ValueError(f"Class {cls.__name__} is blocked")
return None
class ReplaceObjectPolicy(DeserializationPolicy):
"""Policy that replaces deserialized objects from reduce."""
def __init__(self, replacement_value):
self.replacement_value = replacement_value
def inspect_reduced_object(self, obj, **kwargs):
if hasattr(obj, "value"):
return self.replacement_value
return None
class BlockReduceCallPolicy(DeserializationPolicy):
"""Policy that blocks specific callable invocations during reduce."""
def __init__(self, blocked_names):
self.blocked_names = blocked_names
def intercept_reduce_call(self, callable_obj, args, **kwargs):
if hasattr(callable_obj, "__name__") and callable_obj.__name__ in self.blocked_names:
raise ValueError(f"Callable {callable_obj.__name__} is blocked")
return None
class SanitizeStatePolicy(DeserializationPolicy):
"""Policy that sanitizes object state during setstate."""
def intercept_setstate(self, obj, state, **kwargs):
if isinstance(state, dict) and "password" in state:
state["password"] = "***REDACTED***"
return None
def test_block_class_type_deserialization():
"""Test blocking class type (not instance) deserialization."""
class SafeClass:
pass
class UnsafeClass:
pass
policy = BlockClassPolicy(blocked_class_names=["UnsafeClass"])
fory = Fory(ref=True, strict=False, policy=policy)
# Serialize and deserialize the class type itself (not an instance)
safe_data = fory.serialize(SafeClass)
result = fory.deserialize(safe_data)
assert result.__name__ == "SafeClass"
# Now test blocking
unsafe_data = fory.serialize(UnsafeClass)
with pytest.raises(ValueError, match="UnsafeClass is blocked"):
fory.deserialize(unsafe_data)
def test_block_reduce_call():
"""Test blocking callable invocations during reduce."""
class ReducibleClass:
def __init__(self, value):
self.value = value
def __reduce__(self):
return (ReducibleClass, (self.value,))
policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
fory = Fory(ref=True, strict=False, policy=policy)
data = fory.serialize(ReducibleClass(42))
with pytest.raises(ValueError, match="ReducibleClass is blocked"):
fory.deserialize(data)
def test_replace_reduced_object():
"""Test replacing objects created via __reduce__."""
class ReducibleClass:
def __init__(self, value):
self.value = value
def __reduce__(self):
return (ReducibleClass, (self.value,))
policy = ReplaceObjectPolicy(replacement_value="REPLACED")
fory = Fory(ref=True, strict=False, policy=policy)
data = fory.serialize(ReducibleClass(42))
result = fory.deserialize(data)
assert result == "REPLACED"
def test_sanitize_state():
"""Test sanitizing object state during setstate."""
class SecretHolder:
def __init__(self, username, password):
self.username = username
self.password = password
def __getstate__(self):
return {"username": self.username, "password": self.password}
def __setstate__(self, state):
self.__dict__.update(state)
policy = SanitizeStatePolicy()
fory = Fory(ref=False, strict=False, policy=policy)
data = fory.serialize(SecretHolder("admin", "secret123"))
result = fory.deserialize(data)
assert result.username == "admin"
assert result.password == "***REDACTED***"
def test_policy_with_local_class():
"""Test policy intercepts local class deserialization."""
def make_local_class():
class LocalClass:
pass
return LocalClass
LocalCls = make_local_class()
policy = BlockClassPolicy(blocked_class_names=["LocalClass"])
fory = Fory(ref=True, strict=False, policy=policy)
# Serialize the local class type
data = fory.serialize(LocalCls)
with pytest.raises(ValueError, match="LocalClass is blocked"):
fory.deserialize(data)
def test_policy_with_ref_tracking():
"""Test policy works with reference tracking."""
class ReducibleClass:
def __init__(self, value):
self.value = value
def __reduce__(self):
return (ReducibleClass, (self.value,))
policy = BlockReduceCallPolicy(blocked_names=["ReducibleClass"])
fory = Fory(ref=True, strict=False, policy=policy)
data = fory.serialize(ReducibleClass(42))
with pytest.raises(ValueError, match="ReducibleClass is blocked"):
fory.deserialize(data)
def test_policy_allows_safe_operations():
"""Test that policy doesn't interfere with safe built-in types."""
policy = BlockClassPolicy(blocked_class_names=[])
fory = Fory(ref=False, strict=False, policy=policy)
assert fory.deserialize(fory.serialize(42)) == 42
assert fory.deserialize(fory.serialize("test")) == "test"
assert fory.deserialize(fory.serialize([1, 2, 3])) == [1, 2, 3]
def test_multiple_policy_hooks():
"""Test policy with multiple hooks working together."""
class MultiHookPolicy(DeserializationPolicy):
def __init__(self):
self.hooks_called = []
def validate_class(self, cls, is_local, **kwargs):
self.hooks_called.append(("validate_class", cls.__name__))
return None
def intercept_reduce_call(self, callable_obj, args, **kwargs):
if hasattr(callable_obj, "__name__"):
self.hooks_called.append(("intercept_reduce_call", callable_obj.__name__))
return None
def inspect_reduced_object(self, obj, **kwargs):
self.hooks_called.append(("inspect_reduced_object", type(obj).__name__))
return None
class TestClass:
def __init__(self, value):
self.value = value
def __reduce__(self):
return (TestClass, (self.value,))
policy = MultiHookPolicy()
fory = Fory(ref=True, strict=False, policy=policy)
data = fory.serialize(TestClass(42))
result = fory.deserialize(data)
# All hooks should have been called
assert ("intercept_reduce_call", "TestClass") in policy.hooks_called
assert ("inspect_reduced_object", "TestClass") in policy.hooks_called
assert result.value == 42
def test_policy_with_nested_reduce():
"""Test policy handles nested objects with __reduce__."""
class Inner:
def __init__(self, value):
self.value = value
def __reduce__(self):
return (Inner, (self.value,))
class Outer:
def __init__(self, inner):
self.inner = inner
def __reduce__(self):
return (Outer, (self.inner,))
policy = BlockReduceCallPolicy(blocked_names=["Inner"])
fory = Fory(ref=True, strict=False, policy=policy)
data = fory.serialize(Outer(Inner(42)))
with pytest.raises(ValueError, match="Inner is blocked"):
fory.deserialize(data)
def test_validate_module():
"""Test validate_module policy hook for module deserialization."""
import json
import collections
# Test 1: Return module object directly
class ReturnModulePolicy(DeserializationPolicy):
def validate_module(self, module_name, **kwargs):
return collections
fory1 = Fory(ref=True, strict=False, policy=ReturnModulePolicy())
data = fory1.serialize(json)
assert fory1.deserialize(data) is collections
# Test 2: Return string to redirect import
class RedirectPolicy(DeserializationPolicy):
def validate_module(self, module_name, **kwargs):
return "collections" if module_name == "json" else None
fory2 = Fory(ref=True, strict=False, policy=RedirectPolicy())
assert fory2.deserialize(fory2.serialize(json)).__name__ == "collections"
# Test 3: Raise to block module
class BlockPolicy(DeserializationPolicy):
def validate_module(self, module_name, **kwargs):
raise ValueError(f"Module {module_name} blocked")
fory3 = Fory(ref=True, strict=False, policy=BlockPolicy())
with pytest.raises(ValueError, match="blocked"):
fory3.deserialize(fory3.serialize(json))