#!/usr/bin/python3
#
# /// script
# dependencies = [
#   "click",
#   "requests",
#   "jinja2",
#   "ruamel.yaml",
# ]
# ///

"""
List issues and pull requests with story points for the given sprint
"""

import dataclasses
import os
import re
import sys
from collections import defaultdict
from collections.abc import Iterator
from typing import Any, Optional

import click
import jinja2
import requests
from ruamel.yaml import YAML

from common import Item  # isort: skip

# https://docs.github.com/en/rest/about-the-rest-api/breaking-changes
GITHUB_API_URL = "https://api.github.com"
GITHUB_API_VERSION = "2026-03-10"
GITHUB_ORGANIZATION = "teemtee"
GITHUB_PROJECT_NUMBER = 1

# https://docs.github.com/en/rest/using-the-rest-api/using-pagination-in-the-rest-api
GITHUB_ITEMS_PER_PAGE = 100
GITHUB_NEXT_PAGE_PATTERN = re.compile(r'<([^>]+)>;\s*rel="next"')

# https://docs.github.com/en/rest/projects/fields
GITHUB_FIELD_SIZE = 252508798
GITHUB_FIELD_SPRINT = 223204946

DISPLAY_SECTIONS = [
    ("open-issue", "Open Issues"),
    ("open-pull", "Open Pull Requests"),
    ("closed-issue", "Closed Issues"),
    ("merged-pull", "Merged Pull Requests"),
    ("closed-pull", "Closed Pull Requests"),
]

DISPLAY_COMPARISON_SECTIONS = [
    ('completed', 'Completed'),
    ('pending', 'Pending'),
    ('added', 'Added'),
    ('removed', 'Removed'),
    ('completed-original', 'Completed original items'),
    ('incomplete-original', 'Incomplete original items'),
    ('complete-added', 'Completed added items'),
    ('incomplete-added', 'Incompleted added items'),
]

DISPLAY_TEMPLATE = jinja2.Template(
    trim_blocks=True,
    source="""
================================================================================
  Sprint: {{ sprint_name }}
  Total items: {{ total_count }}
  Total story points: {{ total_points }}
================================================================================
{% for key, label in sections %}
{% if categories.get(key) %}

  {{ label }} ({{ categories[key] | length }}) - {{ points[key] }} points
--------------------------------------------------------------------------------
{% for item in categories[key] %}
  {{ item }}
{% endfor %}
{% endif %}
{% endfor %}

================================================================================
  Summary: {{ open_count }} open ({{ open_points }} points), \
{{ closed_count }} closed/merged ({{ closed_points }} points)
================================================================================
""",
)


DISPLAY_COMPARISON_TEMPLATE = jinja2.Template(
    trim_blocks=True,
    source="""
================================================================================
  Sprint:                     {{ sprint_name }}
  Total items:                {{ count_base }} => {{ count_current }}
  Total story points:         {{ size_base }} => {{ size_current }}
  Completed items:            {{ count_completed }} vs {{ count_pending }}
  Completed story points:     {{ size_completed }} vs {{ size_pending }}
  Added/removed items:        {{ count_added }} vs {{ count_removed }}
  Added/removed story points: {{ size_added }} vs {{ size_removed }}
================================================================================
{% for key, label in sections %}
{% if categories.get(key) %}

  {{ label }} ({{ categories[key] | length }}) - {{ categories[key] | sum(attribute="safe_size") }} points
--------------------------------------------------------------------------------
{% for item in categories[key] %}
  {{ item }}
{% endfor %}
{% endif %}
{% endfor %}
""",  # noqa: E501
)


def load_items(filepath: str) -> list[Item]:
    with open(filepath) as f:
        return [Item(**item) for item in YAML().load(f)]


def github_api_get(
    path: str,
    params: Optional[dict[str, str]] = None,
) -> Iterator[list[dict[str, Any]]]:
    """
    Make a GET request to the GitHub REST API, following pagination.

    :param path: API path relative to the project endpoint,
        e.g. ``fields`` or ``items``.
    :param params: optional query parameters.
    :returns: an iterator of parsed JSON pages (each a list of items).
    """

    url = f"{GITHUB_API_URL}/orgs/{GITHUB_ORGANIZATION}/projectsV2/{GITHUB_PROJECT_NUMBER}/{path}"

    headers: dict[str, str] = {
        "Accept": "application/vnd.github+json",
        "X-GitHub-Api-Version": GITHUB_API_VERSION,
    }

    if token := os.environ.get("GITHUB_PERSONAL_ACCESS_TOKEN"):
        headers["Authorization"] = f"Bearer {token}"

    page = 0
    while url:
        page += 1
        response = requests.get(url, headers=headers, params=params, timeout=30)
        response.raise_for_status()
        data: list[dict[str, Any]] = response.json()
        click.echo(f"Page {page}: {len(data)} items", err=True)
        yield data

        # Follow pagination via Link header
        link_header = response.headers.get("Link", "")
        match = GITHUB_NEXT_PAGE_PATTERN.search(link_header)
        url = match.group(1) if match else None
        params = None


def fetch_sprint_items(sprint_name: str) -> list[Item]:
    """
    Fetch all items in the given sprint.

    :param sprint_name: name of the sprint to fetch items for.
    :returns: a list of :py:class:`Item` instances sorted by id.
    """

    click.echo(
        f"Fetching items from {GITHUB_ORGANIZATION}/project/{GITHUB_PROJECT_NUMBER}"
        f" for sprint '{sprint_name}'...",
        err=True,
    )

    all_items: list[Item] = []

    params = {
        "per_page": str(GITHUB_ITEMS_PER_PAGE),
        "fields": f"{GITHUB_FIELD_SIZE},{GITHUB_FIELD_SPRINT}",
        "q": f"Sprint:'{sprint_name}'",
    }

    for page_items in github_api_get("items", params=params):
        for item in page_items:
            content = item.get("content")
            if not content:
                continue

            item_type = "pull" if item["content_type"] == "PullRequest" else "issue"
            size_field = next((f for f in item.get("fields", []) if f["name"] == "Size"), None)
            size = size_field.get("value") if size_field else None

            if item_type == "pull":
                status = "merged" if content.get("merged_at") else content["state"]
                repo = content["base"]["repo"]["name"]
            else:
                status = content["state"]
                repo = content["repository"]["name"]

            all_items.append(
                Item(
                    id=content["number"],
                    type=item_type,
                    repo=repo,
                    status=status,
                    size=size,
                    url=content["html_url"],
                    title=content["title"],
                )
            )

    all_items.sort(key=lambda item: item.id)
    click.echo(f"Total: {len(all_items)} items", err=True)
    return all_items


def display_items(sprint_name: str, items: list[Item]) -> None:
    """
    Display items in a readable format with story points.

    :param sprint_name: name of the sprint being displayed.
    :param items: list of :py:class:`Item` instances as returned by
        :py:func:`fetch_sprint_items`.
    """

    categories: dict[str, list[Item]] = defaultdict(list)
    for item in items:
        categories[f'{item.status}-{item.type}'].append(item)

    total_count = sum(len(items) for items in categories.values())
    open_count = len(categories.get("open-issue", [])) + len(categories.get("open-pull", []))
    closed_count = total_count - open_count

    points = {
        category: sum(item.size or 0 for item in category_items)
        for category, category_items in categories.items()
    }

    total_points = sum(points.values())
    open_points = points.get("open-issue", 0) + points.get("open-pull", 0)
    closed_points = total_points - open_points

    click.echo(
        DISPLAY_TEMPLATE.render(
            sprint_name=sprint_name,
            categories=categories,
            sections=DISPLAY_SECTIONS,
            points=points,
            total_count=total_count,
            total_points=total_points,
            open_count=open_count,
            closed_count=closed_count,
            open_points=open_points,
            closed_points=closed_points,
        )
    )


def display_comparison(sprint_name: str, base: list[Item], current: list[Item]) -> None:
    """
    Display difference between two sets of items in a readable format.

    :param base: list of the "base", initial items, representing a past
        state of the sprint.
    :param items: list of the current items, representing the current
        state of the sprint.
    """

    base.sort(key=lambda x: x.id)

    base_map = {item.id: item for item in base}
    current_map = {item.id: item for item in current}

    completed: list[Item] = []
    pending: list[Item] = []
    added: list[Item] = []
    removed: list[Item] = []

    for base_id, base_item in base_map.items():
        current_item = current_map.pop(base_id, None)

        if current_item is None:
            removed.append(base_item)

        elif current_item.status in ('closed', 'merged'):
            completed.append(base_item)

        else:
            pending.append(base_item)

    for current_item in current_map.values():
        added.append(current_item)

        if current_item.status in ('closed', 'merged'):
            completed.append(current_item)

        else:
            pending.append(current_item)

    def size(items: list[Item]) -> int:
        return sum(item.size or 0 for item in items)

    count_base, count_current, size_base, size_current = (
        len(base),
        len(current),
        size(base),
        size(current),
    )
    count_completed, count_pending, size_completed, size_pending = (
        len(completed),
        len(pending),
        size(completed),
        size(pending),
    )
    count_added, size_added = len(added), size(added)
    count_removed, size_removed = len(removed), size(removed)

    click.echo(
        DISPLAY_COMPARISON_TEMPLATE.render(
            sections=DISPLAY_COMPARISON_SECTIONS,
            categories={
                'completed': completed,
                'pending': pending,
                'added': added,
                'removed': removed,
                'completed-original': [item for item in completed if item not in added],
                'incomplete-original': [item for item in pending if item not in added],
                'complete-added': [item for item in added if item in completed],
                'incomplete-added': [item for item in added if item not in completed],
            },
            sprint_name=sprint_name,
            count_base=count_base,
            count_current=count_current,
            size_base=size_base,
            size_current=size_current,
            count_completed=count_completed,
            count_pending=count_pending,
            size_completed=size_completed,
            size_pending=size_pending,
            count_added=count_added,
            size_added=size_added,
            count_removed=count_removed,
            size_removed=size_removed,
        )
    )


@click.command()
@click.option(
    '--sprint',
    metavar='NAME',
    default='@current',
    help="Name of the sprint, e.g. 'Sprint 11'.",
)
@click.option(
    '--yaml',
    'output_yaml',
    is_flag=True,
    help='Output in YAML format for machine-readable processing.',
)
@click.option(
    '--base',
    'base_path',
    metavar='PATH',
    default=None,
    help='Instead of reporting, compare the current sprint against this saved state.',
)
@click.option(
    '--current',
    'current_path',
    metavar='PATH',
    default=None,
    help='If set, compare --base against --current instead of the live sprint state.',
)
def main(
    sprint: str,
    output_yaml: bool,
    base_path: Optional[str] = None,
    current_path: Optional[str] = None,
) -> None:
    """
    List all issues and pull requests in a GitHub Project sprint with story points.

    Set the GITHUB_PERSONAL_ACCESS_TOKEN environment variable to avoid rate limits.
    """

    if current_path is not None:
        items = load_items(current_path)

    else:
        items = fetch_sprint_items(sprint)

        if not items:
            click.echo(f"No items found in sprint '{sprint}'.")
            return

    if output_yaml:
        YAML().dump([dataclasses.asdict(item) for item in items], sys.stdout)

    elif base_path is not None:
        base = load_items(base_path)

        display_comparison(sprint, base, items)

    elif items:
        display_items(sprint, items)

    else:
        click.echo(f"No items found in sprint '{sprint}'.", err=True)


if __name__ == "__main__":
    main()
