# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.
"""Scheduler tag infrastructure."""

import ast
from collections.abc import Callable, Collection
from functools import cached_property
from typing import TypeAlias, cast

import pyparsing

# See https://github.com/pyparsing/pyparsing/issues/204
pyparsing.ParserElement.enable_packrat()


ProvenanceAllowList: TypeAlias = set[str]
RuleMatcher: TypeAlias = Callable[[set[str]], bool]


class ProvenanceRestrictions:
    """Hold a named set of provenance restrictions."""

    def __init__(self, name: str) -> None:
        """Initialize a named set of provenance restrictions."""
        #: Name identifying this set of provenance restrictions
        self.name = name
        #: Allowlists by tag names
        self.exact: dict[str, ProvenanceAllowList] = {}
        #: Allowlists by tag prefixes
        self.prefixes: list[tuple[str, ProvenanceAllowList]] = []

    def add_exact(self, tag: str, allowlist: Collection[str]) -> None:
        """Add a provenance allowlist for a tag."""
        if tag in self.exact:
            raise ValueError(
                f"{self.name}: exact match for tag {tag}"
                " specified multiple times"
            )
        self.exact[tag] = set(allowlist)

    def add_prefix(self, prefix: str, allowlist: Collection[str]) -> None:
        """Add a provenance allowlist for a tag prefix."""
        self.prefixes.append((prefix, set(allowlist)))

    def _check_exact(self, provenance: str, tag: str) -> bool:
        if (allowlist := self.exact.get(tag)) is not None:
            if provenance not in allowlist:
                return False
        return True

    def _check_prefix(self, provenance: str, tag: str) -> bool:
        for prefix, allowlist in self.prefixes:
            if tag.startswith(prefix) and provenance not in allowlist:
                return False
        return True

    def filter_set(self, provenance: str, tags: set[str]) -> set[str]:
        """
        Filter a set of tags according to its provenance.

        All rules are applied, so if a tag matches multiple times, it must be
        from a provenance present in all matching allowlists.
        """
        result: set[str] = set()
        for tag in tags:
            if not self._check_exact(provenance, tag):
                continue
            if not self._check_prefix(provenance, tag):
                continue
            result.add(tag)
        return result


class DerivationRuleParser:
    """
    Parse a derivation rule into an expression that can be evaluated.

    This parses a simple grammar for boolean expressions on tags::

      expr :=
         TAG
       | NOT expr
       | ( expr )
       | expr AND expr
       | expr OR expr

    Operator precedence is NOT over AND over OR.
    """

    def __init__(self, rule: str) -> None:
        """Initialize a parser for the given rule."""
        self.rule = rule

    @staticmethod
    def _ast_tag(tokens: pyparsing.ParseResults) -> ast.expr:
        """Generate a ``tag in tags`` ast node."""
        [tag] = tokens
        return ast.Compare(
            left=ast.Constant(tag, lineno=0, col_offset=0),
            ops=[ast.In()],
            comparators=[
                ast.Name("tags", ctx=ast.Load(), lineno=0, col_offset=0)
            ],
            lineno=0,
            col_offset=0,
        )

    @staticmethod
    def _ast_not(tokens: pyparsing.ParseResults) -> ast.expr:
        """Generate a ``not expr`` ast node."""
        [[_, arg]] = tokens
        assert isinstance(arg, ast.expr)
        return ast.UnaryOp(ast.Not(), arg, lineno=0, col_offset=0)

    @staticmethod
    def _ast_and(tokens: pyparsing.ParseResults) -> ast.expr:
        """Generate a ``tag in tags op expr`` ast node."""
        assert len(tokens) == 1
        args = tokens[0][::2]
        return ast.BoolOp(ast.And(), values=args, lineno=0, col_offset=0)

    @staticmethod
    def _ast_or(tokens: pyparsing.ParseResults) -> ast.expr:
        """Generate a ``tag in tags op expr`` ast node."""
        assert len(tokens) == 1
        args = tokens[0][::2]
        return ast.BoolOp(ast.Or(), values=args, lineno=0, col_offset=0)

    @cached_property
    def _parser(self) -> pyparsing.ParserElement:
        NOT = pyparsing.Keyword("not")
        AND = pyparsing.Keyword("and")
        OR = pyparsing.Keyword("or")
        tag = ~(NOT | AND | OR) + pyparsing.Word(
            pyparsing.alphas + pyparsing.nums + ":_-"
        )
        return pyparsing.infix_notation(
            tag.set_parse_action(self._ast_tag),
            [
                (NOT, 1, pyparsing.OpAssoc.RIGHT, self._ast_not),
                (AND, 2, pyparsing.OpAssoc.LEFT, self._ast_and),
                (OR, 2, pyparsing.OpAssoc.LEFT, self._ast_or),
            ],
        )

    def as_expr(self) -> ast.expr:
        """Parse the rule into an ast expression."""
        [expr] = self._parser.parse_string(self.rule, parse_all=True)
        assert isinstance(expr, ast.expr)
        return expr

    def as_function(self, provenance: str) -> Callable[[set[str]], bool]:
        """
        Return the parsed rule as a callable that evaluates tag sets.

        :param provenance: provenance to use for error messages
        """
        expr = self.as_expr()
        lambda_expr = ast.Lambda(
            ast.arguments(
                posonlyargs=[ast.arg(arg="tags", lineno=0, col_offset=0)]
            ),
            body=expr,
            lineno=0,
            col_offset=0,
        )

        rule_func = eval(
            compile(
                ast.Expression(body=lambda_expr),
                f"{provenance} derivation rule",
                mode="eval",
            )
        )
        rule_func.__doc__ = self.rule
        return cast(Callable[[set[str]], bool], rule_func)


class DerivationRules:
    """
    Collection of derivation rules tracked by provenance.

    When a rule matches a tag set, its associated tags are candidates for being
    added to it.

    Note that when multiple rules are applied, tags added by a rule do not
    count in the input for the next ones. That is: each rule is applied to the
    tag set being matched, without the possible additions of other rules in the
    set.
    """

    def __init__(self, *, provenance: str) -> None:
        """
        Initialize a collection of derivation rules.

        :param provenance: provenance for these derivation rules
        """
        self.provenance = provenance
        self.rules: list[tuple[RuleMatcher, set[str]]] = []

    def add_rule(self, rule: str, tags: set[str]) -> None:
        """
        Add a new rule, and its associated tags.

        :param rule: rule as a boolean expression of tag names, parentheses,
          "and" and "or" operations
        :param tags: tags to add if the rule matches
        """
        parser = DerivationRuleParser(rule)
        rule_func = parser.as_function(self.provenance)
        self.rules.append((rule_func, tags))

    def compute(self, tags: set[str]) -> set[str]:
        """Compute new tags derived from the given tag set."""
        result: set[str] = set()
        for rule_func, new_tags in self.rules:
            if rule_func(tags):
                result.update(new_tags)
        return result
