| """Parallel building utilities.""" |
| |
| from __future__ import annotations |
| |
| import os |
| import time |
| import traceback |
| from math import sqrt |
| from typing import TYPE_CHECKING, Any, Callable |
| |
| try: |
| import multiprocessing |
| HAS_MULTIPROCESSING = True |
| except ImportError: |
| HAS_MULTIPROCESSING = False |
| |
| from sphinx.errors import SphinxParallelError |
| from sphinx.util import logging |
| |
| if TYPE_CHECKING: |
| from collections.abc import Sequence |
| |
| logger = logging.getLogger(__name__) |
| |
| # our parallel functionality only works for the forking Process |
| parallel_available = multiprocessing and os.name == 'posix' |
| |
| |
| class SerialTasks: |
| """Has the same interface as ParallelTasks, but executes tasks directly.""" |
| |
| def __init__(self, nproc: int = 1) -> None: |
| pass |
| |
| def add_task( |
| self, task_func: Callable, arg: Any = None, result_func: Callable | None = None, |
| ) -> None: |
| if arg is not None: |
| res = task_func(arg) |
| else: |
| res = task_func() |
| if result_func: |
| result_func(res) |
| |
| def join(self) -> None: |
| pass |
| |
| |
| class ParallelTasks: |
| """Executes *nproc* tasks in parallel after forking.""" |
| |
| def __init__(self, nproc: int) -> None: |
| self.nproc = nproc |
| # (optional) function performed by each task on the result of main task |
| self._result_funcs: dict[int, Callable] = {} |
| # task arguments |
| self._args: dict[int, list[Any] | None] = {} |
| # list of subprocesses (both started and waiting) |
| self._procs: dict[int, Any] = {} |
| # list of receiving pipe connections of running subprocesses |
| self._precvs: dict[int, Any] = {} |
| # list of receiving pipe connections of waiting subprocesses |
| self._precvsWaiting: dict[int, Any] = {} |
| # number of working subprocesses |
| self._pworking = 0 |
| # task number of each subprocess |
| self._taskid = 0 |
| |
| def _process(self, pipe: Any, func: Callable, arg: Any) -> None: |
| try: |
| collector = logging.LogCollector() |
| with collector.collect(): |
| if arg is None: |
| ret = func() |
| else: |
| ret = func(arg) |
| failed = False |
| except BaseException as err: |
| failed = True |
| errmsg = traceback.format_exception_only(err.__class__, err)[0].strip() |
| ret = (errmsg, traceback.format_exc()) |
| logging.convert_serializable(collector.logs) |
| pipe.send((failed, collector.logs, ret)) |
| |
| def add_task( |
| self, task_func: Callable, arg: Any = None, result_func: Callable | None = None, |
| ) -> None: |
| tid = self._taskid |
| self._taskid += 1 |
| self._result_funcs[tid] = result_func or (lambda arg, result: None) |
| self._args[tid] = arg |
| precv, psend = multiprocessing.Pipe(False) |
| context: Any = multiprocessing.get_context('fork') |
| proc = context.Process(target=self._process, args=(psend, task_func, arg)) |
| self._procs[tid] = proc |
| self._precvsWaiting[tid] = precv |
| self._join_one() |
| |
| def join(self) -> None: |
| try: |
| while self._pworking: |
| if not self._join_one(): |
| time.sleep(0.02) |
| finally: |
| # shutdown other child processes on failure |
| self.terminate() |
| |
| def terminate(self) -> None: |
| for tid in list(self._precvs): |
| self._procs[tid].terminate() |
| self._result_funcs.pop(tid) |
| self._procs.pop(tid) |
| self._precvs.pop(tid) |
| self._pworking -= 1 |
| |
| def _join_one(self) -> bool: |
| joined_any = False |
| for tid, pipe in self._precvs.items(): |
| if pipe.poll(): |
| exc, logs, result = pipe.recv() |
| if exc: |
| raise SphinxParallelError(*result) |
| for log in logs: |
| logger.handle(log) |
| self._result_funcs.pop(tid)(self._args.pop(tid), result) |
| self._procs[tid].join() |
| self._precvs.pop(tid) |
| self._pworking -= 1 |
| joined_any = True |
| break |
| |
| while self._precvsWaiting and self._pworking < self.nproc: |
| newtid, newprecv = self._precvsWaiting.popitem() |
| self._precvs[newtid] = newprecv |
| self._procs[newtid].start() |
| self._pworking += 1 |
| |
| return joined_any |
| |
| |
| def make_chunks(arguments: Sequence[str], nproc: int, maxbatch: int = 10) -> list[Any]: |
| # determine how many documents to read in one go |
| nargs = len(arguments) |
| chunksize = nargs // nproc |
| if chunksize >= maxbatch: |
| # try to improve batch size vs. number of batches |
| chunksize = int(sqrt(nargs / nproc * maxbatch)) |
| if chunksize == 0: |
| chunksize = 1 |
| nchunks, rest = divmod(nargs, chunksize) |
| if rest: |
| nchunks += 1 |
| # partition documents in "chunks" that will be written by one Process |
| return [arguments[i * chunksize:(i + 1) * chunksize] for i in range(nchunks)] |