blob: 33338bf54f1789894886973570e6c0e87c056053 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Ensure exposed pyclasses default to frozen."""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Iterator
PYCLASS_RE = re.compile(
r"#\[\s*pyclass\s*(?:\((?P<args>.*?)\))?\s*\]",
re.DOTALL,
)
ARG_STRING_RE = re.compile(
r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
)
STRUCT_NAME_RE = re.compile(
r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
)
@dataclass
class PyClass:
module: str
name: str
frozen: bool
source: Path
def iter_pyclasses(root: Path) -> Iterator[PyClass]:
for path in root.rglob("*.rs"):
text = path.read_text(encoding="utf8")
for match in PYCLASS_RE.finditer(text):
args = match.group("args") or ""
frozen = re.search(r"\bfrozen\b", args) is not None
module = None
name = None
for arg_match in ARG_STRING_RE.finditer(args):
key = arg_match.group("key")
value = arg_match.group("value")
if key == "module":
module = value
elif key == "name":
name = value
remainder = text[match.end() :]
struct_match = STRUCT_NAME_RE.search(remainder)
struct_name = struct_match.group("name") if struct_match else None
yield PyClass(
module=module or "datafusion",
name=name or struct_name or "<unknown>",
frozen=frozen,
source=path,
)
def test_pyclasses_are_frozen() -> None:
allowlist = {
# NOTE: Any new exceptions must include a justification comment
# in the Rust source and, ideally, a follow-up issue to remove
# the exemption.
("datafusion.common", "SqlTable"),
("datafusion.common", "SqlView"),
("datafusion.common", "DataTypeMap"),
("datafusion.expr", "TryCast"),
("datafusion.expr", "WriteOp"),
}
unfrozen = [
pyclass
for pyclass in iter_pyclasses(Path("src"))
if not pyclass.frozen and (pyclass.module, pyclass.name) not in allowlist
]
if unfrozen:
msg = (
"Found pyclasses missing `frozen`; add them to the allowlist only "
"with a justification comment and follow-up plan:\n"
)
msg += "\n".join(
(f"- {pyclass.module}.{pyclass.name} (defined in {pyclass.source})")
for pyclass in unfrozen
)
assert not unfrozen, msg