#!/usr/bin/env python3
"""
parallel_run_html.py

Enhanced parallel runner for HTML collection with improved error handling and monitoring.

Key Features:
- Parallel HTML fetching with configurable workers
- Retry mechanism for failed fetches
- Progress tracking with rich progress bars
- Resource monitoring and management
- Graceful interrupt handling
- Detailed logging and statistics
- Resume capability for interrupted runs
- Memory-efficient batch processing

Examples:
  python scripts/parallel_run_html.py --name morizon --workers 4 --max 100
  python scripts/parallel_run_html.py --name morizon --workers 8 --headless
  python scripts/parallel_run_html.py --name morizon --resume --include-fetched
  python scripts/parallel_run_html.py --name otodom,gethome --workers 6 --max 500
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import signal
import sys
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen, PIPE, TimeoutExpired
from typing import Dict, List, Optional, Set, Any
import threading
import psutil
from datetime import datetime

# Progress bar support (optional)
try:
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False
    class tqdm:
        def __init__(self, *args, **kwargs):
            self.total = kwargs.get('total', 0)
            self.current = 0
            
        def update(self, n=1):
            self.current += n
            if self.total > 0:
                pct = (self.current / self.total) * 100
                print(f"\rProgress: {self.current}/{self.total} ({pct:.1f}%)", end='', flush=True)
        
        def close(self):
            print()
            
        def __enter__(self):
            return self
            
        def __exit__(self, *args):
            self.close()

# --- Setup paths and Django ---
THIS_DIR = Path(__file__).parent.absolute()
REPO_ROOT = THIS_DIR.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

MANAGE_PY = REPO_ROOT / "manage.py"

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('parallel_html.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Django setup
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "NetworkMonitoring.settings")
try:
    import django
    django.setup()
    from django.db.models import Q
    from extractly.models import SourceHtml, NetworkMonitoredPage
    logger.info("Django initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize Django: {e}")
    sys.stderr.write(f"[ERROR] Could not initialize Django: {e}\n")
    raise


@dataclass
class FetchResult:
    """Result of fetching HTML for a single page ID."""
    page_id: int
    exit_code: int
    duration: float
    stdout: str = ""
    stderr: str = ""
    retries: int = 0
    timestamp: float = 0.0
    url: str = ""


@dataclass
class FetchStats:
    """Statistics for the entire HTML fetching run."""
    total_pages: int = 0
    completed: int = 0
    failed: int = 0
    skipped: int = 0
    retried: int = 0
    start_time: float = 0.0
    end_time: float = 0.0
    avg_duration: float = 0.0
    peak_memory_mb: float = 0.0
    
    def success_rate(self) -> float:
        return (self.completed / max(1, self.total_pages)) * 100
    
    def total_duration(self) -> float:
        return self.end_time - self.start_time if self.end_time > self.start_time else 0.0


@dataclass
class HTMLConfig:
    """Configuration for the parallel HTML fetcher."""
    source_names: List[str]
    workers: int = 4
    max_pages: int = 0
    batch_size: int = 1
    headless: bool = True
    include_fetched: bool = False
    python_exe: str = sys.executable
    manage_py: str = str(MANAGE_PY)
    cwd: str = str(REPO_ROOT)
    timeout: int = 600  # 10 minutes per page (HTML fetching can be slow)
    max_retries: int = 2
    retry_delay: float = 5.0
    state_file: str = ".html_fetch_state.json"
    log_level: str = "INFO"
    memory_limit_mb: int = 4096


class ProcessMonitor:
    """Monitor system resources during processing."""
    
    def __init__(self):
        self.peak_memory = 0.0
        self.current_memory = 0.0
        self.cpu_percent = 0.0
        self.running = False
        self.thread = None
        
    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._monitor)
        self.thread.daemon = True
        self.thread.start()
        
    def stop(self):
        self.running = False
        if self.thread:
            self.thread.join(timeout=1.0)
            
    def _monitor(self):
        process = psutil.Process()
        while self.running:
            try:
                mem_info = process.memory_info()
                self.current_memory = mem_info.rss / 1024 / 1024  # MB
                self.peak_memory = max(self.peak_memory, self.current_memory)
                self.cpu_percent = process.cpu_percent()
                time.sleep(1.0)
            except Exception:
                pass


class StateManager:
    """Manage fetcher state for resume capability."""
    
    def __init__(self, state_file: str):
        self.state_file = Path(state_file)
        self.state = {
            'completed_ids': set(),
            'failed_ids': set(),
            'stats': {},
            'config': {},
            'last_update': time.time()
        }
        self.lock = threading.Lock()
        
    def load(self) -> bool:
        """Load state from file. Returns True if loaded successfully."""
        try:
            if self.state_file.exists():
                with open(self.state_file, 'r') as f:
                    data = json.load(f)
                    self.state['completed_ids'] = set(data.get('completed_ids', []))
                    self.state['failed_ids'] = set(data.get('failed_ids', []))
                    self.state['stats'] = data.get('stats', {})
                    self.state['config'] = data.get('config', {})
                    self.state['last_update'] = data.get('last_update', time.time())
                logger.info(f"Loaded state: {len(self.state['completed_ids'])} completed, "
                           f"{len(self.state['failed_ids'])} failed")
                return True
        except Exception as e:
            logger.warning(f"Could not load state file: {e}")
        return False
        
    def save(self):
        """Save current state to file."""
        with self.lock:
            try:
                data = {
                    'completed_ids': list(self.state['completed_ids']),
                    'failed_ids': list(self.state['failed_ids']),
                    'stats': self.state['stats'],
                    'config': self.state['config'],
                    'last_update': time.time()
                }
                with open(self.state_file, 'w') as f:
                    json.dump(data, f, indent=2)
            except Exception as e:
                logger.error(f"Could not save state: {e}")
                
    def mark_completed(self, page_id: int):
        with self.lock:
            self.state['completed_ids'].add(page_id)
            self.state['failed_ids'].discard(page_id)
            
    def mark_failed(self, page_id: int):
        with self.lock:
            self.state['failed_ids'].add(page_id)
            
    def is_completed(self, page_id: int) -> bool:
        return page_id in self.state['completed_ids']
        
    def get_pending_ids(self, all_ids: List[int]) -> List[int]:
        """Filter out already completed IDs."""
        return [pid for pid in all_ids if not self.is_completed(pid)]


class InterruptHandler:
    """Handle graceful shutdown on interrupt signals."""
    
    def __init__(self):
        self.interrupted = False
        self.original_handlers = {}
        
    def __enter__(self):
        self.original_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self._handle_interrupt)
        if hasattr(signal, 'SIGTERM'):
            self.original_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self._handle_interrupt)
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        for sig, handler in self.original_handlers.items():
            signal.signal(sig, handler)
            
    def _handle_interrupt(self, signum, frame):
        logger.info("Received interrupt signal, shutting down gracefully...")
        self.interrupted = True


def find_sources_by_name(names: Set[str]) -> List[str]:
    """Find source IDs by name."""
    if not names:
        sources = list(SourceHtml.objects.all())
        logger.info(f"Found {len(sources)} sources")
        return [s.name for s in sources if s.name]
        
    lowered = {n.strip().lower() for n in names if n and n.strip()}
    qs = SourceHtml.objects.all()
    found: List[str] = []
    
    for source in qs:
        candidates = {
            (source.name or "").lower(),
            (source.title or "").lower(),
        }
        if candidates & lowered:
            found.append(source.name)
            logger.debug(f"Matched source: {source.name}")
            
    if not found:
        available = [s.name for s in qs if s.name]
        logger.error(f"No sources found for {sorted(names)}. Available: {available}")
    else:
        logger.info(f"Found {len(found)} matching sources")
        
    return found


def get_pending_page_ids(
    source_name: str,
    max_count: Optional[int],
    include_fetched: bool
) -> List[int]:
    """Get list of page IDs that need HTML fetching."""
    try:
        source = SourceHtml.objects.filter(name=source_name).first()
        if not source:
            logger.warning(f"Source not found: {source_name}")
            return []
        
        # Base query for pages of this source
        if include_fetched:
            # Include already fetched pages (for re-fetching)
            qs = NetworkMonitoredPage.objects.filter(source_id=source.source_id)
        else:
            # Only unfetched or errored pages
            qs = NetworkMonitoredPage.objects.filter(
                source_id=source.source_id
            ).filter(
                Q(html__isnull=True) |
                Q(html__exact="") |
                Q(html__exact="error") |
                Q(html__exact="{}") |
                Q(html__exact="[]") |
                Q(sliced_html__isnull=True) |
                Q(sliced_html__exact="") |
                Q(sliced_html__exact="error") |
                Q(sliced_html__exact="{}") |
                Q(sliced_html__exact="[]")
            )
        
        qs = qs.order_by('id')
        
        if max_count and max_count > 0:
            page_ids = list(qs.values_list('id', flat=True)[:max_count])
        else:
            page_ids = list(qs.values_list('id', flat=True))
        
        logger.info(f"Found {len(page_ids)} pages for {source_name}")
        return page_ids
        
    except Exception as e:
        logger.error(f"Error getting page IDs for {source_name}: {e}")
        return []


def run_html_fetch_with_retry(
    page_id: int,
    config: HTMLConfig,
    max_retries: int = 2
) -> FetchResult:
    """Run manage.py run_html for a page ID with retry logic."""
    start_time = time.time()
    result = FetchResult(page_id=page_id, exit_code=1, duration=0.0, timestamp=start_time)
    
    # Get page URL for logging
    try:
        page = NetworkMonitoredPage.objects.get(id=page_id)
        result.url = page.url
    except Exception:
        pass
    
    for attempt in range(max_retries + 1):
        try:
            cmd = build_html_command(page_id, config)
            
            env = os.environ.copy()
            env.setdefault("PYTHONUNBUFFERED", "1")
            
            proc = Popen(
                cmd,
                stdout=PIPE,
                stderr=PIPE,
                cwd=config.cwd,
                env=env,
                text=True
            )
            
            try:
                stdout, stderr = proc.communicate(timeout=config.timeout)
                result.exit_code = proc.returncode
                result.stdout = stdout
                result.stderr = stderr
                result.retries = attempt
                
                if proc.returncode == 0:
                    logger.debug(f"[OK] HTML fetched: page_id={page_id} (attempt {attempt + 1})")
                    break
                else:
                    logger.warning(f"[FAIL] HTML fetch failed: page_id={page_id} rc={proc.returncode} (attempt {attempt + 1})")
                    if attempt < max_retries:
                        time.sleep(config.retry_delay * (2 ** attempt))  # Exponential backoff
                        
            except TimeoutExpired:
                proc.kill()
                proc.communicate()
                logger.error(f"[TIMEOUT] page_id={page_id} after {config.timeout}s (attempt {attempt + 1})")
                if attempt < max_retries:
                    time.sleep(config.retry_delay)
                    
        except Exception as e:
            logger.error(f"Exception fetching HTML for page_id={page_id}: {e}")
            if attempt < max_retries:
                time.sleep(config.retry_delay)
    
    result.duration = time.time() - start_time
    return result


def build_html_command(page_id: int, config: HTMLConfig) -> List[str]:
    """Build the command to run manage.py run_html."""
    cmd = [config.python_exe, config.manage_py, "run_html", "--id", str(page_id)]
    
    if config.headless:
        cmd.append("--headless")
    if config.include_fetched:
        cmd.append("--include-fetched")
        
    return cmd


def fetch_page_worker(args_tuple) -> FetchResult:
    """Worker function to fetch HTML for a single page ID."""
    page_id, config = args_tuple
    return run_html_fetch_with_retry(page_id, config, config.max_retries)


def main(argv: Optional[List[str]] = None) -> int:
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Parallel HTML fetcher using run_html command",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s --name morizon --workers 4 --max 100
  %(prog)s --name morizon --workers 8 --headless
  %(prog)s --name morizon --resume --include-fetched
  %(prog)s --name otodom,gethome --workers 6 --max 500 --timeout 900
        """
    )
    
    # Required arguments
    parser.add_argument("--name", required=True, help="Source name(s), comma-separated")
    
    # Optional arguments
    parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers (default: 4)")
    parser.add_argument("--max", type=int, default=0, help="Maximum pages to fetch (0 = all)")
    parser.add_argument("--batch-size", type=int, default=1, help="Pages per batch (default: 1)")
    parser.add_argument("--headless", action="store_true", default=True, help="Run browser in headless mode (default)")
    parser.add_argument("--no-headless", action="store_false", dest="headless", help="Show browser window")
    parser.add_argument("--include-fetched", action="store_true", help="Include already fetched pages")
    parser.add_argument("--timeout", type=int, default=600, help="Timeout per page in seconds (default: 600)")
    parser.add_argument("--max-retries", type=int, default=2, help="Max retries per page (default: 2)")
    parser.add_argument("--retry-delay", type=float, default=5.0, help="Delay between retries in seconds (default: 5.0)")
    parser.add_argument("--state-file", default=".html_fetch_state.json", help="State file for resume")
    parser.add_argument("--resume", action="store_true", help="Resume from previous state")
    parser.add_argument("--log-level", default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR"], help="Logging level")
    parser.add_argument("--python", default=sys.executable, help="Python executable path")
    parser.add_argument("--manage", default=str(MANAGE_PY), help="Path to manage.py")
    parser.add_argument("--cwd", default=str(REPO_ROOT), help="Working directory")
    
    args = parser.parse_args(argv)
    
    # Set logging level
    logging.getLogger().setLevel(getattr(logging, args.log_level.upper()))
    
    # Build configuration
    config = HTMLConfig(
        source_names=[s.strip() for s in args.name.split(",") if s.strip()],
        workers=args.workers,
        max_pages=args.max,
        batch_size=args.batch_size,
        headless=args.headless,
        include_fetched=args.include_fetched,
        python_exe=args.python,
        manage_py=args.manage,
        cwd=args.cwd,
        timeout=args.timeout,
        max_retries=args.max_retries,
        retry_delay=args.retry_delay,
        state_file=args.state_file,
        log_level=args.log_level
    )
    
    logger.info("="*60)
    logger.info("PARALLEL HTML FETCHER")
    logger.info("="*60)
    logger.info(f"Workers:          {config.workers}")
    logger.info(f"Target sources:   {', '.join(config.source_names)}")
    logger.info(f"Headless:         {config.headless}")
    logger.info(f"Include fetched:  {config.include_fetched}")
    logger.info(f"Max pages:        {config.max_pages if config.max_pages > 0 else 'all'}")
    logger.info(f"Timeout:          {config.timeout}s")
    logger.info(f"Max retries:      {config.max_retries}")
    logger.info("="*60)
    
    # Initialize state manager
    state_manager = StateManager(config.state_file)
    if args.resume:
        state_manager.load()
    
    # Find sources
    source_names = find_sources_by_name(set(config.source_names))
    if not source_names:
        logger.error("No matching sources found")
        return 1
    
    # Initialize monitoring
    monitor = ProcessMonitor()
    stats = FetchStats(start_time=time.time())
    
    try:
        with InterruptHandler() as interrupt_handler:
            monitor.start()
            
            # Collect all page IDs
            logger.info("Collecting page IDs...")
            all_ids = []
            
            for source_name in source_names:
                page_ids = get_pending_page_ids(
                    source_name=source_name,
                    max_count=config.max_pages if len(source_names) == 1 else None,
                    include_fetched=config.include_fetched
                )
                all_ids.extend(page_ids)
                
                if interrupt_handler.interrupted:
                    break
            
            # Apply max limit across all sources
            if config.max_pages > 0:
                all_ids = all_ids[:config.max_pages]
            
            # Filter out completed IDs if resuming
            pending_ids = state_manager.get_pending_ids(all_ids) if args.resume else all_ids
            
            # Remove duplicates while preserving order
            seen = set()
            unique_ids = []
            for pid in pending_ids:
                if pid not in seen:
                    seen.add(pid)
                    unique_ids.append(pid)
            
            stats.total_pages = len(unique_ids)
            
            if not unique_ids:
                logger.info("No pages to fetch")
                return 0
            
            logger.info(f"Fetching HTML for {len(unique_ids)} pages with {config.workers} workers")
            logger.info("="*60)
            
            # Prepare work items
            work_items = [(page_id, config) for page_id in unique_ids]
            
            # Process with progress tracking
            durations = []
            with tqdm(total=len(unique_ids), desc="Fetching HTML", disable=not HAS_TQDM) as pbar:
                with ProcessPoolExecutor(max_workers=config.workers) as executor:
                    # Submit all tasks
                    future_to_id = {
                        executor.submit(fetch_page_worker, item): item[0]
                        for item in work_items
                    }
                    
                    # Collect results
                    for future in as_completed(future_to_id):
                        if interrupt_handler.interrupted:
                            logger.info("Cancelling remaining tasks...")
                            for f in future_to_id:
                                f.cancel()
                            break
                            
                        try:
                            result = future.result()
                            durations.append(result.duration)
                            
                            if result.exit_code == 0:
                                stats.completed += 1
                                state_manager.mark_completed(result.page_id)
                                logger.debug(f"[OK] {result.page_id} ({result.duration:.1f}s)")
                            else:
                                stats.failed += 1
                                state_manager.mark_failed(result.page_id)
                                logger.warning(f"[FAIL] {result.page_id} (rc={result.exit_code}, {result.duration:.1f}s)")
                                if result.stderr:
                                    error_preview = result.stderr.strip()[:200]
                                    logger.debug(f"  Error: {error_preview}")
                            
                            if result.retries > 0:
                                stats.retried += 1
                            
                            pbar.update(1)
                            
                            # Save state periodically
                            if (stats.completed + stats.failed) % 10 == 0:
                                state_manager.save()
                                
                        except Exception as e:
                            logger.error(f"Task processing error: {e}")
                            logger.debug(traceback.format_exc())
                            stats.failed += 1
                            pbar.update(1)
            
            # Final state save
            state_manager.save()
            
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
        return 130
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        logger.debug(traceback.format_exc())
        return 1
    finally:
        monitor.stop()
        stats.end_time = time.time()
        stats.peak_memory_mb = monitor.peak_memory
        if durations:
            stats.avg_duration = sum(durations) / len(durations)
    
    # Print final statistics
    print("\n" + "="*60)
    print("HTML FETCHING COMPLETE")
    print("="*60)
    print(f"Total pages:      {stats.total_pages:,}")
    print(f"Completed:        {stats.completed:,}")
    print(f"Failed:           {stats.failed:,}")
    print(f"Success rate:     {stats.success_rate():.1f}%")
    print(f"Retried:          {stats.retried:,}")
    print(f"Duration:         {stats.total_duration():.1f}s")
    print(f"Avg per page:     {stats.avg_duration:.2f}s")
    print(f"Peak memory:      {stats.peak_memory_mb:.1f} MB")
    print(f"Pages/hour:       {(stats.completed / max(1, stats.total_duration() / 3600)):.1f}")
    print("="*60)
    
    if stats.failed > 0:
        print(f"\nFailed IDs saved to state file: {config.state_file}")
        print("Run with --resume to retry failed pages")
    
    return 0 if stats.failed == 0 else 1


if __name__ == "__main__":
    raise SystemExit(main())
