"""
Schema Analyzer v6 (full)
=========================
Recursive DOM visualiser with ASCII preview + fingerprints.

Changelog vs v5
---------------
* KV_RE relaxed – matches labels up to 80 chars, allows spaces before “:”.
* New detector `_detect_kv_list`:
  - if a container has **≥4 direct children** and every child’s stripped text
    matches the KV regex, we treat it as *label/value* **table 2 × N**.
  - Such blocks are reported as `[KV] KV` in the preview (rows= N, cols=2).
* Minor: helper `_block_label()` for nicer headers, consolidated ASCII builder.

Public API
----------
>>> sa = SchemaAnalyzer(html)
>>> print(sa.pretty_ascii())        # nice tree with snippets
>>> print(sa.compact_fingerprint)   # collapsed digest
"""

from __future__ import annotations

import argparse
import itertools
import re
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import List, Sequence

from bs4 import BeautifulSoup, Tag

# ═════════════ visuals ═════════════
CELL      = "____"
PARA      = "-----"
SEP_THIN  = "-" * 23
SEP_THICK = "=" * 23

# ═══════════ heuristics ════════════
# label : value  (now allows optional spaces before “:”)
KV_RE             = re.compile(r".{1,80}\s*:\s+\S+")
GRID_CLASS_RE     = re.compile(r"^(row|col|grid|g-\d+)\b")
CARD_CLASS_RE     = re.compile(r"card")
DATA_GRID_ATTRS   = {"data-row", "data-col", "data-grid"}

# how many li/dd/… must repeat to qualify as a “block”
MIN_REPEAT_DEFAULT = 3
MIN_BLOB_DEFAULT   = 120           # paragraph blob min length


# ────────────────────────────────────────────────────────────────────
@dataclass
class BlockShape:
    """Abstract description of one visual block."""
    rows:  List[int]   # [-1] blob | [1,1,…] list | [n,n,…] table
    depth: int
    type_: str         # "LIST" / "KV" / "TABLE" / "BLOB"

    # ------------- ascii lines -------------
    def ascii(self, *, max_rows: int | None = None) -> List[str]:
        if self.rows == [-1]:
            return [PARA]

        # compress long blocks
        if max_rows is not None and len(self.rows) > max_rows:
            head = [" ".join(itertools.repeat(CELL, n)) for n in self.rows[:4]]
            tail = [" ".join(itertools.repeat(CELL, n)) for n in self.rows[-2:]]
            skipped = len(self.rows) - 6
            head.append(f"… ×{skipped} …")
            return head + tail

        return [" ".join(itertools.repeat(CELL, n)) for n in self.rows]

    # ------------- fingerprint -------------
    @property
    def fingerprint(self) -> str:
        if self.rows == [-1]:
            return f"P1@d{self.depth}"
        cols = Counter(self.rows).most_common(1)[0][0]
        label = "KV" if self.type_ == "KV" else f"R{len(self.rows)}C{cols}"
        return f"{label}@d{self.depth}"


# ===================================================================
class SchemaAnalyzer:
    # ------------------------------------------------ init ----------
    def __init__(
        self,
        html: str | bytes,
        *,
        max_depth: int | None       = None,           # None == unlimited
        min_repeat: int             = MIN_REPEAT_DEFAULT,
        min_blob_len: int           = MIN_BLOB_DEFAULT,
        ascii_max_rows: int | None  = 50,
        ascii_indent: str           = "  ",
        snippet_len: int            = 70,
    ):
        self.max_depth     = max_depth
        self.min_repeat    = min_repeat
        self.min_blob_len  = min_blob_len
        self.ascii_max_rows= ascii_max_rows
        self.ascii_indent  = ascii_indent
        self.snippet_len   = snippet_len

        self.soup   = BeautifulSoup(html, "lxml")
        self.shapes: List[BlockShape] = []
        self._crawl(self.soup.body, depth=0)

    # ==============================================================
    # public helpers
    # ==============================================================
    def pretty_ascii(self) -> str:
        """Return full drawing with block headers & text snippets."""
        out: List[str] = []
        for i, blk in enumerate(self.shapes):
            if i:
                out.append(SEP_THICK if blk.rows != [-1] else SEP_THIN)

            # header
            indent = self.ascii_indent * blk.depth
            rows_repr = "blob" if blk.rows == [-1] else f"size:{len(blk.rows)}×{Counter(blk.rows).most_common(1)[0][0]}"
            hdr = f"[{self._block_label(blk)}]".ljust(4)
            out.append(f"{indent}{hdr}  depth:{blk.depth}  {rows_repr}")

            # ascii
            for line in blk.ascii(max_rows=self.ascii_max_rows):
                out.append(f"{indent}{line}")

            # text snippet
            snippet = self._node_snippet(blk._node)  # type: ignore
            if snippet:
                out.append(f"{indent}snippet: \"{snippet}\"")
        return "\n".join(out)

    @property
    def full_fingerprint(self) -> str:
        return " | ".join(b.fingerprint for b in self.shapes)

    @property
    def compact_fingerprint(self) -> str:
        if not self.shapes:
            return ""
        out, prev, cnt = [], self.shapes[0].fingerprint, 1
        for blk in self.shapes[1:]:
            fp = blk.fingerprint
            if fp == prev:
                cnt += 1
            else:
                out.append(f"{prev}×{cnt}" if cnt > 1 else prev)
                prev, cnt = fp, 1
        out.append(f"{prev}×{cnt}" if cnt > 1 else prev)
        return " | ".join(out)

    # ==============================================================
    # internal traversal
    # ==============================================================
    def _crawl(self, node: Tag, *, depth: int):
        if self.max_depth is not None and depth > self.max_depth:
            return

        blk = self._analyse_node(node, depth)
        if blk:
            # keep reference to original Tag for snippet extraction
            blk._node = node     # type: ignore
            self.shapes.append(blk)

        for child in (c for c in node.children if isinstance(c, Tag)):
            self._crawl(child, depth=depth + 1)

    # --------------------------------------------------------------
    # detector dispatcher
    # --------------------------------------------------------------
    def _analyse_node(self, node: Tag, depth: int) -> BlockShape | None:
        return (
            self._detect_kv_list(node, depth)      # new!
            or self._detect_table(node, depth)
            or self._detect_list_like(node, depth)
            or self._detect_row_cols(node, depth)
            or self._detect_css_grid(node, depth)
            or self._detect_cards(node, depth)
            or self._detect_data_grid(node, depth)
            or self._detect_paragraph_blob(node, depth)
        )

    # ==============================================================
    # detectors
    # ==============================================================

    # ---------- brand-new KV list detector ------------------------
    def _detect_kv_list(self, node: Tag, depth: int):
        children = [c for c in node.children if isinstance(c, Tag)]
        if len(children) < 4:
            return None
        if all(KV_RE.match(c.get_text(" ", strip=True)) for c in children):
            return BlockShape([2] * len(children), depth, "KV")
        return None

    # ---------- table --------------------------------------------
    def _detect_table(self, node: Tag, depth: int):
        if node.name != "table":
            return None
        rows: List[int] = []
        tr_list = node.find_all("tr")
        for tr in tr_list:
            cells = tr.find_all(["td", "th"], recursive=False)
            if not cells:
                continue
            col_count = sum(int(c.get("colspan", 1)) for c in cells)
            rows.append(col_count)
        if rows:
            return BlockShape(rows, depth, "TABLE")
        return None

    # ---------- UL/OL/DL -----------------------------------------
    def _detect_list_like(self, node: Tag, depth: int):
        if node.name in {"ul", "ol"}:
            lis = node.find_all("li", recursive=False)
            if len(lis) >= self.min_repeat:
                rows = [2 if KV_RE.match(li.get_text(" ", strip=True)) else 1 for li in lis]
                tp = "KV" if any(r == 2 for r in rows) else "LIST"
                return BlockShape(rows, depth, tp)

        if node.name == "dl":
            dts = node.find_all("dt", recursive=False)
            dds = node.find_all("dd", recursive=False)
            if dts and len(dts) == len(dds):
                return BlockShape([2] * len(dts), depth, "KV")
        return None

    # ---------- row/col repetition -------------------------------
    def _detect_row_cols(self, node: Tag, depth: int):
        # bootstrap “row”
        if node.name == "div" and "row" in (node.get("class") or []):
            cols = [c for c in node.find_all("div", recursive=False)
                    if any(GRID_CLASS_RE.match(cls) for cls in (c.get("class") or []))]
            if len(cols) >= self.min_repeat:
                return BlockShape([1] * len(cols), depth, "LIST")

        # generic: ≥min_repeat identical tags
        children: Sequence[Tag] = [c for c in node.children if isinstance(c, Tag)]
        if len(children) >= self.min_repeat:
            dominant, qty = Counter(c.name for c in children).most_common(1)[0]
            if qty >= self.min_repeat:
                return BlockShape([1] * qty, depth, "LIST")
        return None

    # ---------- CSS grid heuristics ------------------------------
    def _detect_css_grid(self, node: Tag, depth: int):
        styles = node.get("style", "")
        if "display:grid" in styles:
            children = [c for c in node.children if isinstance(c, Tag)]
            if len(children) >= self.min_repeat:
                return BlockShape([1] * len(children), depth, "LIST")

        if any("grid" in cls for cls in (node.get("class") or [])):
            children = [c for c in node.children if isinstance(c, Tag)]
            if len(children) >= self.min_repeat:
                return BlockShape([1] * len(children), depth, "LIST")
        return None

    # ---------- cards block --------------------------------------
    def _detect_cards(self, node: Tag, depth: int):
        children = [c for c in node.children if isinstance(c, Tag)]
        if len(children) >= self.min_repeat and all(
            any(CARD_CLASS_RE.search(cls) for cls in (c.get("class") or []))
            for c in children
        ):
            return BlockShape([1] * len(children), depth, "LIST")
        return None

    # ---------- data-grid libs -----------------------------------
    def _detect_data_grid(self, node: Tag, depth: int):
        if any(attr in node.attrs for attr in DATA_GRID_ATTRS):
            children = [c for c in node.children if isinstance(c, Tag)]
            if len(children) >= self.min_repeat:
                return BlockShape([1] * len(children), depth, "LIST")
        return None

    # ---------- paragraph blob -----------------------------------
    def _detect_paragraph_blob(self, node: Tag, depth: int):
        txt = re.sub(r"\s+", " ", node.get_text(" ", strip=True))
        if len(txt) >= self.min_blob_len and not node.find_all(["table", "ul", "ol", "dl"]):
            return BlockShape([-1], depth, "BLOB")
        return None

    # ==============================================================
    # helpers
    # ==============================================================

    def _node_snippet(self, node: Tag) -> str:
        txt = re.sub(r"\s+", " ", node.get_text(" ", strip=True))
        return txt[: self.snippet_len].strip() + ("…" if len(txt) > self.snippet_len else "")

    @staticmethod
    def _block_label(blk: BlockShape) -> str:
        if blk.type_ == "BLOB":
            return "P"
        if blk.type_ == "KV":
            return "KV"
        if blk.rows == [-1]:
            return "P"
        cols = Counter(blk.rows).most_common(1)[0][0]
        return "L" if cols == 1 else "T"


# ===================================================================
# CLI helper
# ===================================================================
if __name__ == "__main__":
    ap = argparse.ArgumentParser(description="Schema Analyzer v6 – full")
    ap.add_argument("html_file", type=Path)
    ap.add_argument("--depth",   type=int, default=None)
    ap.add_argument("--repeat",  type=int, default=MIN_REPEAT_DEFAULT)
    ap.add_argument("--blob",    type=int, default=MIN_BLOB_DEFAULT)
    ap.add_argument("--compact", action="store_true", help="print only compact fingerprint")
    args = ap.parse_args()

    html_src = args.html_file.read_text(encoding="utf8")
    sa = SchemaAnalyzer(
        html_src,
        max_depth   = args.depth,
        min_repeat  = args.repeat,
        min_blob_len= args.blob,
    )

    if args.compact:
        print(sa.compact_fingerprint)
    else:
        print("\n=== ASCII preview ===\n" + sa.pretty_ascii())
        print("\n=== full fingerprint ===\n"    + sa.full_fingerprint)
        print("\n=== compact fingerprint ===\n" + sa.compact_fingerprint)
