# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Server-side task interface."""

import inspect
import logging
import tempfile
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from pathlib import Path
from typing import (
    Any,
    BinaryIO,
    ClassVar,
    IO,
    Literal,
    TYPE_CHECKING,
    TextIO,
    Union,
    overload,
    override,
)

from debusine.artifacts import WorkRequestDebugLogs
from debusine.artifacts.models import (
    CollectionCategory,
    TaskTypes,
    WorkRequestResults,
)
from debusine.client.models import LookupChildType
from debusine.db.models import (
    Artifact,
    ArtifactRelation,
    Collection,
    WorkRequest,
    Worker,
    Workspace,
)
from debusine.tasks import BaseTask, TaskConfigError
from debusine.tasks.models import (
    ActionRecordInTaskHistory,
    BaseDynamicTaskData,
    BaseTaskData,
    EventReaction,
    WorkerType,
)

# TODO: remove
from debusine.tasks.server import TaskDatabaseInterface
from debusine.utils import extract_generic_type_arguments
from debusine.worker.system_information import native_architecture

if TYPE_CHECKING:
    from _typeshed import OpenBinaryModeWriting, OpenTextModeWriting


class DBTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](metaclass=ABCMeta):
    """
    Base class for server-side task-specific logic.

    A DBTask object supports Debusine in overseeing the task execution, with
    operations like applying task configuration, resolving its task data and
    computing scheduler tags.

    This gives the server a single API to access logic that is specific to all
    supported types of tasks, whether they are run server-side or worker-side.
    """

    #: Class used as the in-memory representation of task data.
    task_data_type: type[TD]
    data: TD

    #: Class used as the in-memory representation of dynamic task data.
    dynamic_task_data_type: type[DTD]
    dynamic_data: DTD | None

    #: Must be overridden by child classes to document the current version of
    #: the task's code. A task will only be scheduled on a worker if its task
    #: version is the same as the one running on the scheduler.
    TASK_VERSION: int | None = None

    #: The worker type must be suitable for the task type.  TaskTypes.WORKER
    #: requires an external worker; TaskTypes.SERVER requires a Celery
    #: worker; TaskTypes.SIGNING requires a signing worker.
    TASK_TYPE: TaskTypes

    name: ClassVar[str]
    _sub_tasks: dict[TaskTypes, dict[str, type["DBTask['Any', 'Any']"]]] = (
        defaultdict(dict)
    )

    #: Work request backing this task in the DB
    work_request: WorkRequest

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """
        Register the subclass into DBTask._sub_tasks.

        Used by DBTask.class_from_name() to return the class given the name.
        """
        super().__init_subclass__(**kwargs)

        # The name of the task. It is computed by converting the class name
        # to lowercase.
        cls.name = getattr(cls, "TASK_NAME", cls.__name__.lower())

        # The task data types, computed by introspecting the type arguments
        # used to specialize this generic class.
        [
            cls.task_data_type,
            cls.dynamic_task_data_type,
        ] = extract_generic_type_arguments(cls, DBTask)

        if inspect.isabstract(cls):
            # Don't list abstract base classes as tasks.
            return

        registry = cls._sub_tasks[cls.TASK_TYPE]

        # The same sub-task could register twice
        # (but assert that is the *same* class, not a different
        # subtask with a name with a different capitalisation)
        if cls.name in registry and registry[cls.name] != cls:
            raise AssertionError(f'Two Tasks with the same name: {cls.name!r}')

        # Make sure SERVER and WORKER do not have conflicting task names
        match cls.TASK_TYPE:
            case TaskTypes.SERVER:
                if cls.name in cls._sub_tasks[TaskTypes.WORKER]:
                    raise AssertionError(
                        f'{cls.name!r} already registered as a Worker task'
                    )
            case TaskTypes.WORKER:
                if cls.name in cls._sub_tasks[TaskTypes.SERVER]:
                    raise AssertionError(
                        f'{cls.name!r} already registered as a Server task'
                    )

        registry[cls.name] = cls

    def __init__(self, work_request: WorkRequest) -> None:
        """Initialize the task."""
        self.work_request = work_request

        #: A :py:class:`logging.Logger` instance that can be used in child
        #: classes when you override methods to implement the task.
        self.logger = logging.getLogger("debusine.tasks")

        #: Validate task data and dynamic data
        try:
            self.data = self.task_data_type(**work_request.used_task_data)
            self.dynamic_data = (
                None
                if work_request.dynamic_task_data is None
                else self.dynamic_task_data_type(
                    **work_request.dynamic_task_data
                )
            )
        except ValueError as exc:
            raise TaskConfigError(None, original_exception=exc)

        # Task is aborted: the task does not need to be executed, and can be
        # stopped if it is already running
        self._aborted = False

        # fetch_input() add the downloaded artifacts. Used by
        # `DBTask._upload_work_request_debug_logs()` and maybe by
        # required method `upload_artifacts()`.
        #
        # This is distinct from get_input_artifacts_ids, which is used to
        # extract IDs from dynamic_data for use by UI views
        self._source_artifacts_ids: list[int] = []

        self._debug_log_files_directory: (
            None | (tempfile.TemporaryDirectory[str])
        ) = None

        self.post_init()

    @property
    def workspace(self) -> Workspace:
        """
        Return the workspace for this task.

        This is a shortcut for ``self.work_request.workspace``.
        """
        return self.work_request.workspace

    def set_worker(self, worker: Worker) -> None:
        """Set the worker for this task."""
        # Nothing to do by default

    def post_init(self) -> None:
        """Specific post-init code."""

    def append_to_log_file(self, filename: str, lines: list[str]) -> None:
        """
        Open log file and write contents into it.

        :param filename: use self.open_debug_log_file(filename)
        :param lines: write contents to the logfile
        """
        with self.open_debug_log_file(filename) as file:
            file.writelines([line + "\n" for line in lines])

    @classmethod
    def prefix_with_task_name(cls, text: str) -> str:
        """:return: the ``text`` prefixed with the task name and a colon."""
        if cls.TASK_TYPE is TaskTypes.WORKER:
            # Worker tasks are left unprefixed for compatibility
            return f"{cls.name}:{text}"
        else:
            return f"{cls.TASK_TYPE.lower()}:{cls.name}:{text}"

    def build_architecture(self) -> str | None:
        """
        Return the architecture to run this task on.

        Tasks where build_architecture is not determined by
        self.data.build_architecture should re-implement this method.
        """
        return native_architecture()

    def can_run_on(self, worker_metadata: dict[str, Any]) -> bool:
        """
        Check if the specified worker can run the task.

        This method shall take its decision solely based on the supplied
        ``worker_metadata`` and on the configured task data (``self.data``).

        The default implementation always returns True unless
        :py:attr:`TASK_TYPE` doesn't match the worker type or there's a
        mismatch between the :py:attr:`TASK_VERSION` on the scheduler side
        and on the worker side.

        Derived objects can implement further checks by overriding the method
        in the following way::

            if not super().can_run_on(worker_metadata):
                return False

            if ...:
                return False

            return True

        :param dict worker_metadata: The metadata collected from the worker by
            running :py:meth:`analyze_worker` on all the tasks on the worker
            under consideration.
        :return: the boolean result of the check.
        :rtype: bool.
        """
        worker_type = worker_metadata.get("system:worker_type")
        if (self.TASK_TYPE, worker_type) not in {
            (TaskTypes.WORKER, WorkerType.EXTERNAL),
            (TaskTypes.SERVER, WorkerType.CELERY),
            (TaskTypes.SIGNING, WorkerType.SIGNING),
        }:
            return False

        version_key_name = self.prefix_with_task_name("version")
        if worker_metadata.get(version_key_name) != self.TASK_VERSION:
            return False

        # Some tasks might not have "build_architecture"
        task_architecture = self.build_architecture()
        if (
            task_architecture is not None
            and task_architecture
            not in worker_metadata.get("system:architectures", [])
        ):
            return False

        return True

    def compute_system_required_tags(self) -> set[str]:
        """Compute the system set of task-required tags."""
        # Prevent circular import
        import debusine.worker.tags as wtags

        required_tags: set[str] = set()

        # The worker must have code that can handle this task version.
        task_version_tag = (
            wtags.TASK_PREFIX + f"{self.TASK_TYPE.lower()}:{self.name}:"
            f"version:{self.TASK_VERSION}"
        )

        # The worker must be of a type that can run this task.
        match self.TASK_TYPE:
            case TaskTypes.WORKER:
                required_tags.add(wtags.WORKER_TYPE_EXTERNAL)
                required_tags.add(task_version_tag)
            case TaskTypes.SERVER:
                required_tags.add(wtags.WORKER_TYPE_CELERY)
                required_tags.add(task_version_tag)
            case TaskTypes.SIGNING:
                required_tags.add(wtags.WORKER_TYPE_SIGNING)
                required_tags.add(task_version_tag)
            case _:
                # Other tasks should never be dispatched to workers, but
                # make sure of that by requiring a tag that will never
                # exist.
                required_tags.add(wtags.WORKER_TYPE_NOT_ASSIGNABLE)

        return required_tags

    @abstractmethod
    def build_dynamic_data(self, task_database: TaskDatabaseInterface) -> DTD:
        """
        Build a dynamic task data structure for this task.

        :param task_database: TaskDatabaseInterface to use for lookups
        :returns: the newly created dynamic task data
        """

    def compute_dynamic_data(self, task_database: TaskDatabaseInterface) -> DTD:
        """
        Compute dynamic data for this task.

        This may involve resolving artifact lookups.
        """
        return self.build_dynamic_data(task_database)

    def execute_logging_exceptions(self) -> WorkRequestResults:
        """Execute self.execute() logging any raised exceptions."""
        try:
            return self.execute()
        except Exception as exc:
            self.logger.info("Exception in Task %s", self.name, exc_info=True)
            raise exc

    def execute(self) -> WorkRequestResults:
        """
        Call the _execute() method, upload debug artifacts.

        See _execute() for more information.

        :return: result of the _execute() method.
        """  # noqa: D402
        result = self._execute()
        self._upload_work_request_debug_logs()
        return result

    @abstractmethod
    def _execute(self) -> WorkRequestResults:
        """
        Execute the requested task.

        The task must first have been configured. It is allowed to take
        as much time as required. This method will only be run on a worker. It
        is thus allowed to access resources local to the worker.

        It is recommended to fail early by raising a :py:exc:TaskConfigError if
        the parameters of the task let you anticipate that it has no chance of
        completing successfully.

        :return: SUCCESS to indicate success, FAILURE for a failure, ERROR for
            an internal error, SKIPPED if the task turned out to be a noop.
        :raises TaskConfigError: if the parameters of the work request are
            incompatible with the worker.
        """

    def abort(self) -> None:
        """Task does not need to be executed. Once aborted cannot be changed."""
        self._aborted = True

    @property
    def aborted(self) -> bool:
        """
        Return if the task is aborted.

        Tasks cannot transition from aborted -> not-aborted.
        """
        return self._aborted

    @staticmethod
    def class_from_name(
        task_type: TaskTypes, task_name: str
    ) -> type["DBTask['Any', 'Any']"]:
        """
        Return class for :param task_name (case-insensitive).

        :param task_type: type of task to look up

        __init_subclass__() registers DBTask subclasses' into
        DBTask._sub_tasks.
        """
        if (registry := DBTask._sub_tasks.get(task_type)) is None:
            raise ValueError(f"{task_type!r} is not a registered task type")

        task_name_lowercase = task_name.lower()
        if (cls := registry.get(task_name_lowercase)) is None:
            raise ValueError(
                f"{task_name_lowercase!r} is not a registered"
                f" {task_type} task_name"
            )

        return cls

    @staticmethod
    def is_valid_task_name(task_type: TaskTypes, task_name: str) -> bool:
        """Return True if task_name is registered (its class is imported)."""
        if (registry := DBTask._sub_tasks.get(task_type)) is None:
            return False
        return task_name.lower() in registry

    @staticmethod
    def task_names(task_type: TaskTypes) -> list[str]:
        """Return list of sub-task names."""
        return sorted(DBTask._sub_tasks[task_type])

    @staticmethod
    def is_worker_task(task_name: str) -> bool:
        """Check if task_name is a task that can run on external workers."""
        return task_name.lower() in DBTask._sub_tasks[TaskTypes.WORKER]

    @staticmethod
    def worker_task_names() -> list[str]:
        """Return list of sub-task names not of type TaskTypes.SERVER."""
        return sorted(DBTask._sub_tasks[TaskTypes.WORKER].keys())

    def get_input_artifacts_ids(self) -> list[int]:
        """
        Return the list of input artifact IDs used by this task.

        This refers to the artifacts actually used by the task. If
        dynamic_data is empty, this returns the empty list.

        This is used by views to show what artifacts were used by a task.
        `_source_artifacts_ids` cannot be used for this purpose because it is
        only set during task execution.
        """
        if self.dynamic_data is None:
            return []
        return self.dynamic_data.get_input_artifacts_ids()

    def get_label(self) -> str:
        """Return a short human-readable label for the task."""
        return self.work_request.get_label()

    def get_event_reactions(
        self,
        event_name: Literal[
            "on_creation",
            "on_unblock",
            "on_assignment",
            "on_success",
            "on_failure",
        ],
    ) -> list[EventReaction]:
        """
        Return event reactions for this task.

        This allows tasks to provide actions that are processed by the
        server at various points in the lifecycle of the work request.
        """
        event_reactions: list[EventReaction] = []
        if event_name in {"on_success", "on_failure"}:
            event_reactions.append(ActionRecordInTaskHistory())
        return event_reactions

    @overload
    def open_debug_log_file(
        self, filename: str, *, mode: "OpenTextModeWriting" = "a"
    ) -> TextIO: ...

    @overload
    def open_debug_log_file(
        self, filename: str, *, mode: "OpenBinaryModeWriting"
    ) -> BinaryIO: ...

    def open_debug_log_file(
        self,
        filename: str,
        *,
        mode: Union["OpenTextModeWriting", "OpenBinaryModeWriting"] = "a",
    ) -> IO[Any]:
        """
        Open a temporary file and return it.

        The files are always for the same temporary directory, calling it twice
        with the same file name will open the same file.

        The caller must call .close() when finished writing.
        """
        if self._debug_log_files_directory is None:
            self._debug_log_files_directory = tempfile.TemporaryDirectory(
                prefix="debusine-task-debug-log-files-"
            )

        debug_file = Path(self._debug_log_files_directory.name) / filename
        return debug_file.open(mode)

    def _upload_work_request_debug_logs(self) -> None:
        """
        Create a WorkRequestDebugLogs artifact and upload the logs.

        The logs might exist in self._debug_log_files_directory and were
        added via self.open_debug_log_file() or self.create_debug_log_file().

        For each self._source_artifacts_ids: create a relation from
        WorkRequestDebugLogs to source_artifact_id.
        """
        if self._debug_log_files_directory is None:
            return

        work_request_debug_logs = WorkRequestDebugLogs.create(
            files=Path(self._debug_log_files_directory.name).glob("*")
        )

        artifact = Artifact.objects.create_from_local_artifact(
            work_request_debug_logs,
            self.work_request.workspace,
            created_by_work_request=self.work_request,
        )

        for source_artifact_id in self._source_artifacts_ids:
            ArtifactRelation.objects.create(
                artifact=artifact,
                target_id=source_artifact_id,
                type=ArtifactRelation.Relations.RELATES_TO,
            )

        self._debug_log_files_directory.cleanup()
        self._debug_log_files_directory = None

    def compute_system_provided_tags(self) -> set[str]:
        """Compute the system set of task-provided tags."""
        # Prevent circular import
        import debusine.tasks.tags as ttags

        ws = self.work_request.workspace

        provided_tags: set[str] = set()
        provided_tags.add(ttags.SCOPE_PREFIX + ws.scope.name)
        provided_tags.add(ttags.WORKSPACE_PREFIX + f"{ws.scope.name}:{ws.name}")

        for group in self.work_request.created_by.debusine_groups.values(
            "scope__name", "workspace__name", "name"
        ):
            provided_tags.add(
                ttags.GROUP_PREFIX
                + group["scope__name"]
                + ":"
                + (group["workspace__name"] or "")
                + ":"
                + group["name"]
            )

        return provided_tags

    def compute_user_provided_tags(
        self, dynamic_data: BaseDynamicTaskData
    ) -> set[str]:
        """Compute the user set of task-provided tags."""
        # Prevent circular import
        import debusine.tasks.tags as ttags

        provided_tags: set[str] = set()
        if package_name := dynamic_data.get_source_package_name():
            provided_tags.add(ttags.SOURCE_PACKAGE_PREFIX + package_name)
        return provided_tags

    def apply_task_configuration(self) -> None:
        """Apply task configuration to this task and work request."""
        import debusine.server.tags as tags
        import debusine.tasks.tags as ttags
        import debusine.worker.tags as wtags
        from debusine.db.models.task_database import TaskDatabase
        from debusine.server.collections.debusine_task_configuration import (
            apply_configuration,
        )

        self.work_request.configured_task_data = None
        self.work_request.dynamic_task_data = None

        # Lookup the task-configuration collection to use
        config_collection: Collection | None = None
        if (task_configuration := self.data.task_configuration) is not None:
            # Look up the debusine:task-configuration collection
            try:
                lookup_result = self.work_request.lookup_single(
                    lookup=task_configuration,
                    default_category=CollectionCategory.TASK_CONFIGURATION,
                    expect_type=LookupChildType.COLLECTION,
                )
            except KeyError:
                if (
                    task_configuration
                    != BaseTaskData.DEFAULT_TASK_CONFIGURATION_LOOKUP
                ):
                    raise
            else:
                config_collection = lookup_result.collection

        # Compute the configured task data
        self.work_request.configured_task_data = self.data.model_dump(
            mode="json", exclude_unset=True
        )

        # Compute dynamic task data once to extract subject and
        # configuration_context
        # TODO: these values can be computed in a separate methods if the rest
        # TODO: of compute_dynamic_data becomes too expensive to run twice
        # TODO: this may be replaced with using tags computed as system tags
        task_db = TaskDatabase(self.work_request)
        early_dynamic_task_data = self.compute_dynamic_data(task_db)

        if config_collection is not None:
            apply_configuration(
                self.work_request.configured_task_data,
                config_collection,
                self.TASK_TYPE,
                self.name,
                early_dynamic_task_data.subject,
                early_dynamic_task_data.configuration_context,
            )

        # Build the set of task-provided tags
        tagset = tags.make_task_provided_tagset()
        tagset.add(ttags.PROVENANCE_SYSTEM, self.compute_system_provided_tags())
        tagset.add(
            ttags.PROVENANCE_USER,
            self.compute_user_provided_tags(early_dynamic_task_data),
        )
        tagset.finalize()
        self.work_request.scheduler_tags_provided = sorted(tagset.tags)
        del tagset

        # Build the set of task-required tags
        tagset = tags.make_task_required_tagset()
        tagset.add(wtags.PROVENANCE_SYSTEM, self.compute_system_required_tags())
        tagset.finalize()
        self.work_request.scheduler_tags_required = sorted(tagset.tags)


class DefaultDynamicData[TD: BaseTaskData](
    DBTask[TD, BaseDynamicTaskData], metaclass=ABCMeta
):
    """Base class for tasks that do not add to dynamic task data."""

    @override
    def build_dynamic_data(
        self,
        task_database: TaskDatabaseInterface,  # noqa: U100
    ) -> BaseDynamicTaskData:
        """Return default dynamic data."""
        return BaseDynamicTaskData()


class DBWorkerTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](
    DBTask[TD, DTD], metaclass=ABCMeta
):
    """Base class for DB tasks of type WORKER."""

    TASK_TYPE = TaskTypes.WORKER

    @override
    def _execute(self) -> WorkRequestResults:  # noqa: C901
        raise AssertionError("Worker tasks cannot be executed server-side")


class DBProxyTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](
    DBTask[TD, DTD], metaclass=ABCMeta
):
    """A :py:class:`DBTask` that wraps a non-db task."""

    #: Wrapped BaseTask subclass
    task_class: type[BaseTask[TD, DTD]]
    #: Wrapped BaseTask
    task: BaseTask[TD, DTD]

    def __init_subclass__(cls, **kwargs: Any) -> None:
        """Disable automatic registration."""
        # Do not register proxies in the type system automatically: it is done
        # by populate_task_registry

    @classmethod
    def populate_task_registry(cls) -> None:
        """Populate DBWorkerTask types from the BaseTask registry."""
        registry = DBTask._sub_tasks[cls.TASK_TYPE]
        for name, task_cls in BaseTask._sub_tasks[cls.TASK_TYPE].items():
            # Skip names already registered, so that the method can be run
            # twice to catch newly registered classes
            if name in registry:
                continue
            db_task_cls = type(
                f"DB{name}",
                (cls,),
                {
                    "TASK_VERSION": task_cls.TASK_VERSION,
                    "name": task_cls.name,
                    "task_data_type": task_cls.task_data_type,
                    "dynamic_task_data_type": task_cls.dynamic_task_data_type,
                    "task_class": task_cls,
                    "__doc__": task_cls.__doc__,
                },
            )
            registry[name] = db_task_cls

    def __init__(self, work_request: WorkRequest) -> None:
        """Initialize the task with the wrapped BaseTask."""
        super().__init__(work_request)
        self.task = self.task_class(
            task_data=self.work_request.used_task_data,
            dynamic_task_data=self.work_request.dynamic_task_data,
        )

    @override
    def build_architecture(self) -> str | None:
        return self.task.build_architecture()

    @override
    def compute_system_required_tags(self) -> set[str]:
        tags = super().compute_system_required_tags()
        tags |= self.task.compute_system_required_tags()
        return tags

    @override
    def build_dynamic_data(self, task_database: TaskDatabaseInterface) -> DTD:
        raise NotImplementedError(
            "build_dynamic_data cannot be called directly,"
            " only via compute_dynamic_data"
        )

    @override
    def compute_dynamic_data(self, task_database: TaskDatabaseInterface) -> DTD:
        task_database.resolve_inputs(self.task)
        return self.task.build_dynamic_data()

    @override
    def can_run_on(self, worker_metadata: dict[str, Any]) -> bool:
        if not super().can_run_on(worker_metadata):
            return False
        return self.task.can_run_on(worker_metadata)

    @override
    def set_worker(self, worker: Worker) -> None:
        worker_metadata = worker.metadata()
        # TODO: Drop system:host_architecture fallback once all workers
        # have been updated.
        self.task.worker_native_architecture = worker_metadata.get(
            "system:native_architecture",
            worker_metadata.get("system:host_architecture"),
        )


class DBWorkerProxyTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](
    DBWorkerTask[TD, DTD], DBProxyTask[TD, DTD], metaclass=ABCMeta
):
    """A :py:class:`DBTask` that wraps a worker task."""

    @override
    @classmethod
    def populate_task_registry(cls) -> None:
        # Ensure worker tasks are registered before creating their proxy
        # classes
        import debusine.tasks  # noqa: F401

        super().populate_task_registry()


class DBSigningProxyTask[TD: BaseTaskData, DTD: BaseDynamicTaskData](
    DBProxyTask[TD, DTD], metaclass=ABCMeta
):
    """A :py:class:`DBTask` that wraps a worker task."""

    TASK_TYPE = TaskTypes.SIGNING

    @override
    @classmethod
    def populate_task_registry(cls) -> None:
        # Ensure signing tasks are registered before creating their proxy
        # classes
        import debusine.signing.tasks  # noqa: F401

        super().populate_task_registry()

    @override
    def _execute(self) -> WorkRequestResults:  # noqa: C901
        raise AssertionError("Signing tasks cannot be executed server-side")


DBWorkerProxyTask.populate_task_registry()
DBSigningProxyTask.populate_task_registry()
