| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """conftest.py contains configuration for pytest.""" |
| |
| import gc |
| import platform |
| |
| import mxnet as mx |
| import pytest |
| |
| |
| @pytest.fixture(autouse=True) |
| def check_leak_ndarray(request): |
| garbage_expected = request.node.get_closest_marker('garbage_expected') |
| if garbage_expected: # Some tests leak references. They should be fixed. |
| yield # run test |
| return |
| |
| if 'centos' in platform.platform(): |
| # Multiple tests are failing due to reference leaks on CentOS. It's not |
| # yet known why there are more memory leaks in the Python 3.6.9 version |
| # shipped on CentOS compared to the Python 3.6.9 version shipped in |
| # Ubuntu. |
| yield |
| return |
| |
| del gc.garbage[:] |
| # Collect garbage prior to running the next test |
| gc.collect() |
| # Enable gc debug mode to check if the test leaks any arrays |
| gc_flags = gc.get_debug() |
| gc.set_debug(gc.DEBUG_SAVEALL) |
| |
| # Run the test |
| yield |
| |
| # Check for leaked NDArrays |
| gc.collect() |
| gc.set_debug(gc_flags) # reset gc flags |
| |
| seen = set() |
| def has_array(element): |
| try: |
| if element in seen: |
| return False |
| seen.add(element) |
| except (TypeError, ValueError, NotImplementedError): # unhashable |
| pass |
| |
| if isinstance(element, mx.nd._internal.NDArrayBase): |
| return element._alive # We only care about catching NDArray's that haven't been freed in the backend yet |
| elif isinstance(element, mx.sym._internal.SymbolBase): |
| return False |
| elif hasattr(element, '__dict__'): |
| return any(has_array(x) for x in vars(element)) |
| elif isinstance(element, dict): |
| return any(has_array(x) for x in element.items()) |
| else: |
| try: |
| return any(has_array(x) for x in element) |
| except (TypeError, KeyError, RecursionError): |
| return False |
| |
| assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays due to reference cycles' |
| del gc.garbage[:] |