import enum
import re
import textwrap
import time
from re import Pattern
from typing import TYPE_CHECKING, Any, Optional

import jinja2

import tmt.log
import tmt.utils
import tmt.utils.themes
from tmt.checks import Check, CheckPlugin, _RawCheck, provides_check
from tmt.container import container, field
from tmt.result import CheckResult, ResultOutcome, save_failures
from tmt.utils import (
    CommandOutput,
    Path,
    ShellScript,
    Stopwatch,
    render_command_report,
    safe_call,
)
from tmt.utils.hints import hints_as_notes

if TYPE_CHECKING:
    import tmt.base.core
    from tmt.guest import Guest
    from tmt.steps.execute import TestInvocation


class TestMethod(enum.Enum):
    TIMESTAMP = 'timestamp'
    CHECKPOINT = 'checkpoint'

    @classmethod
    def from_spec(cls, spec: str) -> 'TestMethod':
        try:
            return TestMethod(spec)
        except ValueError as error:
            raise tmt.utils.SpecificationError(f"Invalid AVC check method '{spec}'.") from error

    @classmethod
    def normalize(
        cls,
        key_address: str,
        value: Any,
        logger: tmt.log.Logger,
    ) -> 'TestMethod':
        if isinstance(value, TestMethod):
            return value

        if isinstance(value, str):
            return cls.from_spec(value)

        raise tmt.utils.SpecificationError(f"Invalid AVC check method '{value}' at {key_address}.")


#: The filename of the final check report file.
TEST_POST_AVC_FILENAME = 'avc.txt'

#: The filename of the "mark" file ``ausearch`` on the guest.
AUSEARCH_MARK_FILENAME = 'avc-mark.txt'

#: Packages related to selinux and AVC reporting. Their versions would be made
#: part of the report.
INTERESTING_PACKAGES = ['audit', 'selinux-policy']

#: Default message types to be collected by ausearch
DEFAULT_MESSAGE_TYPES = ['AVC', 'USER_AVC', 'SELINUX_ERR']

#: Compiled regex pattern to match relevant AVC denial messages
DENIAL_PATTERN = re.compile(rf"^type=(?:{'|'.join(map(re.escape, DEFAULT_MESSAGE_TYPES))})\b")

#: Default list of patterns to be ignored
DEFAULT_IGNORE_PATTERNS = [
    re.compile(pattern)
    for pattern in [
        # Informative messages shown when policy is reloaded
        r'type=USER_AVC.*received policyload notice',
    ]
]


SETUP_SCRIPT = jinja2.Template(
    textwrap.dedent("""
set -x
export LC_ALL=C

{% if CHECK.test_method.value == 'timestamp' %}
echo "export AVC_SINCE=\\"$( date "+%x %H:%M:%S")\\"" > {{ MARK_FILEPATH }}
{% else %}
ausearch --input-logs --checkpoint {{ MARK_FILEPATH }} -m {{ MESSAGE_TYPES | join(',') }}
{% endif %}

cat {{ MARK_FILEPATH }}
""")
)

TEST_SCRIPT = jinja2.Template(
    textwrap.dedent(
        """
set -x
export LC_ALL=C

{% if CHECK.test_method.value == 'timestamp' %}
source {{ MARK_FILEPATH }}
ausearch -i --input-logs -m {{ MESSAGE_TYPES | join(',') }} -ts $AVC_SINCE
{% else %}
ausearch --input-logs --checkpoint {{ MARK_FILEPATH }} -m {{ MESSAGE_TYPES | join(',') }} -i -ts checkpoint
{% endif %}
"""  # noqa: E501
    )
)


def _run_script(
    *,
    invocation: 'TestInvocation',
    script: ShellScript,
    needs_sudo: bool = False,
    logger: tmt.log.Logger,
) -> CommandOutput:
    """
    A helper to run a script on the guest.

    Instead of letting failed commands to interrupt execution by raising
    exceptions, this helper intercepts them and returns them together
    with command output. This let's us log them in the report file.

    :returns: a tuple of two items, either a command output and
        ``None``, or ``None`` and captured :py:class:`RunError`
        describing the command failure.
    """

    if needs_sudo:
        script = ShellScript(f'{invocation.guest.facts.sudo_prefix} {script.to_shell_command()}')

    def _output_logger(
        key: str,
        value: Optional[str] = None,
        color: tmt.utils.themes.Style = None,
        shift: int = 2,
        level: int = 3,
        topic: Optional[tmt.log.Topic] = None,
        stacklevel: int = 1,
    ) -> None:
        logger.verbose(
            key=key,
            value=value,
            color=color,
            shift=shift,
            level=level,
            topic=topic,
            stacklevel=stacklevel + 1,
        )

    return invocation.guest.execute(script, log=_output_logger, silent=True)


def create_ausearch_mark(
    invocation: 'TestInvocation', check: 'AvcCheck', logger: tmt.log.Logger
) -> None:
    """
    Save a mark for ``ausearch`` in a file on the guest
    """

    ausearch_mark_filepath = invocation.check_files_path / AUSEARCH_MARK_FILENAME

    # Wait one second before storing the mark because ausearch
    # could catch denials from the previous test if they are executed
    # during the same second
    time.sleep(check.delay_before_report)

    report: list[str] = []

    script = ShellScript(
        SETUP_SCRIPT.render(
            CHECK=check,
            MARK_FILEPATH=ausearch_mark_filepath,
            MESSAGE_TYPES=DEFAULT_MESSAGE_TYPES,
        ).strip()
    )

    output, exc, timer = Stopwatch.measure(
        _run_script, invocation=invocation, script=script, logger=logger
    )

    if exc is None:
        assert output is not None

        report.extend(render_command_report(label='mark', output=output))

    else:
        report.extend(render_command_report(label='mark', exc=exc))

    report_filepath = invocation.check_files_path / TEST_POST_AVC_FILENAME

    invocation.phase.write_report(
        path=report_filepath,
        label='AVC denials check',
        timer=timer,
        body=iter(report),
    )


def create_final_report(
    invocation: 'TestInvocation',
    check: 'AvcCheck',
    logger: tmt.log.Logger,
) -> tuple[ResultOutcome, list[Path]]:
    """
    Collect the data, evaluate and create the final report
    """

    if invocation.start_time is None:
        raise tmt.utils.GeneralError(
            "Test does not have start time recorded, cannot run AVC check."
        )

    ausearch_mark_filepath = invocation.check_files_path / AUSEARCH_MARK_FILENAME

    # Wait one second before storing the mark because ausearch
    # could catch denials from the previous test if they are executed
    # during the same second
    time.sleep(check.delay_before_report)

    # Collect all report components
    report: list[str] = []
    failures: list[str] = []

    # Flags indicating whether we were able to successfully fetch report components
    got_sestatus, got_rpm, got_ausearch, got_denials = False, False, False, False

    # Get the `sestatus` output.
    output, exc = safe_call(
        _run_script, invocation=invocation, script=ShellScript('sestatus'), logger=logger
    )

    if exc is None:
        assert output is not None

        got_sestatus = True

        report.extend(render_command_report(label='sestatus', output=output))

    else:
        failure = list(render_command_report(label='sestatus', exc=exc))
        report += failure
        failures.append('\n'.join(failure))

    # Record NVRs of interesting packages.
    interesting_packages = ' '.join(INTERESTING_PACKAGES)
    output, exc = safe_call(
        _run_script,
        invocation=invocation,
        script=ShellScript(f'rpm -q {interesting_packages}'),
        logger=logger,
    )

    if exc is None:
        assert output is not None

        got_rpm = True

        report.extend(render_command_report(label=f'rpm -q {interesting_packages}', output=output))

    else:
        failure = list(render_command_report(label=f'rpm -q {interesting_packages}', exc=exc))
        report += failure
        failures.append('\n'.join(failure))

    # Finally, run `ausearch`, to list AVC denials from the time the test started.
    script = ShellScript(
        TEST_SCRIPT.render(
            CHECK=check,
            MARK_FILEPATH=ausearch_mark_filepath,
            MESSAGE_TYPES=DEFAULT_MESSAGE_TYPES,
        ).strip()
    )

    output, exc, timer = Stopwatch.measure(
        _run_script, invocation=invocation, script=script, needs_sudo=True, logger=logger
    )

    # `ausearch` outcome evaluation is a bit more complicated than the one for a simple
    # `rpm -q`, because not all non-zero exit codes mean error.
    if exc is None:
        assert output is not None

        got_ausearch = True

        # Include all failures in the report, even those that would be ignored later.
        report += list(render_command_report(label='ausearch', output=output))

        if output.stdout:
            # ausearch returns complete audit events which could contain multiple message types.
            # Filter them to keep only message types we are interested in.
            denials = [line for line in output.stdout.splitlines() if DENIAL_PATTERN.match(line)]
            if denials and check.ignore_pattern:
                filtered_denials: list[str] = []
                for denial in denials:
                    matching_pattern = next(
                        (pattern for pattern in check.ignore_pattern if pattern.search(denial)),
                        None,
                    )
                    if matching_pattern:
                        logger.info(
                            "Ignoring AVC denial due to pattern match: "
                            f"'{matching_pattern.pattern}'"
                        )
                        logger.debug(f"Full ignored AVC denial: {denial}")
                    else:
                        filtered_denials.append(denial)
                denials = filtered_denials
            if denials:
                got_denials = True
                failures.append(
                    '\n'.join(
                        render_command_report(
                            label='ausearch',
                            output=CommandOutput(stdout='\n'.join(denials), stderr=output.stderr),
                        )
                    )
                )

    else:
        failure = list(render_command_report(label='ausearch', exc=exc))
        report += failure

        if (
            isinstance(exc, tmt.utils.RunError)
            and exc.returncode == 1
            and exc.stderr
            and '<no matches>' in exc.stderr.strip()
        ):
            got_ausearch = True
        else:
            failures.append('\n'.join(failure))

    # If we were able to fetch all components successfully, pick the result based on `ausearch`
    # output.
    if all([got_sestatus, got_rpm, got_ausearch]):
        outcome = ResultOutcome.FAIL if got_denials else ResultOutcome.PASS

    # Otherwise, it's an error - we already made all output part of the report.
    else:
        outcome = ResultOutcome.ERROR

    report_filepath = invocation.check_files_path / TEST_POST_AVC_FILENAME

    invocation.phase.write_report(
        path=report_filepath,
        label='AVC denials check',
        timer=timer,
        body=iter(report),
    )

    paths = [
        report_filepath.relative_to(invocation.phase.step_workdir),
        save_failures(invocation, invocation.check_files_path, failures),
    ]

    return outcome, paths


@container
class AvcCheck(Check):
    test_method: TestMethod = field(
        default=TestMethod.TIMESTAMP,
        choices=[method.value for method in TestMethod],
        help="""
             Which method to use when calling ``ausearch`` to report new
             AVC denials. With ``checkpoint``, native ``--checkpoint``
             option of ``ausearch`` is used, while ``timestamp`` will
             depend on ``--ts`` option and a date/time recorded before
             the test.
             """,
        normalize=TestMethod.normalize,
        serialize=lambda method: method.value,
        unserialize=lambda serialized: TestMethod.from_spec(serialized),
        exporter=lambda method: method.value,
    )

    delay_before_report: int = field(
        default=5,
        metavar='SECONDS',
        help="""
             How many seconds to wait before running ``ausearch`` after
             the test. Increasing it may help when events do reach logs
             fast enough for ``ausearch`` report them.
             """,
        normalize=tmt.utils.normalize_int,
    )

    ignore_pattern: list[Pattern[str]] = field(
        default_factory=lambda: DEFAULT_IGNORE_PATTERNS[:],
        help="""
             Optional list of regular expressions to ignore in AVC denials.
             If an AVC denial matches any of these patterns, it will be ignored
             and not cause a failure. Any other denials will still cause the test
             to fail. By default, informative messages about policy reload are
             ignored.
             """,
        metavar="PATTERN",
        normalize=tmt.utils.normalize_pattern_list,
        exporter=lambda patterns: [pattern.pattern for pattern in patterns],
        serialize=lambda patterns: [pattern.pattern for pattern in patterns],
        unserialize=lambda serialized: [re.compile(pattern) for pattern in serialized],
    )

    # TODO: fix `to_spec` of `Check` to support nested serializables
    def to_spec(self) -> _RawCheck:
        spec = super().to_spec()

        spec['test-method'] = self.test_method.value  # type: ignore[reportGeneralTypeIssues,typeddict-unknown-key,unused-ignore]
        spec['ignore-pattern'] = [  # type: ignore[reportGeneralTypeIssues,typeddict-unknown-key,unused-ignore]
            pattern.pattern for pattern in self.ignore_pattern
        ]

        return spec


@provides_check(
    'avc',
    hints={
        'detection-skipped': """
            The detection of AVC denials was skipped because the guest was not compatible.
            """,
    },
)
class AvcDenials(CheckPlugin[AvcCheck]):
    #
    # This plugin docstring has been reviewed and updated to follow
    # our documentation best practices. When changing it, please make
    # sure new changes are following them as well.
    #
    # https://tmt.readthedocs.io/en/stable/contribute.html#docs
    #
    """
    Check for SELinux AVC denials raised during the test.

    The check collects SELinux AVC denials from the audit log,
    gathers details about them, and together with versions of
    the ``selinux-policy`` and related packages stores them in
    a report file after the test.

    .. code-block:: yaml

        check:
          - name: avc

    .. note::

        To work correctly, the check requires SELinux to be enabled on the
        guest, and ``auditd`` must be running. Without SELinux, the
        check will turn into no-op, reporting
        :ref:`skip</spec/results/outcomes>` result, and
        without ``auditd``, the check will discover no AVC denials,
        reporting :ref:`pass</spec/results/outcomes>`.

        If the test manipulates ``auditd`` or SELinux in general, the
        check may report unexpected results.

    .. versionadded:: 1.28
    """

    _check_class = AvcCheck

    @classmethod
    def essential_requires(
        cls,
        guest: 'Guest',
        test: 'tmt.base.core.Test',
        logger: tmt.log.Logger,
    ) -> list['tmt.base.core.DependencySimple']:
        if not guest.facts.has_selinux:
            return []

        # Avoid circular imports
        import tmt.base.core

        # Note: yes, this will most likely explode in any distro outside
        # of Fedora, CentOS and RHEL.
        return [
            tmt.base.core.DependencySimple('audit'),
            tmt.base.core.DependencySimple('policycoreutils'),
        ]

    @classmethod
    def before_test(
        cls,
        *,
        check: 'AvcCheck',
        invocation: 'TestInvocation',
        environment: Optional[tmt.utils.Environment] = None,
        logger: tmt.log.Logger,
    ) -> list[CheckResult]:
        if invocation.guest.facts.has_selinux:
            create_ausearch_mark(invocation, check, logger)

        return []

    @classmethod
    def after_test(
        cls,
        *,
        check: 'AvcCheck',
        invocation: 'TestInvocation',
        environment: Optional[tmt.utils.Environment] = None,
        logger: tmt.log.Logger,
    ) -> list[CheckResult]:
        if not invocation.guest.facts.has_selinux:
            return [
                CheckResult(
                    name='avc',
                    result=ResultOutcome.SKIP,
                    note=hints_as_notes(
                        'test-checks/avc/detection-skipped', 'selinux-not-available'
                    ),
                )
            ]

        if not invocation.is_guest_healthy:
            return [
                CheckResult(
                    name='avc',
                    result=ResultOutcome.SKIP,
                    note=hints_as_notes('guest-not-healthy'),
                )
            ]

        outcome, paths = create_final_report(invocation, check, logger)

        return [CheckResult(name='avc', result=outcome, log=paths)]
