import abc
import copy
import functools
import threading
from collections.abc import Iterable, Iterator
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar

from tmt._compat.typing import ParamSpec
from tmt.log import Logger
from tmt.utils import GeneralError

if TYPE_CHECKING:
    from tmt._compat.typing import Self
    from tmt.guest import Guest


T = TypeVar('T')
P = ParamSpec('P')
TaskResultT = TypeVar('TaskResultT')
TaskT = TypeVar('TaskT', bound='Task')  # type: ignore[type-arg]


@functools.total_ordering
class Task(abc.ABC, Generic[TaskResultT]):
    """
    A base class for queueable actions.

    .. note::

        The class provides both the implementation of the action, but
        also serves as a container for outcome of the action: every time
        the task is invoked by :py:class:`<Queue>`, the queue yields an
        instance of the same class, but filled with information related
        to the result of its action.
    """

    #: A logger to use for logging events related to the outcome.
    logger: Logger

    #: Order of this task. Follow the semantics of the
    #: :tmt:story:`/spec/test/order` key, the lower the number, the
    #: earlier the task runs. Tasks with ``order`` left unset will be
    #: invoked last, in no guaranteed order.
    order: Optional[int] = None

    #: Result returned by the task when executed.
    result: Optional[TaskResultT] = None

    #: If set, an exception was raised by the running task, and said
    #: exception is saved in this field.
    exc: Optional[Exception] = None

    #: If set, the task raised :py:class:`SystemExit` exception, and
    #: wants to terminate the run completely. Original exception is
    #: assigned to this field.
    requested_exit: Optional[SystemExit] = None

    def __init__(self, logger: Logger) -> None:
        self.logger = logger

    # Magic methods to support queue sorting: sort tasks by their `order`,
    # `order=None` means "undefined" and comes last. `@total_ordering`
    # should fill in the blanks.
    def __gt__(self, other: Any) -> bool:
        if not isinstance(other, Task):
            raise GeneralError(
                f"Cannot compare task of type '{self.__class__.__name__}'"
                f" with '{type(other).__name__}'."
            )

        # Both tasks do not care about their order, they shall remain
        # in the order in which they are now.
        if self.order is None and other.order is None:
            return False

        # Other task cares about its order while this one does not, that
        # puts this task after the other automatically, no matter what
        # the other `order` value is.
        if self.order is None and other.order is not None:
            return True

        # This task cares about its order while the other does not, that
        # puts the other task after this one automatically, no matter
        # what this tasks's `order` is.
        if self.order is not None and other.order is None:
            return False

        # Both tasks care about their order, therefore their ordering
        # shall put the one with lower `order` first.
        if self.order is not None and other.order is not None:
            return self.order > other.order

        raise NotImplementedError

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, Task):
            raise GeneralError(
                f"Cannot compare task of type '{self.__class__.__name__}'"
                f" with '{type(other).__name__}'."
            )

        return self.order == other.order

    @property
    @abc.abstractmethod
    def name(self) -> str:
        """
        A name of this task.

        Left for child classes to implement, because the name depends on
        the actual task.
        """

        raise NotImplementedError

    def _extract_task_outcome(
        self, logger: Logger, extract: Callable[P, TaskResultT], *args: P.args, **kwargs: P.kwargs
    ) -> 'Self':
        """
        A helper for extracting the task outcome and recording it.

        :param logger: used for logging, and will be attached to the
            returned instance.
        :param extract: a callable responsible for extracting outcome
            of the task. It will be passed rest of positional and
            keyword arguments.
        :param args: positional arguments for ``extract`` callable.
        :param kwargs: keyword arguments for ``extract`` callable.
        :returns: new instance of this class, with :py:attr:`logger`,
            :py:attr:`result`, :py:attr:`exc` and
            :py:attr:`requested_exit` attributes filled according to the
            result of ``extract``.
        """

        task = copy.copy(self)

        task.logger = logger
        task.result = None
        task.exc = None
        task.requested_exit = None

        try:
            task.result = extract(*args, **kwargs)

        except SystemExit as exc:
            task.requested_exit = exc

        except Exception as exc:
            task.exc = exc

        return task

    def _invoke_in_pool(
        self,
        *,
        units: list[T],
        get_label: Callable[['Self', T], str],
        extract_logger: Callable[['Self', T], Logger],
        inject_logger: Callable[['Self', T, Logger], None],
        submit: Callable[['Self', T, Logger, ThreadPoolExecutor], Future[TaskResultT]],
        on_complete: Optional[Callable[['Self', T], 'Self']] = None,
        logger: Logger,
    ) -> Iterator['Self']:
        """
        Execute the task across a list of "units" of work.

        A helper for situations where the task is to be applied at
        multiple guests, phases or other objects at the same time. The
        task is scheduled as a :py:class:`Future` for each unit of
        ``units`` list; results of these futures are then collected, and
        yielded as instances of the task's class.

        :param units: list of units the task should run for.
        :param get_label: a callable that shall return a logger label for
            the given unit. The label is than added to a custom logger
            passed to ``inject_logger``.
        :param extract_logger: a callable that shall return the current
            logger of the given unit. The logger is saved and then
            restored when the task for the given unit is complete.
        :param inject_logger: a callable that should update the given
            unit with the given unit-specific logger. It will be called
            twice, to inject the custom logger first, and then to restore
            the original logger.
        :param submit: a callable that shall submit the task, with the
            given unit, to the executor, and return the :py:class:`Future`
            instance it receives from the executor.
        :param on_complete: if set, it will be called once the task
            completes for the given unit.
        :param logger: used for logging.
        """

        multiple_units = len(units) > 1

        new_loggers = prepare_loggers(logger, [get_label(self, unit) for unit in units])
        old_loggers: dict[str, Logger] = {}

        with ThreadPoolExecutor(max_workers=len(units)) as executor:
            futures: dict[Future[TaskResultT], T] = {}

            for unit in units:
                # Swap unit's logger for the one we prepared, with labels
                # and stuff.
                old_loggers[get_label(self, unit)] = extract_logger(self, unit)
                new_logger = new_loggers[get_label(self, unit)]

                inject_logger(self, unit, new_logger)

                if multiple_units:
                    new_logger.info('started', color='cyan')

                # Submit each task/unit combination, and save the unit
                # and logger for later.
                futures[submit(self, unit, new_logger, executor)] = unit

            # ... and then sit and wait as they get delivered to us as they
            # finish. Unpack the guest and logger, so we could preserve logging
            # and prepare the right outcome package.
            for future in as_completed(futures):
                unit = futures[future]

                old_logger = old_loggers[get_label(self, unit)]
                new_logger = new_loggers[get_label(self, unit)]

                if multiple_units:
                    new_logger.info('finished', color='cyan')

                # `Future.result()` will either 1. reraise an exception the
                # callable raised, if any, or 2. return whatever the callable
                # returned - which is `None` in our case, therefore we can
                # ignore the return value.
                task = self._extract_task_outcome(new_logger, future.result)

                if on_complete:
                    task = on_complete(task, unit)

                # Don't forget to restore the original logger.
                inject_logger(task, unit, old_logger)

                yield task

    @abc.abstractmethod
    def go(self) -> Iterator['Self']:
        """
        Perform the task.

        Called by :py:class:`Queue` machinery to accomplish the task.

        :yields: instances of the same class, describing invocations of
            the task and their outcome. The task might be executed
            multiple times, depending on how exactly it was queued, and
            method would yield corresponding results.
        """

        raise NotImplementedError


def prepare_loggers(logger: Logger, labels: list[str]) -> dict[str, Logger]:
    """
    Create loggers for a set of labels.

    Guests are assumed to be a group a phase would be executed on, and
    therefore their labels need to be set, to provide context, plus their
    labels need to be properly aligned for more readable output.
    """

    loggers: dict[str, Logger] = {}

    # First, spawn all loggers, and set their labels if needed.
    # Don't bother with labels if there's just a single guest.
    for label in labels:
        new_logger = logger.clone()

        if len(labels) > 1:
            new_logger.labels.append(label)

        loggers[label] = new_logger

    # Second, find the longest label, and instruct all loggers to pad their
    # labels to match this length. This should create well-indented messages.
    max_label_span = max(new_logger.labels_span for new_logger in loggers.values())

    for new_logger in loggers.values():
        new_logger.labels_padding = max_label_span

    return loggers


class GuestlessTask(Task[TaskResultT]):
    """
    A task not assigned to a particular set of guests.

    An extension of the :py:class:`Task` class, provides a starting
    point for tasks that do not need to run on any guest.
    """

    @abc.abstractmethod
    def run(self, logger: Logger) -> TaskResultT:
        """
        Perform the task.

        Called once from :py:meth:`go`. Subclasses of must implement
        their logic in this method rather than in :py:meth:`go` which is
        already provided.
        """

        raise NotImplementedError

    def go(self) -> Iterator['Self']:
        """
        Perform the task.

        Called by :py:class:`Queue` machinery to accomplish the task.

        Invokes :py:meth:`run` method to perform the task itself, and
        derived classes therefore must provide implementation of ``run``
        method.

        :yields: instances of the same class, describing invocations of
            the task and their outcome. The task might be executed
            multiple times, depending on how exactly it was queued, and
            method would yield corresponding results.
        """

        yield self._extract_task_outcome(self.logger, self.run, self.logger)


class MultiGuestTask(Task[TaskResultT]):
    """
    A task assigned to a particular set of guests.

    An extension of the :py:class:`Task` class, provides a starting
    point for tasks that do need to run on a set of guests.
    """

    #: List of guests to run the task on.
    guests: list['Guest']

    #: Guest on which the phase was executed.
    guest: Optional['Guest'] = None

    def __init__(self, guests: list['Guest'], logger: Logger) -> None:
        super().__init__(logger)

        self.guests = guests

    @property
    def guest_ids(self) -> list[str]:
        return sorted([guest.multihost_name for guest in self.guests])

    @abc.abstractmethod
    def run_on_guest(self, guest: 'Guest', logger: Logger) -> TaskResultT:
        """
        Perform the task.

        Called once from :py:meth:`go`. Subclasses of must implement
        their logic in this method rather than in :py:meth:`go` which is
        already provided.
        """

        raise NotImplementedError

    def go(self) -> Iterator['Self']:
        """
        Perform the task.

        Called by :py:class:`Queue` machinery to accomplish the task.

        Invokes :py:meth:`run_on_guest` method to perform the task itself,
        and derived classes therefore must provide implementation of
        ``run_on_guest`` method.

        :yields: instances of the same class, describing invocations of
            the task and their outcome. The task might be executed
            multiple times, depending on how exactly it was queued, and
            method would yield corresponding results.
        """

        def _on_complete(task: 'Self', guest: 'Guest') -> 'Self':
            task.guest = guest

            return task

        yield from self._invoke_in_pool(
            # Run across all guests known to this task.
            units=self.guests,
            # Unit ID here is guest's multihost name
            get_label=lambda task, guest: guest.multihost_name,
            extract_logger=lambda task, guest: guest._logger,
            inject_logger=lambda task, guest, logger: guest.inject_logger(logger),
            # Submit work for the executor pool.
            submit=lambda task, guest, logger, executor: executor.submit(
                self.run_on_guest, guest, logger
            ),
            on_complete=_on_complete,
            logger=self.logger,
        )


class Queue(list[TaskT]):
    """
    Queue class for running tasks.
    """

    #: If set, the queue is running and invoking tasks.
    is_running: bool

    #: After yielding all outcomes from a single task, this flag is
    #: checked. If it's set, the next task in line would be started;
    #: otherwise, :py:meth:`run` will quit.
    _keep_running: bool

    #: Number of tasks that were already invoked.
    _invoked_tasks: int

    #: Lock protecting modifications of the queue, i.e. the list of
    #: tasks.
    _queue_lock: threading.Lock

    def __init__(self, name: str, logger: Logger) -> None:
        super().__init__()

        self.name = name
        self._logger = logger
        self._queue_lock = threading.Lock()

        self.reset()

    def reset(self) -> None:
        """
        Reset queue content and properties as if it was just created.
        """

        self[:] = []
        self.is_running = False
        self._keep_running = True
        self._invoked_tasks = 0

    # We only need to track the head and tail to properly calculate the
    # task queue number dynamically. Adding task number to each task
    # would require re-defining it whenever the queue get reordered.
    @property
    def _head_task_number(self) -> Optional[int]:
        """
        Task number of the task currently at the beginning of the queue.

        For example, for a queue with three tasks, none invoked yet, the
        head task number would be ``1``. After invoking the first task,
        the head task number would be ``2``.

        :returns: task number of the first task in the queue, or
            ``None`` if the queue is empty.
        """

        if not self:
            return None

        return self._invoked_tasks + 1

    @property
    def _tail_task_number(self) -> Optional[int]:
        """
        Task number of the last task in the queue.

        For example, for a queue with three tasks, none invoked yet, the
        tail task number would be ``3``. After invoking the first task,
        the tail task would still be ``3``. If new task is added, the
        tail task number would be ``4``.

        :returns: task number of the last task in the queue, or
            ``None`` if the queue is empty.
        """

        if not self:
            return None

        return self._invoked_tasks + len(self)

    def show_tasks(self, label: str, logger: Logger) -> None:
        logger.info(label, color='cyan')

        if not self:
            return

        task_logger = logger.descend()

        # Narrow type: this shall no longer be undefined as we
        # do have at least one item in the queue.
        assert self._head_task_number is not None

        for task_number, task in enumerate(self, start=self._head_task_number):
            task_logger.info(
                f'#{task_number}',
                task.name,
                color='cyan',
            )

    def enqueue_task(self, task: TaskT) -> bool:
        """
        Put new task into a queue.

        :returns: ``True`` if the queue was reordered because of the new
            task, ``False`` otherwise.
        """

        # While `append()` is thread-safe and atomic, sorting certainly
        # isn't.
        with self._queue_lock:
            self.append(task)

            self._logger.debug(
                f'{self.name} queue: added task',
                task.name,
            )

            # Reorder the remaining tasks.
            current_order = [task.name for task in self]

            self.sort()

            new_order = [task.name for task in self]

            if current_order != new_order and self.is_running:
                self.show_tasks(
                    f'{self.name} queue: reordering after task {task.name}', self._logger
                )

            return current_order != new_order

    def run(self) -> Iterator[TaskT]:
        """
        Start crunching the queued tasks.

        Tasks are executed in the order, for each invoked task new
        instance of this class is yielded.
        """

        self.is_running = True

        self.show_tasks(f'queued {self.name} tasks', self._logger)

        # `self` test does not need to be protected by a lock: nothing
        # except the `pop()` below removes tasks from the queue, so if
        # the queue is empty, it will remain empty, and if it has tasks,
        # it will remain having at least the same amount of tasks.
        #
        # `pop()` must be protected, because it must not collide with
        # 1. addition of tasks - that would be fine, both `append()`
        # and `pop()` are atomic - but also 2. sorting of the queue
        # after addition, which is not atomic.
        while self:
            with self._queue_lock:
                task_number = self._head_task_number
                task = self.pop(0)

                self._invoked_tasks += 1

            self._logger.info('')

            self._logger.info(
                f'{self.name} task #{task_number}',
                task.name,
                color='cyan',
            )

            failed_tasks: list[TaskT] = []

            for outcome in task.go():
                if outcome.exc:
                    failed_tasks.append(outcome)

                yield outcome

            # TODO: with the `self._keep_running` this check is probably
            # no longer needed, as the responsibility shifts more to the
            # user of the queue, to decide when to stop. To drop it, we
            # will need to review uses of `Queue`, which will happen as
            # part of https://github.com/teemtee/tmt/issues/4668.
            if failed_tasks:
                self.is_running = False

                return

            if not self._keep_running:
                self.is_running = False

                return

        self.is_running = False

    def stop(self) -> Iterable[TaskT]:
        """
        Stop crunching the queue tasks.

        :returns: remaining tasks.
        """

        self._keep_running = False

        return self[:]
