blob: 700b2baa2eaf0534b269e5421ebc2e4beb722eec [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Test max_collection_size and max_binary_size guardrails to prevent OOM attacks
from malicious payloads.
Collections preallocate memory based on declared size, so they need guardrails.
Binary reads are guarded by max_binary_size on the Buffer.
"""
from dataclasses import dataclass
from typing import List
import pytest
import pyfory
from pyfory import Fory
from pyfory.serialization import Buffer
def roundtrip(data, limit, xlang=False, ref=False):
"""Serialize and deserialize with given collection size limit."""
writer = Fory(xlang=xlang, ref=ref)
reader = Fory(xlang=xlang, ref=ref, max_collection_size=limit)
return reader.deserialize(writer.serialize(data))
def roundtrip_binary(data, max_binary_size, xlang=False, ref=False):
"""Serialize and deserialize with given binary size limit."""
writer = Fory(xlang=xlang, ref=ref)
reader = Fory(xlang=xlang, ref=ref, max_binary_size=max_binary_size)
return reader.deserialize(writer.serialize(data))
class TestCollectionSizeLimit:
"""Collections (list/set/dict) preallocate memory, so need size limits."""
@pytest.mark.parametrize("xlang", [False, True])
@pytest.mark.parametrize(
"data,limit",
[
([1, 2, 3], 10), # list within limit
({1, 2, 3}, 10), # set within limit
({"a": 1}, 10), # dict within limit
([], 0), # empty list ok
(set(), 0), # empty set ok
({}, 0), # empty dict ok
],
)
def test_within_limit_succeeds(self, xlang, data, limit):
assert roundtrip(data, limit, xlang=xlang) == data
@pytest.mark.parametrize("xlang", [False, True])
@pytest.mark.parametrize(
"data,limit",
[
(list(range(10)), 5), # list exceeds
(set(range(10)), 5), # set exceeds
({str(i): i for i in range(10)}, 5), # dict exceeds
([[1], list(range(10))], 5), # nested inner exceeds
],
)
def test_exceeds_limit_fails(self, xlang, data, limit):
with pytest.raises(ValueError, match="exceeds the configured limit"):
roundtrip(data, limit, xlang=xlang)
@pytest.mark.parametrize("ref", [False, True])
@pytest.mark.parametrize(
"data,limit,should_fail",
[
((1, 2, 3), 10, False),
(tuple(range(10)), 5, True),
],
)
def test_tuple_limit(self, ref, data, limit, should_fail):
"""Tuple only works in xlang=False mode."""
if should_fail:
with pytest.raises(ValueError, match="exceeds the configured limit"):
roundtrip(data, limit, xlang=False, ref=ref)
else:
assert roundtrip(data, limit, xlang=False, ref=ref) == data
def test_default_limit_is_one_million(self):
assert Fory().max_collection_size == 1_000_000
def test_dataclass_list_field_exceeds_limit(self):
@dataclass
class Container:
items: List[pyfory.int32]
writer = Fory(xlang=True)
writer.register(Container)
reader = Fory(xlang=True, max_collection_size=5)
reader.register(Container)
with pytest.raises(ValueError, match="exceeds the configured limit"):
reader.deserialize(writer.serialize(Container(items=list(range(10)))))
class TestBinarySizeLimit:
"""Binary reads are guarded by max_binary_size on the Buffer."""
def test_default_limit_is_64mib(self):
assert Fory().max_binary_size == 64 * 1024 * 1024
@pytest.mark.parametrize("xlang", [False, True])
def test_within_limit_succeeds(self, xlang):
assert roundtrip_binary(b"x" * 100, max_binary_size=1024, xlang=xlang) == b"x" * 100
@pytest.mark.parametrize("xlang", [False, True])
def test_exceeds_limit_fails(self, xlang):
with pytest.raises(ValueError, match="exceeds the configured limit"):
roundtrip_binary(b"x" * 200, max_binary_size=100, xlang=xlang)
def test_from_stream_respects_limit(self):
import io
payload = Fory().serialize(b"x" * 200)
buf = Buffer.from_stream(io.BytesIO(payload), max_binary_size=100)
with pytest.raises(ValueError, match="exceeds the configured limit"):
Fory(max_binary_size=100).deserialize(buf)