blob: cb37857513c61455b86b2aa65b72c99b171524b2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pyfory
# Global classes for testing global class method serialization
class GlobalTestClass:
"""Global test class for method serialization."""
class_variable = "global_class_value"
def __init__(self, value):
self.instance_value = value
def instance_method(self):
"""Instance method for testing."""
return f"instance_{self.instance_value}"
@classmethod
def class_method(cls):
"""Class method for testing."""
return f"class_{cls.class_variable}"
@classmethod
def class_method_with_args(cls, arg1, arg2):
"""Class method with arguments for testing."""
return f"class_{cls.class_variable}_{arg1}_{arg2}"
@staticmethod
def static_method():
"""Static method for testing."""
return "static_global_result"
@staticmethod
def static_method_with_args(arg1, arg2):
"""Static method with arguments for testing."""
return f"static_{arg1}_{arg2}"
class AnotherGlobalClass:
"""Another global class to test cross-class method serialization."""
@classmethod
def another_class_method(cls):
return f"another_{cls.__name__}"
class GlobalClassWithInheritance(GlobalTestClass):
"""Global class with inheritance."""
class_variable = "inherited_value"
@classmethod
def inherited_class_method(cls):
return f"inherited_{cls.class_variable}"
class TestMethodSerialization:
"""Test class for method serialization scenarios."""
def test_instance_method_serialization(self):
"""Test serialization of instance methods."""
fory = pyfory.Fory(strict=False, ref=True)
class TestClass:
def __init__(self, value):
self.value = value
def instance_method(self):
return self.value * 2
obj = TestClass(5)
method = obj.instance_method
# Test serialization/deserialization
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert method() == deserialized()
assert method() == 10
def test_classmethod_serialization(self):
"""Test serialization of class methods."""
fory = pyfory.Fory(strict=False, ref=True)
class TestClass:
class_var = 42
@classmethod
def class_method(cls):
return cls.class_var
method = TestClass.class_method
# Test serialization/deserialization
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert method() == deserialized()
assert method() == 42
def test_staticmethod_serialization(self):
"""Test serialization of static methods."""
fory = pyfory.Fory(strict=False, ref=True)
class TestClass:
@staticmethod
def static_method():
return "static_result"
method = TestClass.static_method
# Test serialization/deserialization
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert method() == deserialized()
assert method() == "static_result"
def test_method_with_args_serialization(self):
"""Test serialization of methods with arguments."""
fory = pyfory.Fory(strict=False, ref=True)
class TestClass:
def __init__(self, base):
self.base = base
def add(self, x):
return self.base + x
@classmethod
def multiply(cls, a, b):
return a * b
@staticmethod
def subtract(a, b):
return a - b
obj = TestClass(10)
# Test instance method
instance_method = obj.add
serialized = fory.serialize(instance_method)
deserialized = fory.deserialize(serialized)
assert instance_method(5) == deserialized(5)
assert instance_method(5) == 15
# Test classmethod
class_method = TestClass.multiply
serialized = fory.serialize(class_method)
deserialized = fory.deserialize(serialized)
assert class_method(3, 4) == deserialized(3, 4)
assert class_method(3, 4) == 12
# Test staticmethod
static_method = TestClass.subtract
serialized = fory.serialize(static_method)
deserialized = fory.deserialize(serialized)
assert static_method(10, 3) == deserialized(10, 3)
assert static_method(10, 3) == 7
def test_nested_class_method_serialization(self):
"""Test serialization of methods from nested classes."""
fory = pyfory.Fory(strict=False, ref=True)
class OuterClass:
class InnerClass:
@classmethod
def inner_class_method(cls):
return "inner_result"
method = OuterClass.InnerClass.inner_class_method
# Test serialization/deserialization
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert method() == deserialized()
assert method() == "inner_result"
def test_classmethod_serialization():
"""Standalone test for classmethod serialization - reproduces the original error."""
fory = pyfory.Fory(strict=False, ref=True)
class A:
@classmethod
def f(cls):
pass
@staticmethod
def g():
return A
method = A.f
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert isinstance(deserialized, type(method))
# Check that the class names are the same (the classes might be different instances due to deserialization)
assert deserialized.__self__.__name__ == method.__self__.__name__
assert deserialized.__func__.__name__ == method.__func__.__name__
# Most importantly, check that the deserialized method is callable and has the same behavior
# Both should return None for this test case
original_result = method()
deserialized_result = deserialized()
assert original_result == deserialized_result
def test_staticmethod_serialization():
"""Standalone test for staticmethod serialization."""
fory = pyfory.Fory(strict=False, ref=True)
class A:
@staticmethod
def g():
return "static_result"
method = A.g
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert method() == deserialized()
assert method() == "static_result"
# Global class method tests
def test_global_classmethod_serialization():
"""Test serialization of global class methods."""
fory = pyfory.Fory(strict=False, ref=True)
method = GlobalTestClass.class_method
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert isinstance(deserialized, type(method))
assert deserialized() == method()
assert deserialized() == "class_global_class_value"
def test_global_classmethod_with_args():
"""Test serialization of global class methods with arguments."""
fory = pyfory.Fory(strict=False, ref=True)
method = GlobalTestClass.class_method_with_args
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
args = ("arg1", "arg2")
assert deserialized(*args) == method(*args)
assert deserialized(*args) == "class_global_class_value_arg1_arg2"
def test_global_staticmethod_serialization():
"""Test serialization of global static methods."""
fory = pyfory.Fory(strict=False, ref=True)
method = GlobalTestClass.static_method
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert deserialized() == method()
assert deserialized() == "static_global_result"
def test_global_staticmethod_with_args():
"""Test serialization of global static methods with arguments."""
fory = pyfory.Fory(strict=False, ref=True)
method = GlobalTestClass.static_method_with_args
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
args = ("test1", "test2")
assert deserialized(*args) == method(*args)
assert deserialized(*args) == "static_test1_test2"
def test_global_instance_method_serialization():
"""Test serialization of global instance methods."""
fory = pyfory.Fory(strict=False, ref=True)
obj = GlobalTestClass("test_value")
method = obj.instance_method
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert deserialized() == method()
assert deserialized() == "instance_test_value"
def test_multiple_global_classes():
"""Test serialization of methods from multiple global classes."""
fory = pyfory.Fory(strict=False, ref=True)
# Test methods from different global classes
method1 = GlobalTestClass.class_method
method2 = AnotherGlobalClass.another_class_method
serialized1 = fory.serialize(method1)
serialized2 = fory.serialize(method2)
deserialized1 = fory.deserialize(serialized1)
deserialized2 = fory.deserialize(serialized2)
assert deserialized1() == method1()
assert deserialized2() == method2()
assert deserialized1() == "class_global_class_value"
assert deserialized2() == "another_AnotherGlobalClass"
def test_global_class_inheritance():
"""Test serialization of methods from global classes with inheritance."""
fory = pyfory.Fory(strict=False, ref=True)
# Test inherited class method
method = GlobalClassWithInheritance.inherited_class_method
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert deserialized() == method()
assert deserialized() == "inherited_inherited_value"
# Test parent class method on child class
parent_method = GlobalClassWithInheritance.class_method
serialized_parent = fory.serialize(parent_method)
deserialized_parent = fory.deserialize(serialized_parent)
assert deserialized_parent() == parent_method()
assert deserialized_parent() == "class_inherited_value" # Uses child's class_variable
def test_global_methods_without_ref_tracking():
"""Test serialization of global class methods without reference tracking."""
fory = pyfory.Fory(strict=False, ref=False)
# Global classes should work even without ref_tracking
method = GlobalTestClass.class_method
serialized = fory.serialize(method)
deserialized = fory.deserialize(serialized)
assert deserialized() == method()
assert deserialized() == "class_global_class_value"
def test_global_method_collection():
"""Test serialization of collections containing global methods."""
fory = pyfory.Fory(strict=False, ref=True)
methods = [
GlobalTestClass.class_method,
GlobalTestClass.static_method,
AnotherGlobalClass.another_class_method,
]
serialized = fory.serialize(methods)
deserialized = fory.deserialize(serialized)
assert len(deserialized) == len(methods)
for original, restored in zip(methods, deserialized):
assert original() == restored()
def test_global_method_in_dict():
"""Test serialization of dictionaries containing global methods."""
fory = pyfory.Fory(strict=False, ref=True)
method_dict = {
"class_method": GlobalTestClass.class_method,
"static_method": GlobalTestClass.static_method,
"another_method": AnotherGlobalClass.another_class_method,
}
serialized = fory.serialize(method_dict)
deserialized = fory.deserialize(serialized)
assert len(deserialized) == len(method_dict)
for key in method_dict:
assert method_dict[key]() == deserialized[key]()
if __name__ == "__main__":
# Run tests
import pytest
pytest.main([__file__, "-v"])