Source code for deal._imports

from __future__ import annotations

import ast
import sys
from types import ModuleType
from typing import Any, Callable

from _frozen_importlib_external import PathFinder  # pyright: reportMissingImports=false

import deal

from ._state import state

def get_name(expr) -> str | None:
    if isinstance(expr, ast.Name):
    if isinstance(expr, ast.Attribute):
        left = get_name(expr.value)
        if left is None:
            return None
        return left + '.' + expr.attr
    return None

class DealFinder(PathFinder):
    def find_spec(cls, *args, **kwargs):
        spec = super().find_spec(*args, **kwargs)
        if spec is not None:  # pragma: no cover
            spec.loader = DealLoader(spec.loader)
        return spec

class DealLoader:
    __slots__ = ('_loader', )

    def __init__(self, loader) -> None:
        self._loader = loader

    def __getattr__(self, name: str):
        return getattr(self._loader, name)

    def exec_module(self, module: ModuleType) -> None:
        if not hasattr(self._loader, 'get_source'):
            return self._loader.exec_module(module)

        # get nodes with module-level contracts from the source code
        source = self._loader.get_source(module.__name__)
        if source is None:
            return self._loader.exec_module(module)
        tree = ast.parse(source)
        nodes = self._get_contracts(tree=tree)
        if not nodes:
            return self._loader.exec_module(module)

        # convert contracts nodes into real contracts
        contracts = []
        for node in nodes:
            contract = self._exec_contract(node=node)
            if contract is None:
                msg = f'unsupported contract: {ast.dump(node)}'
                raise RuntimeError(msg)

        # execute module with contracts
        wrapped = deal.chain(*contracts)(self._loader.exec_module)

    def _get_contracts(tree: ast.Module) -> list[ast.expr]:
        for node in tree.body:  # type: Any
            if type(node) is not ast.Expr:
            if type(node.value) is not ast.Call:
            if get_name(node.value.func) != 'deal.module_load':
            return node.value.args
        return []

    def _exec_contract(cls, node: ast.AST) -> Callable | None:
        """Get AST node and return a contract function
        if isinstance(node, ast.Call) and not node.keywords:
                args = [ast.literal_eval(arg) for arg in node.args]
            except ValueError:
                return None
            func = cls._exec_contract(node.func)
            if not func:
                return None
            return func(*args)

        if not isinstance(node, ast.Attribute):
            return None
        if != 'deal':
            return None
        contract = getattr(deal, node.attr, None)
        if contract is None:
            return None
        return contract

[docs]def module_load(*contracts) -> None: """ Specify contracts that will be checked at module import time. Keep in mind that [deal.activate](#deal.activate) must be called before importing a module with `module_load` contract. ```pycon >>> import deal >>> deal.module_load(deal.has(), ``` See [Contracts for importing modules][module_load] documentation for more details. [module_load]: """ if not state.debug: return if not contracts: raise RuntimeError('no contracts specified') if DealFinder not in sys.meta_path: msg = 'deal.activate must be called ' msg += 'before importing anything with deal.module_load contract' raise RuntimeError(msg)
[docs]def activate() -> bool: """Activate module-level checks. This function must be called before importing anything with [deal.module_load](#deal.module_load) contract. Otherwise, the contract won't be executed. The best practice is to place it in `` of your project: ```pycon >>> import deal >>> deal.activate() ``` See [Contracts for importing modules][module_load] documentation for more details. [module_load]: """ if not state.debug: return False if DealFinder in sys.meta_path: return False index = sys.meta_path.index(PathFinder) sys.meta_path[index] = DealFinder return True
def deactivate() -> bool: """used in tests """ if DealFinder not in sys.meta_path: return False index = sys.meta_path.index(DealFinder) sys.meta_path[index] = PathFinder return True