Source code for deluxe.importers

# Copyright (c) 2024 - Gilles Coissac
# This file is part of standard-deluxe library.
#
# standard-deluxe is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published
# by the Free Software Foundation, either version 3 of the License,
# or (at your option) any later version.
#
# standard-deluxe is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty
# of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
# See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with standard-deluxe. If not, see <https://www.gnu.org/licenses/>
#
# Parts of this module are borrowed to the code from Python 3.13
# test.support.import_helper module, which is not guaranted
# to be present in all python distribution and could be removed without
# notice between release of Python.
# Copyright (C) 2006 Python Software Foundation.
# vendored under the PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2
#
"""Module loading and import utilities.

This module provides tools for dynamic module loading, fresh imports that
bypass :data:`sys.modules`, monkey patching, and context managers for
isolating import side-effects.

.. note::

    Parts of this module are derived from the Python ``test.support.import_helper``
    module (Copyright (C) 2006 Python Software Foundation, licensed under
    the `PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2`_). That module is part
    of CPython's test suite and is not guaranteed to be available in all Python
    distributions or to remain stable between releases (actually it is already
    absent from some standalone Python build).
    Those vendored portions have been adapted to fit `standard-deluxe` public API:
    :class:`CleanImport`, :class:`DirsOnSysPath`, :func:`forget_module`, :func:`frozen_modules`,
    :func:`import_fresh_module`.

.. _PYTHON SOFTWARE FOUNDATION LICENSE VERSION 2:
    https://docs.python.org/3/license.html#psf-license
"""

from __future__ import annotations

import contextlib
import importlib.machinery
import importlib.util
import logging
import os
import sys
import time
import warnings
from _imp import (
    _override_frozen_modules_for_tests,  # pyright: ignore[reportAttributeAccessIssue, reportUnknownVariableType]
)
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Self, final

from deluxe.types import AnyFilePath, Unset


if TYPE_CHECKING:
    from pathlib import Path
    from types import ModuleType, TracebackType


__all__ = (
    "CleanImport",
    "DirsOnSysPath",
    "Module",
    "Patch",
    "Patchable",
    "forget_module",
    "frozen_modules",
    "import_fresh_module",
    "loads_module",
    "monkey",
)


__all__ = ("Module", "Patch", "Patchable", "loads_module", "monkey")


logger = logging.getLogger(__name__)

Patchable = object
"""Type alias for objects that can be the target of a monkey patch."""

Patch = Callable[..., object]
"""Type alias for monkey patch callables."""


def _waitfor(
    func: Callable[..., Any], pathname: AnyFilePath, waitall: bool = False
) -> None:  # pragma: no cover
    # Perform the operation
    func(pathname)
    # Now setup the wait loop
    if waitall:
        name = ""
        dirname = pathname
    else:
        dirname, name = os.path.split(pathname)
        dirname = dirname or "."
    # Check for `pathname` to be removed from the filesystem.
    # The exponential backoff of the timeout amounts to a total
    # of ~1 second after which the deletion is probably an error
    # anyway.
    # Testing on an i7@4.3GHz shows that usually only 1 iteration is
    # required when contention occurs.
    timeout = 0.001
    while timeout < 1.0:
        # Note we are only testing for the existence of the file(s) in
        # the contents of the directory regardless of any security or
        # access rights. If we have made it this far, we have sufficient
        # permissions to do that much using Python's equivalent of the
        # Windows API FindFirstFile.
        # Other Windows APIs can fail or give incorrect results when
        # dealing with files that are pending deletion.
        l_ = os.listdir(dirname)  # noqa: PTH208
        if not (l_ if waitall else name in l_):
            return
        # Increase the timeout and try again
        time.sleep(timeout)
        timeout *= 2


def _unlink(filename: AnyFilePath) -> None:  # pragma: no cover
    with contextlib.suppress(FileNotFoundError, NotADirectoryError):
        if sys.platform.startswith("win"):
            _waitfor(os.unlink, filename)
        else:
            os.unlink(filename)  # noqa: PTH108


@contextlib.contextmanager
def _ignore_deprecated_imports(ignore: bool = True):  # pragma: no cover
    """Context manager to suppress package and module deprecation warnings when importing them.

    If ignore is False, this context manager has no effect.
    """
    if ignore:
        with warnings.catch_warnings():
            warnings.filterwarnings(
                "ignore",
                ".+ (module|package)",
                DeprecationWarning,
            )
            yield
    else:
        yield


def _unload(name: str) -> None:  # pragma: no cover
    with contextlib.suppress(KeyError):
        del sys.modules[name]


[docs] def loads_module(name: str, where: Path) -> ModuleType | None: """Load a Python module or package from a filesystem path. This is a convenience wrapper around :class:`Module` that resolves and loads the module in a single call. Args: name (:obj:`str`): The fully qualified module name to load. where (:class:`~pathlib.Path`): The directory to search for the module in. Returns: :class:`~types.ModuleType` | ``None``: The loaded module if found and successfully imported, ``None`` otherwise. Raises: :exc:`ModuleNotFoundError`: If the module cannot be found at the given path. :exc:`ImportError`: If loading the module fails. """ mod = Module(name=name, where=where) mod.load() return mod.module
def _save_and_remove_modules(names: Iterable[str]): # pragma: no cover orig_modules: dict[str, ModuleType] = {} prefixes = tuple(name + "." for name in names) for modname in list(sys.modules): if modname in names or modname.startswith(prefixes): orig_modules[modname] = sys.modules.pop(modname) return orig_modules
[docs] def forget_module(name: str) -> None: # pragma: no cover """Forget a module was ever imported. This removes the module from :data:`sys.modules` and deletes any PEP 3147/488 or legacy ``.pyc`` cached bytecode files found along :data:`sys.path`. Args: name (:obj:`str`): The fully qualified module name to forget. """ _unload(name) for dirname in sys.path: source = os.path.join(dirname, name + ".py") # noqa: PTH118 # It doesn't matter if they exist or not, unlink all possible # combinations of PEP 3147/488 and legacy pyc files. _unlink(source + "c") for opt in ("", 1, 2): _unlink(importlib.util.cache_from_source(source, optimization=opt))
[docs] def import_fresh_module( name: str, fresh: Iterable[str] = (), blocked: Iterable[str] = (), *, deprecated: bool = False, usefrozen: bool = False, ) -> ModuleType | None: # pragma: no cover """Import and return a module, deliberately bypassing :data:`sys.modules`. This function imports and returns a fresh copy of the named Python module by removing the named module from :data:`sys.modules` before doing the import. Note that unlike :func:`importlib.reload`, the original module is not affected by this operation. The named module and any modules named in the *fresh* and *blocked* parameters are saved before starting the import and then reinserted into :data:`sys.modules` when the fresh import is complete. Args: name (:obj:`str`): The fully qualified module name to import. fresh (Iterable[:obj:`str`]): Additional module names that are also removed from the :data:`sys.modules` cache before the import. If one of these modules cannot be imported, ``None`` is returned. Default: ``()``. blocked (Iterable[:obj:`str`]): Module names that are replaced with ``None`` in the module cache during the import to ensure that attempts to import them raise :exc:`ImportError`. Default: ``()``. deprecated (:obj:`bool`): If ``True``, module and package deprecation messages are suppressed during this import. Default: ``False``. usefrozen (:obj:`bool`): If ``False`` (the default), the frozen importer is disabled (except for essential modules like ``importlib._bootstrap``). Default: ``False``. Returns: :class:`~types.ModuleType` | ``None``: The freshly imported module, or ``None`` if one of the *fresh* modules could not be imported. Raises: ImportError: If the named module cannot be imported. """ # noqa: DOC502 # NOTE: test_heapq, test_json and test_warnings include extra sanity checks # to make sure that this utility function is working as expected with _ignore_deprecated_imports(deprecated): # Keep track of modules saved for later restoration as well # as those which just need a blocking entry removed fresh = list(fresh) blocked = list(blocked) names = {name, *fresh, *blocked} orig_modules = _save_and_remove_modules(names) for modname in blocked: sys.modules[modname] = None # pyright: ignore[reportArgumentType] try: with frozen_modules(usefrozen): # Return None when one of the "fresh" modules can not be imported. try: for modname in fresh: __import__(modname) except ImportError: return None return importlib.import_module(name) finally: _save_and_remove_modules(names) sys.modules.update(orig_modules)
[docs] @contextlib.contextmanager def frozen_modules(enabled: bool = True): # pragma: no cover """Force frozen modules to be used (or not). This context manager controls whether the Python importer uses precompiled frozen modules. When disabled, the standard source-based import machinery is used instead. This only applies to modules that haven't been imported yet. Some essential modules (e.g. ``importlib._bootstrap``) will always be imported frozen regardless of this setting. Args: enabled (:obj:`bool`): If ``True``, frozen modules are enabled. If ``False``, frozen modules are disabled. Default: ``True``. Yields: ``None``: This context manager yields nothing. """ _override_frozen_modules_for_tests(1 if enabled else -1) try: yield finally: _override_frozen_modules_for_tests(0)
[docs] class CleanImport: # pragma: no cover """Context manager to force import to return a new module reference. This is useful for testing module-level behaviors, such as the emission of a :exc:`DeprecationWarning` on import. When entered, the named modules are removed from :data:`sys.modules` so that subsequent imports return fresh references. On exit, the original :data:`sys.modules` state is restored. Args: *module_names (:obj:`str`): One or more fully qualified module names to remove from :data:`sys.modules`. usefrozen (:obj:`bool`): If ``False`` (the default), the frozen importer is disabled (except for essential modules like ``importlib._bootstrap``). Default: ``False``. Examples:: with CleanImport("foo"): importlib.import_module("foo") # fresh import """ def __init__(self, *module_names: str, usefrozen: bool = False) -> None: self.original_modules: dict[str, ModuleType] = sys.modules.copy() for module_name in module_names: if module_name in sys.modules: module = sys.modules[module_name] # It is possible that module_name is just an alias for # another module (e.g. stub for modules renamed in 3.x). # In that case, we also need delete the real module to clear # the import cache. if module.__name__ != module_name: del sys.modules[module.__name__] del sys.modules[module_name] self._frozen_modules = frozen_modules(usefrozen) # pyright: ignore[reportUnannotatedClassAttribute] def __enter__(self) -> Self: self._frozen_modules.__enter__() return self def __exit__( self, t: type[BaseException] | None, i: BaseException | None, tb: TracebackType | None ) -> None: sys.modules.update(self.original_modules) self._frozen_modules.__exit__(t, i, tb)
[docs] class DirsOnSysPath: # pragma: no cover """Context manager to temporarily add directories to :data:`sys.path`. This makes a copy of :data:`sys.path`, appends any directories given as positional arguments, then reverts :data:`sys.path` to the copied settings when the context ends. Note that *all* :data:`sys.path` modifications in the body of the context manager, including replacement of the object, will be reverted at the end of the block. Args: *paths (:obj:`str`): Directory paths to temporarily add to :data:`sys.path`. Examples:: with DirsOnSysPath("/tmp/my_modules"): import my_module """ def __init__(self, *paths: str) -> None: self.original_value: list[str] = sys.path[:] self.original_object: list[str] = sys.path sys.path.extend(paths) def __enter__(self) -> Self: return self def __exit__( self, t: type[BaseException] | None, i: BaseException | None, tb: TracebackType | None ) -> None: sys.path = self.original_object sys.path[:] = self.original_value
[docs] class Module: """Utility class for dynamic loading of a Python module. This class wraps the :mod:`importlib` machinery to resolve a module specification and load it from an explicit filesystem location or the default search path. It tracks the module's package hierarchy and provides helper methods for comparing module names. Args: name (:obj:`str`): The module name to resolve. Can be a relative name (e.g. ``"foo"``) which is resolved against *package*. package (:obj:`str`): The package context for resolving relative module names. Default: ``""``. where (:class:`~pathlib.Path` | ``None``): An explicit directory to search for the module. If ``None``, the standard :func:`importlib.util.find_spec` search is used. Default: ``None``. Raises: :exc:`ModuleNotFoundError`: If the module cannot be found. Examples:: mod = Module("my_package.my_module") mod.load() print(mod.full_name) # "my_package.my_module" """ __slots__: tuple[str, ...] = ("_is_pkg", "_is_root", "_name", "_pkg", "_spec") def __init__(self, name: str, *, package: str = "", where: Path | None = None) -> None: abs_name = importlib.util.resolve_name(name, package) self._spec: importlib.machinery.ModuleSpec = self._find_spec(abs_name, where) self._is_pkg: bool = self._spec.submodule_search_locations is not None self._pkg: str self._name: str self._pkg, _, self._name = abs_name.rpartition(".") self._is_root: bool = self._is_pkg and not self._pkg @staticmethod def _find_spec(name: str, where: Path | None) -> importlib.machinery.ModuleSpec: if where is None: spec = importlib.util.find_spec(name=name) else: loader_details = [ (importlib.machinery.SourceFileLoader, importlib.machinery.SOURCE_SUFFIXES), (importlib.machinery.SourcelessFileLoader, importlib.machinery.BYTECODE_SUFFIXES), (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES), ] finder = importlib.machinery.FileFinder(str(where), *loader_details) spec = finder.find_spec(name) logger.debug("loading module '%s' from '%s'", name, where) if not spec: raise ModuleNotFoundError return spec
[docs] def load(self) -> None: """Load this module if not already loaded. The module is executed and registered in :data:`sys.modules`. If the module is a submodule, it is also set as an attribute on its parent package. Raises: ModuleNotFoundError: If the module specification cannot be found. ImportError: If loading the module fails. """ if self.module: return if (module := importlib.util.module_from_spec(self._spec)) and self._spec.loader: try: sys.modules[str(self)] = module self._spec.loader.exec_module(module) if self._spec.parent: setattr(sys.modules[self._spec.parent], self._name, module) except FileNotFoundError as e: raise ModuleNotFoundError from e else: return raise ImportError
@property def pkg(self) -> str: """Returns the package name this module belongs to.""" return self._pkg @property def root(self) -> str: """Returns the top package name this module belongs to.""" return self._name if self._is_root else self._pkg.split(".")[0] @property def name(self) -> str: """Returns the name of this module without any prefixes.""" return self._name @property def full_name(self) -> str: """Returns the full name of this module.""" return f"{self._pkg}.{self.name}" if self._pkg else self.name @property def module(self) -> ModuleType | None: """Return the actual Module if loaded or None otherwise.""" return sys.modules.get(str(self), None) @property def is_package(self) -> bool: """Returns True if this module is also a package.""" return self._is_pkg @property def is_root(self) -> bool: """Returns True if this module is a top package.""" return self._is_root
[docs] def prefix_of(self, other: str) -> bool: """Check if this module is a package prefix of another module name. Returns ``True`` only if this module is a package and *other* is a direct submodule of it (not the package itself). Args: other (:obj:`str`): The module name to test against. Returns: :obj:`bool`: ``True`` if *other* is a submodule of this package. """ prefix = f"{self!s}." return other != prefix and other.startswith(prefix) if self._is_pkg else False
[docs] def share_root(self, other: str) -> bool: """Check if this module and another name share a common root package. Two module names share a root if their first dotted component is the same and nonempty. Args: other (:obj:`str`): The module name to compare against. Returns: :obj:`bool`: ``True`` if both names share the same top-level package prefix. """ return other.split(".", maxsplit=1)[0] == self.root if self.root else False
def __str__(self) -> str: return self.full_name def __repr__(self) -> str: return f"{self.__class__.__name__}({self!s})" def __hash__(self) -> int: return hash((self.pkg, self.name)) def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) and hash(other) == hash(self)
[docs] @final class monkey: # noqa: N801 """Decorator-based monkey patcher for module attributes. This class provides a declarative way to register monkey patches. Each patch is defined by decorating a replacement function with a :class:`monkey` instance. Patches are stored globally and applied together via :meth:`apply_all`. When a patch is applied, the original attribute value is saved and can be retrieved later via :meth:`target`. Already-loaded modules that depend on the patched module are automatically marked for reload. Protected modules (``sys``, ``builtins``, ``importlib``, ``importlib.util``, ``__main__``) cannot be patched or reloaded. Args: module (:obj:`str`): The fully qualified name of the module containing the attribute to patch. target (:obj:`str`): The name of the attribute to replace. Examples:: @monkey(module="os.path", target="join") def patched_join(*args): return "patched" monkey.apply_all() """ _protected: ClassVar[set[str]] = {"sys", "builtins", "importlib", "importlib.util", "__main__"} _to_reload: ClassVar[set[str]] = set() _patches: ClassVar[dict[str, monkey]] = {} __slots__ = ("_module", "_origin", "_patch", "_target") def __init__(self, *, module: str, target: str): self._module: Module = Module(module) self._target: str = target self._patch: Patch = monkey._null_patch self._origin: Patchable = Unset monkey._patches[str(self)] = self def __call__(self, patch: Patch) -> Patch: # noqa: D102 self._patch = patch self._mark_modules() return patch def __str__(self) -> str: return f"{self._module}.{self._target}" def __repr__(self) -> str: return f"{self.__class__.__name__}({self!s})"
[docs] @classmethod def patches(cls) -> list[str]: """Return a list of all registered patched target names. Each name has the format ``"module.target"``. Returns: list[:obj:`str`]: The registered patch target names. """ return [str(k) for k in monkey._patches]
[docs] @classmethod def apply_all(cls) -> None: """Apply all registered patches and reload affected modules. Each patch is applied at most once. After all patches are applied, any modules marked for reload are re-imported. Raises: :exc:`RuntimeError`: If a protected module is encountered during the reload phase. """ for patch in monkey._patches.values(): if patch._origin is Unset: patch._apply() monkey._reload_modules()
[docs] @classmethod def target(cls, name: str) -> Patchable: """Return the original unpatched target of a registered patch. Args: name (:obj:`str`): The patch target name in the format ``"module.target"``. Returns: :class:`Patchable`: The original attribute value before patching. Raises: RuntimeError: If called before any patch was applied. KeyError: If *name* is not a registered patch. """ try: patch = monkey._patches[name] except KeyError as e: msg = f"{name} is not a known monkey patch." raise KeyError(msg) from e else: if patch._origin is Unset: msg = f"target for {patch!r} is not yet available, call monkey.apply_all() before." raise RuntimeError(msg) return patch._origin
[docs] @classmethod def marks_modules(cls, *modules: str) -> None: """Mark module names for explicit reload during :meth:`apply_all`. This is useful when a module's behavior depends on a patched dependency but is not automatically detected by the patching mechanism. Args: *modules (:obj:`str`): Fully qualified module names to mark. Raises: ValueError: If a module is in the protected list (``sys``, ``builtins``, ``importlib``, ``importlib.util``, or ``__main__``). """ for mod in modules: if mod in monkey._protected: msg = f"{mod} belongs to protected module list, can't mark it" raise ValueError(msg) cls._to_reload.add(mod)
def _apply(self) -> None: self._module.load() self._origin = getattr(self._module.module, self._target) setattr(self._module.module, self._target, self._patch) def _mark_modules(self) -> None: for mod_name in sys.modules: if mod_name == self._module.full_name: continue if self._module.prefix_of(mod_name) or self._module.share_root(mod_name): monkey._to_reload.add(mod_name) @classmethod def _reload_modules(cls) -> None: while monkey._to_reload: module = monkey._to_reload.pop() if module in monkey._protected: monkey._to_reload.clear() msg = f"{module} belongs to protected module list, can't reload it" raise RuntimeError(msg) importlib.reload(sys.modules[module]) @staticmethod def _null_patch(*_a: Any, **_k: Any) -> Any: raise NotImplementedError