# monitoring/views_celery.py
# Comments in English as requested.

from typing import Dict, Any, List, DefaultDict
from collections import defaultdict

from rest_framework.views import APIView
from rest_framework.response import Response
from rest_framework.permissions import IsAuthenticated, IsAdminUser, AllowAny
from rest_framework import status

from celery import current_app as celery_app


def _safe(d: dict | None) -> dict:
    return d or {}


class CeleryTopologyView(APIView):
    """
    Returns current Celery topology:
    - workers: their queues, activity, basic stats
    - queues: reverse index queue -> workers
    - summary: counts
    """
    permission_classes = (AllowAny)  # adjust if you want broader access

    def get(self, request, *args, **kwargs):
        try:
            insp = celery_app.control.inspect(timeout=3)
        except Exception as e:
            return Response({"detail": f"inspect init error: {e}"}, status=status.HTTP_503_SERVICE_UNAVAILABLE)

        # Raw probes (may be None if no workers respond)
        try:
            active_queues = _safe(insp.active_queues())
        except Exception:
            active_queues = {}

        try:
            active = _safe(insp.active())
        except Exception:
            active = {}

        try:
            reserved = _safe(insp.reserved())
        except Exception:
            reserved = {}

        try:
            registered = _safe(insp.registered())
        except Exception:
            registered = {}

        try:
            stats = _safe(insp.stats())
        except Exception:
            stats = {}

        # Optional: ping to mark which nodes are alive
        try:
            ping_res = celery_app.control.ping(timeout=2)  # [{'w_name': {'ok': 'pong'}}] or []
        except Exception:
            ping_res = []

        alive_nodes = set()
        for item in (ping_res or []):
            # item is a dict with single {node_name: {'ok': 'pong'}}
            alive_nodes.update(item.keys())

        # Build workers structure
        workers: Dict[str, Dict[str, Any]] = {}
        queues_to_workers: DefaultDict[str, List[str]] = defaultdict(list)

        all_nodes = set().union(active_queues.keys(), active.keys(), reserved.keys(), registered.keys(), stats.keys())
        for node in sorted(all_nodes):
            node_queues = []
            for q in active_queues.get(node, []):
                # q example: {'name': 'check', 'exchange': {...}, 'binding_count': 1, ...}
                name = q.get("name") if isinstance(q, dict) else str(q)
                if name:
                    node_queues.append(name)
                    queues_to_workers[name].append(node)

            node_stats = stats.get(node, {}) or {}
            pool = node_stats.get("pool", {}) if isinstance(node_stats, dict) else {}
            concurrency = pool.get("max-concurrency") or node_stats.get("concurrency")

            workers[node] = {
                "alive": node in alive_nodes,
                "queues": sorted(set(node_queues)),
                "active_count": len(active.get(node, []) or []),
                "reserved_count": len(reserved.get(node, []) or []),
                "registered_count": len(registered.get(node, []) or []),
                "pid": node_stats.get("pid"),
                "hostname": node_stats.get("hostname"),
                "platform": node_stats.get("platform"),
                "prefetch_multiplier": node_stats.get("prefetch_count") or node_stats.get("prefetch_multiplier"),
                "concurrency": concurrency,
                "version": node_stats.get("celery_version") or node_stats.get("version"),
                "broker": node_stats.get("broker", {}).get("transport") if isinstance(node_stats.get("broker"), dict) else None,
            }

        data = {
            "workers": workers,
            "queues": {q: sorted(ws) for q, ws in queues_to_workers.items()},
            "summary": {
                "workers_total": len(workers),
                "alive_workers": sum(1 for w in workers.values() if w["alive"]),
                "queues_total": len(queues_to_workers),
                "tasks_active_total": sum(w["active_count"] for w in workers.values()),
                "tasks_reserved_total": sum(w["reserved_count"] for w in workers.values()),
            },
        }
        return Response(data, status=status.HTTP_200_OK)
