#!/usr/bin/env python3
"""A script to automate the creation and landing of a stack of Pull Requests."""

import argparse
import json
import os
import re
import subprocess
import sys
import time
import urllib.error
import urllib.request

# TODO: Remove typing workarounds when we use a newer python.
from typing import List, Optional, Tuple
from http.client import HTTPResponse
from dataclasses import dataclass

# --- Constants --- #
BASE_BRANCH = "main"
GITHUB_REMOTE_NAME = "origin"
UPSTREAM_REMOTE_NAME = "upstream"

DEFAULT_LABEL = "skip-precommit-approval"

LLVM_GITHUB_TOKEN_VAR = "LLVM_GITHUB_TOKEN"
LLVM_REPO = "llvm/llvm-project"
GITHUB_API = "https://api.github.com"

MERGE_MAX_RETRIES = 10
MERGE_RETRY_DELAY = 5  # seconds
REQUEST_TIMEOUT = 30  # seconds


class LlvmPrError(Exception):
    """Custom exception for errors in the PR automator script."""


@dataclass
class PRAutomatorConfig:
    """Configuration Data."""

    user_login: str
    token: str
    base_branch: str
    upstream_remote: str
    prefix: str
    draft: bool
    labels: List[str]
    no_merge: bool
    auto_merge: bool


class CommandRunner:
    """Handles command execution and output.
    Supports dry runs and verbosity level."""

    def __init__(
        self, dry_run: bool = False, verbose: bool = False, quiet: bool = False
    ):
        self.dry_run = dry_run
        self.verbose = verbose
        self.quiet = quiet

    def print(self, message: str, file=sys.stdout) -> None:
        if self.quiet and file == sys.stdout:
            return
        print(message, file=file)

    def verbose_print(self, message: str, file=sys.stdout) -> None:
        if self.verbose:
            print(message, file)

    def run_command(
        self,
        command: List[str],
        check: bool = True,
        capture_output: bool = False,
        text: bool = False,
        stdin_input: Optional[str] = None,
        read_only: bool = False,
        env: Optional[dict] = None,
    ) -> subprocess.CompletedProcess:
        if self.dry_run and not read_only:
            self.print(f"[Dry Run] Would run: {' '.join(command)}")
            return subprocess.CompletedProcess(command, 0, "", "")

        self.verbose_print(f"Running: {' '.join(command)}")

        try:
            return subprocess.run(
                command,
                check=check,
                capture_output=capture_output,
                text=text,
                input=stdin_input,
                env=env,
            )
        except FileNotFoundError as e:
            raise LlvmPrError(
                f"Command '{command[0]}' not found. Is it installed and in your PATH?"
            ) from e
        except subprocess.CalledProcessError as e:
            self.print(f"Error running command: {' '.join(command)}", file=sys.stderr)
            if e.stdout:
                self.print(f"--- stdout ---\n{e.stdout}", file=sys.stderr)
            if e.stderr:
                self.print(f"--- stderr ---\n{e.stderr}", file=sys.stderr)
            raise e


class GitHubAPI:
    """A wrapper for the GitHub API."""

    def __init__(self, runner: CommandRunner, token: str):
        self.runner = runner
        self.headers = {
            "Authorization": f"token {token}",
            "Accept": "application/vnd.github.v3+json",
            "User-Agent": "llvm-push-pr",
        }
        self.opener = urllib.request.build_opener(
            urllib.request.HTTPHandler(), urllib.request.HTTPSHandler()
        )

    def _request(
        self, method: str, endpoint: str, json_payload: Optional[dict] = None
    ) -> HTTPResponse:
        url = f"{GITHUB_API}{endpoint}"
        self.runner.verbose_print(f"API Request: {method.upper()} {url}")
        if json_payload:
            self.runner.verbose_print(f"Payload: {json_payload}")

        data = None
        headers = self.headers.copy()
        if json_payload:
            data = json.dumps(json_payload).encode("utf-8")
            headers["Content-Type"] = "application/json"

        req = urllib.request.Request(
            url, data=data, headers=headers, method=method.upper()
        )

        try:
            return self.opener.open(req, timeout=REQUEST_TIMEOUT)
        except urllib.error.HTTPError as e:
            self.runner.print(
                f"Error making API request to {url}: {e}", file=sys.stderr
            )
            self.runner.verbose_print(
                f"Error response body: {e.read().decode()}", file=sys.stderr
            )
            raise e

    def _request_and_parse_json(
        self, method: str, endpoint: str, json_payload: Optional[dict] = None
    ) -> dict:
        with self._request(method, endpoint, json_payload) as response:
            # Expect a 200 'OK' or 201 'Created' status on success and JSON body.
            self._log_unexpected_status([200, 201], response.status)

            response_text = response.read().decode("utf-8")
            if response_text:
                return json.loads(response_text)
            return {}

    def _request_no_content(
        self, method: str, endpoint: str, json_payload: Optional[dict] = None
    ) -> None:
        with self._request(method, endpoint, json_payload) as response:
            # Expected a 204 No Content status on success, indicating the
            # operation was successful but there is no body.
            self._log_unexpected_status([204], response.status)

    def _log_unexpected_status(
        self, expected_statuses: List[int], actual_status: int
    ) -> None:
        if actual_status not in expected_statuses:
            self.runner.print(
                f"Warning: Expected status {', '.join(map(str, expected_statuses))}, but got {actual_status}",
                file=sys.stderr,
            )

    def get_user_login(self) -> str:
        return self._request_and_parse_json("GET", "/user")["login"]

    def create_pr(
        self,
        head_branch: str,
        base_branch: str,
        title: str,
        body: str,
        draft: bool,
    ) -> int:
        if self.runner.dry_run:
            self.runner.print(f"[Dry Run] Would create pull request for '{head_branch}'...")
            return 0

        self.runner.print(f"Creating pull request for '{head_branch}'...")
        data = {
            "title": title,
            "body": body,
            "head": head_branch,
            "base": base_branch,
            "draft": draft,
        }
        response_data = self._request_and_parse_json(
            "POST", f"/repos/{LLVM_REPO}/pulls", json_payload=data
        )
        self.runner.print(f"Pull request created: {response_data.get("html_url")}")
        return response_data.get("number")

    def get_repo_settings(self) -> dict:
        return self._request_and_parse_json("GET", f"/repos/{LLVM_REPO}")

    def _get_pr_details(self, pr_number: str) -> dict:
        """Fetches the JSON details for a given pull request number."""
        return self._request_and_parse_json(
            "GET", f"/repos/{LLVM_REPO}/pulls/{pr_number}"
        )

    def add_labels(
        self,
        pr_number: int,
        labels: List[str],
    ) -> None:
        if self.runner.dry_run:
            self.runner.print(f"[Dry Run] Would set labels for #{pr_number}: {' '.join(labels)}")
            return None

        self.runner.print(f"Setting labels for #{pr_number}: {' '.join(labels)}")

        self._request_and_parse_json(
            "POST", f"/repos/{LLVM_REPO}/issues/{pr_number}/labels",
            json_payload={"labels": labels},
        )

    def _attempt_squash_merge(self, pr_number: str) -> bool:
        """Attempts to squash merge a PR, returning True on success."""
        try:
            self._request_and_parse_json(
                "PUT",
                f"/repos/{LLVM_REPO}/pulls/{pr_number}/merge",
                json_payload={"merge_method": "squash"},
            )
            return True
        except urllib.error.HTTPError as e:
            # A 405 status code means the PR is not in a mergeable state.
            if e.code == 405:
                return False
            # Re-raise other HTTP errors.
            raise e

    def merge_pr(self, pr_number: int) -> Optional[str]:
        if self.runner.dry_run:
            self.runner.print(f"[Dry Run] Would merge #{pr_number}")
            return None

        for i in range(MERGE_MAX_RETRIES):
            self.runner.print(
                f"Attempting to merge #{pr_number} (attempt {i + 1}/{MERGE_MAX_RETRIES})..."
            )

            pr_data = self._get_pr_details(pr_number)
            head_branch = pr_data["head"]["ref"]

            if pr_data.get("mergeable_state") == "dirty":
                raise LlvmPrError("Merge conflict.")

            if pr_data.get("mergeable"):
                if self._attempt_squash_merge(pr_number):
                    self.runner.print("Successfully merged.")
                    time.sleep(2)  # Allow GitHub's API to update.
                    return head_branch

            self.runner.print(
                f"PR not mergeable yet (state: {pr_data.get('mergeable_state', 'unknown')}). Retrying in {MERGE_RETRY_DELAY} seconds..."
            )
            time.sleep(MERGE_RETRY_DELAY)

        raise LlvmPrError(f"PR was not mergeable after {MERGE_MAX_RETRIES} attempts.")

    def enable_auto_merge(self, pr_number: int) -> None:
        if self.runner.dry_run:
            self.runner.print(f"[Dry Run] Would enable auto-merge for #{pr_number}")
            return

        self.runner.print(f"Enabling auto-merge for #{pr_number}...")
        data = {
            "enabled": True,
            "merge_method": "squash",
        }
        self._request_no_content(
            "PUT",
            f"/repos/{LLVM_REPO}/pulls/{pr_number}/auto-merge",
            json_payload=data,
        )
        self.runner.print("Auto-merge enabled.")

    def delete_branch(
        self, branch_name: str, default_branch: Optional[str] = None
    ) -> None:
        if default_branch and branch_name == default_branch:
            self.runner.print(
                f"Error: Refusing to delete the default branch '{branch_name}'.",
                file=sys.stderr,
            )
            return
        try:
            self._request_no_content(
                "DELETE", f"/repos/{LLVM_REPO}/git/refs/heads/{branch_name}"
            )
        except urllib.error.HTTPError as e:
            if e.code == 422:
                self.runner.print(
                    f"Warning: Remote branch '{branch_name}' was already deleted, skipping deletion.",
                    file=sys.stderr,
                )
            else:
                raise e


class LLVMPRAutomator:
    """Automates the process of creating and landing a stack of GitHub Pull Requests."""

    def __init__(
        self,
        runner: CommandRunner,
        github_api: "GitHubAPI",
        config: "PRAutomatorConfig",
        remote: str,
    ):
        self.runner = runner
        self.github_api = github_api
        self.config = config
        self.remote = remote
        self.original_branch: str = ""
        self.created_branches: List[str] = []
        self.repo_settings: dict = {}

    def _get_git_env(self) -> dict:
        git_env = os.environ.copy()
        git_env[LLVM_GITHUB_TOKEN_VAR] = self.config.token
        git_env["GIT_TERMINAL_PROMPT"] = "0"
        git_env["GIT_CONFIG_COUNT"] = "1"
        git_env["GIT_CONFIG_KEY_0"] = "credential.helper"
        git_env[
            "GIT_CONFIG_VALUE_0"
        ] = f"!{sys.executable} -c \"import os; print('username=x'); print('password=' + os.environ['{LLVM_GITHUB_TOKEN_VAR}']);\""
        return git_env

    def _get_current_branch(self) -> str:
        result = self.runner.run_command(
            ["git", "rev-parse", "--abbrev-ref", "HEAD"],
            capture_output=True,
            text=True,
            read_only=True,
        )
        return result.stdout.strip()

    def _check_work_tree(self) -> None:
        result = self.runner.run_command(
            ["git", "status", "--porcelain"],
            capture_output=True,
            text=True,
            read_only=True,
        )
        if result.stdout.strip():
            raise LlvmPrError(
                "Your working tree is dirty. Please stash or commit your changes."
            )

    def _rebase_current_branch(self) -> None:
        self._check_work_tree()

        target = f"{self.config.upstream_remote}/{self.config.base_branch}"
        self.runner.print(
            f"Fetching from '{self.config.upstream_remote}' and rebasing '{self.original_branch}' on top of '{target}'..."
        )

        git_env = self._get_git_env()

        refspec = f"refs/heads/{self.config.base_branch}:refs/remotes/{self.config.upstream_remote}/{self.config.base_branch}"
        self.runner.run_command(
            ["git", "fetch", self.config.upstream_remote, refspec], env=git_env
        )

        try:
            self.runner.run_command(["git", "rebase", target], env=git_env)
        except subprocess.CalledProcessError as e:
            self.runner.print(
                "Error: The rebase operation failed, likely due to a merge conflict.",
                file=sys.stderr,
            )
            if e.stdout:
                self.runner.print(f"--- stdout ---\n{e.stdout}", file=sys.stderr)
            if e.stderr:
                self.runner.print(f"--- stderr ---\n{e.stderr}", file=sys.stderr)

            # Check if rebase is in progress before aborting
            rebase_status_result = self.runner.run_command(
                ["git", "status", "--verify-status=REBASE_HEAD"],
                check=False,
                capture_output=True,
                text=True,
                read_only=True,
                env=git_env,
            )

            # REBASE_HEAD exists, so rebase is in progress
            if rebase_status_result.returncode == 0:
                self.runner.print("Aborting rebase...", file=sys.stderr)
                self.runner.run_command(
                    ["git", "rebase", "--abort"], check=False, env=git_env
                )
            raise LlvmPrError("rebase operation failed.") from e

    def _get_commit_stack(self) -> List[str]:
        target = f"{self.config.upstream_remote}/{self.config.base_branch}"
        result = self.runner.run_command(
            ["git", "rev-list", "--reverse", f"{target}..HEAD"],
            capture_output=True,
            text=True,
            read_only=True,
        )
        return result.stdout.strip().splitlines()

    def _get_commit_details(self, commit_hash: str) -> Tuple[str, str]:
        # Get the subject and body from git show. Insert "\n\n" between to make
        # parsing simple to do w/ split.
        result = self.runner.run_command(
            ["git", "show", "-s", "--format=%B", commit_hash],
            capture_output=True,
            text=True,
            read_only=True,
        )
        parts = [item.strip() for item in result.stdout.split("\n", 1)]
        title = parts[0]
        body = parts[1] if len(parts) > 1 else ""
        return title, body

    def _sanitize_branch_name(self, text: str) -> str:
        sanitized = re.sub(r"[^\w\s-]", "", text).strip().lower()
        sanitized = re.sub(r"[-\s]+", "-", sanitized)
        # Use "auto-pr" as a fallback.
        return sanitized or "auto-pr"

    def _validate_merge_config(self, num_commits: int) -> None:
        if num_commits > 1:
            if self.config.auto_merge:
                raise LlvmPrError("--auto-merge is only supported for a single commit.")

            if self.config.no_merge:
                raise LlvmPrError(
                    "--no-merge is only supported for a single commit. "
                    "For stacks, the script must merge sequentially."
                )

        self.runner.print(f"Found {num_commits} commit(s) to process.")

    def _create_and_push_branch_for_commit(
        self, commit_hash: str, base_branch_name: str, index: int
    ) -> str:
        branch_name = f"{self.config.prefix}{base_branch_name}-{index + 1}"
        commit_title, _ = self._get_commit_details(commit_hash)
        self.runner.print(f"Processing commit {commit_hash[:7]}: {commit_title}")
        self.runner.print(f"Pushing commit to temporary branch '{branch_name}'")

        git_env = self._get_git_env()

        push_command = [
            "git",
            "push",
            self.remote,
            f"{commit_hash}:refs/heads/{branch_name}",
        ]
        self.runner.run_command(push_command, env=git_env)
        self.created_branches.append(branch_name)
        return branch_name

    def _process_commit(
        self, commit_hash: str, base_branch_name: str, index: int
    ) -> None:
        commit_title, commit_body = self._get_commit_details(commit_hash)

        temp_branch = self._create_and_push_branch_for_commit(
            commit_hash, base_branch_name, index
        )
        pr_number = self.github_api.create_pr(
            head_branch=f"{self.config.user_login}:{temp_branch}",
            base_branch=self.config.base_branch,
            title=commit_title,
            body=commit_body,
            draft=self.config.draft,
        )

        # TODO: There's a possibility of a race with PR labelers workflow.
        # To avoid it, we could create the PR as a draft, set the labels
        # and change the status, but that requires the use of GraphQL API.

        if self.config.labels:
            self.github_api.add_labels(pr_number, self.config.labels)

        if self.config.no_merge:
            return

        if self.config.auto_merge:
            self.github_api.enable_auto_merge(pr_number)
        else:
            merged_branch = self.github_api.merge_pr(pr_number)
            if merged_branch and not self.repo_settings.get("delete_branch_on_merge"):
                # After a merge, the branch should be deleted.
                self.github_api.delete_branch(merged_branch)

        if temp_branch in self.created_branches:
            # If the branch was successfully merged, it should not be deleted
            # again during cleanup.
            self.created_branches.remove(temp_branch)

    def run(self) -> None:
        self.repo_settings = self.github_api.get_repo_settings()
        self.original_branch = self._get_current_branch()
        self.runner.print(f"On branch: {self.original_branch}")

        try:
            commits = self._get_commit_stack()
            if not commits:
                self.runner.print("No new commits to process.")
                return

            self._validate_merge_config(len(commits))
            branch_base_name = self.original_branch
            if self.original_branch == "main":
                first_commit_title, _ = self._get_commit_details(commits[0])
                branch_base_name = self._sanitize_branch_name(first_commit_title)

            for i in range(len(commits)):
                if not commits:
                    self.runner.print("Success! All commits have been landed.")
                    break
                self._process_commit(commits[0], branch_base_name, i)
                self._rebase_current_branch()
                # After a rebase, the commit hashes can change, so we need to
                # get the latest commit stack.
                commits = self._get_commit_stack()

        finally:
            self._cleanup()

    def _cleanup(self) -> None:
        self.runner.print(f"Returning to original branch: {self.original_branch}")
        self.runner.run_command(
            ["git", "checkout", self.original_branch], capture_output=True
        )
        if self.created_branches:
            self.runner.print("Cleaning up temporary remote branches...")
            for branch in self.created_branches:
                self.github_api.delete_branch(branch)


def check_prerequisites(runner: CommandRunner) -> None:
    runner.print("Checking prerequisites...")
    runner.run_command(["git", "--version"], capture_output=True, read_only=True)
    result = runner.run_command(
        ["git", "rev-parse", "--is-inside-work-tree"],
        check=False,
        capture_output=True,
        text=True,
        read_only=True,
    )

    if result.returncode != 0 or result.stdout.strip() != "true":
        raise LlvmPrError("This script must be run inside a git repository.")
    runner.print("Prerequisites met.")


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Create and land a stack of Pull Requests."
    )

    command_runner = CommandRunner()
    token = os.getenv(LLVM_GITHUB_TOKEN_VAR)
    if not token:
        raise LlvmPrError(f"{LLVM_GITHUB_TOKEN_VAR} environment variable not set.")

    parser.add_argument(
        "--base",
        default=BASE_BRANCH,
        help=f"Base branch to target (default: {BASE_BRANCH})",
    )
    parser.add_argument(
        "--remote",
        default=GITHUB_REMOTE_NAME,
        help=f"Remote for your fork to push to (default: {GITHUB_REMOTE_NAME})",
    )
    parser.add_argument(
        "--upstream-remote",
        default=UPSTREAM_REMOTE_NAME,
        help=f"Remote for the upstream repository (default: {UPSTREAM_REMOTE_NAME})",
    )
    parser.add_argument(
        "--login",
        default=None,
        help="The user login to use. If not provided this will be queried from the TOKEN",
    )
    parser.add_argument(
        "--prefix",
        default=None,
        help="Prefix for temporary branches (default: users/<username>)",
    )
    parser.add_argument(
        "--draft", action="store_true", help="Create pull requests as drafts."
    )
    parser.add_argument(
        "--labels",
        nargs='*',
        default=[DEFAULT_LABEL],
        help=f"Set the PR labels (default: {DEFAULT_LABEL})."
    )
    merging = parser.add_mutually_exclusive_group()
    merging.add_argument(
        "--no-merge", action="store_true", help="Create PRs but do not merge them."
    )
    merging.add_argument(
        "--auto-merge",
        action="store_true",
        help="Enable auto-merge for each PR instead of attempting to merge immediately.",
    )
    parser.add_argument(
        "--dry-run", action="store_true", help="Print commands without executing them."
    )
    verbosity = parser.add_mutually_exclusive_group()
    verbosity.add_argument(
        "-v", "--verbose", action="store_true", help="Print all commands being run."
    )
    verbosity.add_argument(
        "-q",
        "--quiet",
        action="store_true",
        help="Print only essential output and errors.",
    )

    args = parser.parse_args()

    command_runner = CommandRunner(
        dry_run=args.dry_run, verbose=args.verbose, quiet=args.quiet
    )
    check_prerequisites(command_runner)

    github_api = GitHubAPI(command_runner, token)
    if not args.login:
        try:
            args.login = github_api.get_user_login()
        except urllib.error.HTTPError as e:
            raise LlvmPrError(f"Could not fetch user login from GitHub: {e}") from e

    if not args.prefix:
        args.prefix = f"users/{args.login}/"

    if not args.prefix.endswith("/"):
        args.prefix += "/"

    try:
        config = PRAutomatorConfig(
            user_login=args.login,
            token=token,
            base_branch=args.base,
            upstream_remote=args.upstream_remote,
            prefix=args.prefix,
            draft=args.draft,
            labels=args.labels,
            no_merge=args.no_merge,
            auto_merge=args.auto_merge,
        )
        automator = LLVMPRAutomator(
            runner=command_runner,
            github_api=github_api,
            config=config,
            remote=args.remote,
        )
        automator.run()
    except LlvmPrError as e:
        sys.exit(f"Error: {e}")


if __name__ == "__main__":
    main()
