blob: 9feca197a4044b7e7edf410bf4b8e7b6191e3784 [file]
############################################################################
# SPDX-License-Identifier: Apache-2.0
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership. The
# ASF licenses this file to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance with the
# License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
############################################################################
"""NTFC collector plugin for pytest."""
import os
from typing import TYPE_CHECKING, Dict, List, Tuple
import pytest
from ntfc.log.logger import logger
from ntfc.pytest.collecteditem import CollectedItem
from ntfc.testfilter import FilterTest
if TYPE_CHECKING:
from ntfc.envconfig import EnvConfig
###############################################################################
# Class: CollectorPlugin
###############################################################################
class CollectorPlugin:
"""Custom Pytest collector plugin."""
def __init__(self, config: "EnvConfig", collectonly: bool = True) -> None:
"""Initialize custom pytest collector plugin."""
self._config = config
self._filter = FilterTest(config)
self._all_items: List[CollectedItem] = []
self._filtered_items: List[CollectedItem] = []
self._collectonly = collectonly
self._skipped_items: List[Tuple[pytest.Item, str]] = []
def _collected_item(self, item: pytest.Item) -> CollectedItem:
"""Create collected item."""
path, lineno, name = item.location
lineno = lineno or 0
abs_path = os.path.abspath(item.path)
directory = os.path.dirname(abs_path)
module = abs_path.replace(pytest.testroot, "")
root = module.replace(pytest.testroot, "")
ci = CollectedItem(
directory,
module,
name,
abs_path,
lineno,
item.nodeid,
pytest.ntfcyaml.get("module", "Unknown_"),
root,
)
return ci
@property
def skipped_items(self) -> List[Tuple[pytest.Item, str]]:
"""Get skipped items."""
return self._skipped_items
@property
def filtered(self) -> List[CollectedItem]:
"""Get filtered items."""
return self._filtered_items
@property
def allitems(self) -> List[CollectedItem]:
"""Get all items before filtration."""
return self._all_items
def pytest_runtestloop(self, session: pytest.Session) -> bool:
"""Run test loop.
Do not run tests if we are in collect only mode.
"""
if session.testsfailed: # pragma: no cover
raise session.Interrupted("error during collection")
# do not run test cases when in collect only mode
if self._collectonly:
return True
loops = self._config.common.get("loops", 1)
for _ in range(loops):
if loops > 1:
print("\n\n" + "=" * 100)
print("Loop:", _)
print("=" * 100)
for i, item in enumerate(session.items):
nextitem = (
session.items[i + 1]
if i + 1 < len(session.items)
else None
)
logger.debug(f"run test:{item}")
item.config.hook.pytest_runtest_protocol(
item=item, nextitem=nextitem
)
if session.shouldfail: # pragma: no cover
raise session.Failed(session.shouldfail)
if session.shouldstop: # pragma: no cover
raise session.Interrupted(session.shouldstop)
return True
def pytest_collection_finish(self, session: pytest.Session) -> None:
"""Pytest collection finish callback."""
def _filter_modules(
self, ci: CollectedItem, include: List[str], exclude: List[str]
) -> Tuple[bool, str]:
"""Filter modules based on include/exclude lists."""
if include and ci.module2 not in include:
return True, "not in include_module"
if exclude and ci.module2 in exclude:
return True, "excluded module"
return False, ""
def _order_items(
self, items: List[pytest.Item], order_map: Dict[str, int]
) -> List[pytest.Item]:
"""Order test items based on the order map."""
def sort_key(test_item: pytest.Item) -> Tuple[int, int]:
v = order_map.get(test_item._collected.module2)
if v is None:
return (1, 0)
if v > 0:
return (0, v)
# v < 0
return (2, v)
return sorted(items, key=sort_key)
def pytest_collection_modifyitems(
self,
config: pytest.Config,
items: list[pytest.Item], # pylint: disable=unused-argument
) -> None:
"""Modify the `items` list after collection is completed.
:param config:
:param items:
"""
tmp: List[pytest.Item] = []
module = pytest.cfgtest.get("module", {})
include_module = module.get("include_module", [])
exclude_module = module.get("exclude_module", [])
order_list = module.get("order", [])
order_map = {
e["module"]: int(e["value"])
for e in order_list
if e.get("module") and e.get("value") is not None
}
for item in items:
ci = self._collected_item(item)
item._collected = ci
self._all_items.append(ci)
skip, reason = self._filter.check_test_support(item)
if not skip:
skip, reason = self._filter_modules(
ci, include_module, exclude_module
)
if skip:
skip_reason = reason or "unknown reason"
self._skipped_items.append((item, skip_reason))
item.add_marker(pytest.mark.skip(reason=skip_reason))
continue
self._filtered_items.append(ci)
tmp.append(item)
if order_map:
tmp = self._order_items(tmp, order_map)
# Update filtered items list to match new order
self._filtered_items = [item._collected for item in tmp]
# overwrite items
items[:] = tmp