from io import BytesIO
from typing import Optional
from PIL import Image, ImageOps

# ────────────────────────── Tunables (miniatures) ──────────────────────────
TARGET_PX = 384           # downscale long side to this (try 256–512)
WEBP_METHOD = 6           # 0..6 (6 = best compression, slower)
Q_PHOTO = 50              # lossy quality for photos (40–55 works well)
Q_GRAPHIC = 55            # lossy quality for graphics without alpha
ALPHA_Q = 50              # alpha channel quality for webp
COLOR_SAMPLE = 2          # subsample factor for color heuristic
GRAPHIC_COLOR_THRESHOLD = 512  # <= threshold => likely "graphic/UI"

# ───────────────────────────── Helpers ─────────────────────────────

def estimate_input_size_bytes(image_input) -> Optional[int]:
    """Best-effort input size estimation for logging."""
    try:
        if hasattr(image_input, "getbuffer"):
            return len(image_input.getbuffer())
        if hasattr(image_input, "getvalue"):
            return len(image_input.getvalue())
        if hasattr(image_input, "seek") and hasattr(image_input, "tell"):
            pos = image_input.tell()
            image_input.seek(0, 2)
            size = image_input.tell()
            image_input.seek(pos)
            return size
    except Exception:
        return None
    return None

def ensure_color_mode(img: Image.Image) -> Image.Image:
    """Convert to RGB/RGBA while preserving alpha if present."""
    has_alpha = "A" in img.getbands()
    if img.mode not in ("RGB", "RGBA"):
        img = img.convert("RGBA" if has_alpha else "RGB")
    return img

def downscale_long_side(img: Image.Image, target_px: int) -> Image.Image:
    """Resize so that the longest side is <= target_px using high-quality filter."""
    w, h = img.size
    m = max(w, h)
    if m <= target_px:
        return img
    scale = target_px / float(m)
    nw, nh = max(1, int(w * scale)), max(1, int(h * scale))
    return img.resize((nw, nh), Image.LANCZOS)

def is_graphic_like(img: Image.Image, sample: int = COLOR_SAMPLE,
                    color_threshold: int = GRAPHIC_COLOR_THRESHOLD) -> bool:
    """
    Heuristic: few distinct colors => likely UI/graphics (logos, screenshots).
    Many colors => photograph.
    """
    w, h = img.size
    small = img.resize((max(1, w // sample), max(1, h // sample)), Image.LANCZOS)
    try:
        colors = small.convert("RGBA").getcolors(maxcolors=256 * 256 * 256)
    except Exception:
        colors = None
    if colors is None:
        return False  # too many colors counted => photograph
    return len(colors) <= color_threshold

# ─────────────────────────── Main API ────────────────────────────

def convert_image_to_webp(image_input):
    """
    Aggressive miniature compression to WEBP.
    - auto-rotate via EXIF
    - downscale to TARGET_PX
    - strip EXIF/ICC (by not passing them)
    - different strategy for graphics vs photos
    - returns BytesIO with WEBP
    """
    try:
        im = Image.open(image_input)
        # Auto-rotate per EXIF, but do not carry EXIF forward
        im = ImageOps.exif_transpose(im)
        im = ensure_color_mode(im)

        orig_size_bytes = estimate_input_size_bytes(image_input)

        # Downscale first — biggest size reduction
        im = downscale_long_side(im, TARGET_PX)

        # Decide compression strategy
        graphic = is_graphic_like(im)
        has_alpha = ("A" in im.getbands())

        out = BytesIO()

        if graphic:
            # Try lossless for crisp edges; if not ideal, fall back to lossy.
            try:
                im.save(out, format="WEBP", lossless=True, method=WEBP_METHOD)
            except Exception:
                out = BytesIO()
                im.save(
                    out,
                    format="WEBP",
                    quality=Q_GRAPHIC,
                    method=WEBP_METHOD,
                    exact=has_alpha,      # better edges around alpha
                    alpha_quality=ALPHA_Q,
                )
        else:
            # Photographs: lossy with subsampling; keep alpha reasonable if present
            im.save(
                out,
                format="WEBP",
                quality=Q_PHOTO,
                method=WEBP_METHOD,
                alpha_quality=ALPHA_Q,
                exact=has_alpha,
            )

        out.seek(0)

        new_size = out.getbuffer().nbytes
        if orig_size_bytes:
            saved = orig_size_bytes - new_size
            pct = (saved / orig_size_bytes * 100.0) if orig_size_bytes else 0.0
            print(
                f"✅ WEBP thumbnail {im.size} | {new_size} B "
                f"(saved ~{saved} B, {pct:.1f}%)"
            )
        else:
            print(f"✅ WEBP thumbnail {im.size} | {new_size} B")

        return out

    except Exception as e:
        print("❌ Błąd podczas konwersji do WebP:", str(e))
        import traceback
        print(traceback.format_exc())
        raise
