blob: a77809da6a4fbf2f527910ce7601ea06092da785 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""conftest.py contains configuration for pytest."""
import gc
import platform
import mxnet as mx
import pytest
@pytest.fixture(autouse=True)
def check_leak_ndarray(request):
garbage_expected = request.node.get_closest_marker('garbage_expected')
if garbage_expected: # Some tests leak references. They should be fixed.
yield # run test
return
if 'centos' in platform.platform():
# Multiple tests are failing due to reference leaks on CentOS. It's not
# yet known why there are more memory leaks in the Python 3.6.9 version
# shipped on CentOS compared to the Python 3.6.9 version shipped in
# Ubuntu.
yield
return
del gc.garbage[:]
# Collect garbage prior to running the next test
gc.collect()
# Enable gc debug mode to check if the test leaks any arrays
gc_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
# Run the test
yield
# Check for leaked NDArrays
gc.collect()
gc.set_debug(gc_flags) # reset gc flags
seen = set()
def has_array(element):
try:
if element in seen:
return False
seen.add(element)
except (TypeError, ValueError, NotImplementedError): # unhashable
pass
if isinstance(element, mx.nd._internal.NDArrayBase):
return element._alive # We only care about catching NDArray's that haven't been freed in the backend yet
elif isinstance(element, mx.sym._internal.SymbolBase):
return False
elif hasattr(element, '__dict__'):
return any(has_array(x) for x in vars(element))
elif isinstance(element, dict):
return any(has_array(x) for x in element.items())
else:
try:
return any(has_array(x) for x in element)
except (TypeError, KeyError, RecursionError):
return False
assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles'
del gc.garbage[:]