Source code for deal._cli._test

from __future__ import annotations

import re
import sys
import traceback
from argparse import ArgumentParser
from contextlib import contextmanager
from functools import update_wrapper
from importlib import import_module
from pathlib import Path
from textwrap import indent
from typing import Iterable, Iterator, TextIO, TypeVar

from .._colors import COLORS
from .._testing import TestCase, cases
from .._trace import TraceResult, format_lines, trace
from ..linter._contract import Category
from ..linter._extractors.pre import format_call_args
from ..linter._func import Func
from ._base import Command
from ._common import get_paths


try:
    import pygments
except ImportError:
    pygments = None
else:
    from pygments.formatters import TerminalFormatter
    from pygments.lexers import PythonTracebackLexer


T = TypeVar('T')

rex_exception = re.compile(r'deal\.(\w*ContractError)')


@contextmanager
def sys_path(path: Path):
    path = str(path)
    sys.path.insert(0, path)
    try:
        yield
    finally:
        if sys.path[0] == path:
            del sys.path[0]


def has_pure_contract(func: Func) -> bool:
    for contract in func.contracts:
        if contract.category == Category.PURE:
            return True
        if contract.category == Category.HAS and not contract.args:
            return True
    return False


def get_func_names(path: Path) -> Iterator[str]:
    text = path.read_text()
    for func in Func.from_text(text):
        if has_pure_contract(func):
            yield func.name


def color_exception(text: str) -> str:
    text = rex_exception.sub(r'\1', text)
    if pygments is None:  # pragma: no cover
        return text
    return pygments.highlight(
        code=text,
        lexer=PythonTracebackLexer(),
        formatter=TerminalFormatter(),
    )


def format_exception() -> str:
    lines = traceback.format_exception(*sys.exc_info())
    text = color_exception(''.join(lines))
    text = indent(text=text, prefix='    ')
    return text.rstrip()


def fast_iterator(items: Iterable[T]) -> Iterator[T]:
    """
    Iterate over `iterator` disabling tracer on every iteration step.
    This is a trick to avoid using our coverage tracer when calling hypothesis machinery.
    Without it, testing is about 3 times slower.
    """
    iterator = iter(items)
    default_trace = sys.gettrace()
    while True:  # pragma: no cover
        sys.settrace(None)
        try:
            case = next(iterator)
        except StopIteration:
            return
        finally:
            sys.settrace(default_trace)
        yield case


def run_cases(
    cases: Iterator[TestCase],
    func_name: str,
    stream: TextIO,
    colors: dict[str, str],
) -> bool:
    print('  {blue}running {name}{end}'.format(name=func_name, **colors), file=stream)
    for case in cases:
        try:
            case()
        except Exception:
            line = '    {yellow}{name}({args}){end}'.format(
                name=func_name,
                args=format_call_args(args=case.args, kwargs=case.kwargs),
                **colors,
            )
            print(line, file=stream)
            print(format_exception(), file=stream)
            return False
    return True


def format_coverage(tresult: TraceResult, colors: dict[str, str]) -> str:
    cov = tresult.coverage
    if cov >= 85:
        color = colors['green']
    elif cov >= 50:
        color = colors['yellow']
    else:
        color = colors['red']
    tmpl = '    coverage {color}{cov}%{end}'
    missing = format_lines(
        statements=tresult.all_lines,
        lines=tresult.all_lines - tresult.covered_lines,
    )
    if cov != 0 and cov != 100 and len(missing) <= 60:
        tmpl += ' (missing {missing})'
    line = tmpl.format(
        cov=cov,
        color=color,
        missing=missing,
        **colors,
    )
    return line


[docs]class TestCommand(Command): """Generate and run tests against pure functions. ```bash python3 -m deal test project/ ``` Function must be decorated by one of the following to be run: + `@deal.pure` + `@deal.has()` (without arguments) Options: + `--count`: how many input values combinations should be checked. Exit code is equal to count of failed test cases. See [tests][tests] documentation for more details. [tests]: https://deal.readthedocs.io/basic/tests.html """ @staticmethod def init_parser(parser: ArgumentParser) -> None: parser.add_argument('--count', type=int, default=50) parser.add_argument('paths', nargs='+') def __call__(self, args) -> int: failed = 0 for arg in args.paths: for path in get_paths(Path(arg)): failed += self.run_tests( path=Path(path), count=args.count, ) return failed def run_tests(self, path: Path, count: int) -> int: names = list(get_func_names(path)) if not names: return 0 self.print('{magenta}running {path}{end}'.format(path=path, **COLORS)) module_name = '.'.join(path.relative_to(self.root).with_suffix('').parts) with sys_path(path=self.root): module = import_module(module_name) failed = 0 for func_name in names: # pragma: no cover func = getattr(module, func_name) # set `__wrapped__` attr so `trace` can find the original function. runner = update_wrapper(wrapper=run_cases, wrapped=func) tresult = trace( func=runner, cases=fast_iterator(cases(func=func, count=count)), func_name=func_name, stream=self.stream, colors=COLORS, ) if tresult.func_result and tresult.all_lines: text = format_coverage(tresult=tresult, colors=COLORS) self.print(text) else: failed += 1 return failed # pragma: no cover