#!/usr/bin/env python3
"""
parallel_manual_parser_v3.py

Enhanced parallel runner with improved error handling, monitoring, and performance.

Key Improvements:
- Retry mechanism for failed processes
- Progress tracking with rich progress bars
- Better resource management and monitoring
- Graceful interrupt handling
- Detailed logging and statistics
- Resume capability for interrupted runs
- Memory-efficient batch processing
- Configuration file support

Examples:
  python scripts/parallel_manual_parser_v3.py --name dobry --workers 8 --max 1000
  python scripts/parallel_manual_parser_v3.py --config config/parser.json
  python scripts/parallel_manual_parser_v3.py --name dobry --resume --state-file .parser_state.json
"""

from __future__ import annotations

import argparse
import json
import logging
import os
import signal
import sys
import time
import traceback
from concurrent.futures import ProcessPoolExecutor, as_completed, Future
from dataclasses import dataclass, asdict
from pathlib import Path
from subprocess import Popen, PIPE, TimeoutExpired
from typing import Dict, Iterable, List, Optional, Set, Tuple, Any
from collections import defaultdict, deque
import threading
import psutil
from datetime import datetime, timedelta

# Progress bar support (optional)
try:
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False
    class tqdm:
        def __init__(self, *args, **kwargs):
            self.total = kwargs.get('total', 0)
            self.current = 0
            
        def update(self, n=1):
            self.current += n
            if self.total > 0:
                pct = (self.current / self.total) * 100
                print(f"\rProgress: {self.current}/{self.total} ({pct:.1f}%)", end='', flush=True)
        
        def close(self):
            print()
            
        def __enter__(self):
            return self
            
        def __exit__(self, *args):
            self.close()

# --- Setup paths and Django ---
THIS_DIR = Path(__file__).parent.absolute()
REPO_ROOT = THIS_DIR.parent
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

MANAGE_PY = REPO_ROOT / "manage.py"

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('parallel_parser.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# Django setup with better error handling
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "NetworkMonitoring.settings")
try:
    import django
    django.setup()
    from django.db.models import Q
    from extractly.models import SourceManual, NetworkMonitoredPage
    logger.info("Django initialized successfully")
except Exception as e:
    logger.error(f"Failed to initialize Django: {e}")
    sys.stderr.write(
        f"[ERROR] Could not initialize Django: {e}\n"
        "Make sure DJANGO_SETTINGS_MODULE is correct and dependencies are installed.\n"
    )
    raise


@dataclass
class TaskResult:
    """Result of processing a single page ID."""
    page_id: int
    exit_code: int
    duration: float
    stdout: str = ""
    stderr: str = ""
    retries: int = 0
    timestamp: float = 0.0


@dataclass
class ProcessingStats:
    """Statistics for the entire processing run."""
    total_pages: int = 0
    completed: int = 0
    failed: int = 0
    skipped: int = 0
    retried: int = 0
    start_time: float = 0.0
    end_time: float = 0.0
    avg_duration: float = 0.0
    peak_memory_mb: float = 0.0
    
    def success_rate(self) -> float:
        return (self.completed / max(1, self.total_pages)) * 100
    
    def total_duration(self) -> float:
        return self.end_time - self.start_time if self.end_time > self.start_time else 0.0


@dataclass
class ParserConfig:
    """Configuration for the parallel parser."""
    manuals: List[str]
    workers: int = 8
    max_pages: int = 0
    batch_size: int = 1
    since_id: Optional[int] = None
    until_id: Optional[int] = None
    filter_name: str = ""
    dry_run: bool = False
    force: bool = False
    force_names: List[str] = None
    python_exe: str = sys.executable
    manage_py: str = str(MANAGE_PY)
    cwd: str = str(REPO_ROOT)
    timeout: int = 300  # 5 minutes per page
    max_retries: int = 2
    retry_delay: float = 1.0
    state_file: str = ".parser_state.json"
    log_level: str = "INFO"
    memory_limit_mb: int = 4096
    
    def __post_init__(self):
        if self.force_names is None:
            self.force_names = []


class ProcessMonitor:
    """Monitor system resources during processing."""
    
    def __init__(self):
        self.peak_memory = 0.0
        self.current_memory = 0.0
        self.cpu_percent = 0.0
        self.running = False
        self.thread = None
        
    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._monitor)
        self.thread.daemon = True
        self.thread.start()
        
    def stop(self):
        self.running = False
        if self.thread:
            self.thread.join(timeout=1.0)
            
    def _monitor(self):
        process = psutil.Process()
        while self.running:
            try:
                mem_info = process.memory_info()
                self.current_memory = mem_info.rss / 1024 / 1024  # MB
                self.peak_memory = max(self.peak_memory, self.current_memory)
                self.cpu_percent = process.cpu_percent()
                time.sleep(1.0)
            except Exception:
                pass


class StateManager:
    """Manage parser state for resume capability."""
    
    def __init__(self, state_file: str):
        self.state_file = Path(state_file)
        self.state = {
            'completed_ids': set(),
            'failed_ids': set(),
            'stats': {},
            'config': {},
            'last_update': time.time()
        }
        self.lock = threading.Lock()
        
    def load(self) -> bool:
        """Load state from file. Returns True if loaded successfully."""
        try:
            if self.state_file.exists():
                with open(self.state_file, 'r') as f:
                    data = json.load(f)
                    self.state['completed_ids'] = set(data.get('completed_ids', []))
                    self.state['failed_ids'] = set(data.get('failed_ids', []))
                    self.state['stats'] = data.get('stats', {})
                    self.state['config'] = data.get('config', {})
                    self.state['last_update'] = data.get('last_update', time.time())
                logger.info(f"Loaded state: {len(self.state['completed_ids'])} completed, "
                           f"{len(self.state['failed_ids'])} failed")
                return True
        except Exception as e:
            logger.warning(f"Could not load state file: {e}")
        return False
        
    def save(self):
        """Save current state to file."""
        with self.lock:
            try:
                data = {
                    'completed_ids': list(self.state['completed_ids']),
                    'failed_ids': list(self.state['failed_ids']),
                    'stats': self.state['stats'],
                    'config': self.state['config'],
                    'last_update': time.time()
                }
                with open(self.state_file, 'w') as f:
                    json.dump(data, f, indent=2)
            except Exception as e:
                logger.error(f"Could not save state: {e}")
                
    def mark_completed(self, page_id: int):
        with self.lock:
            self.state['completed_ids'].add(page_id)
            self.state['failed_ids'].discard(page_id)
            
    def mark_failed(self, page_id: int):
        with self.lock:
            self.state['failed_ids'].add(page_id)
            
    def is_completed(self, page_id: int) -> bool:
        return page_id in self.state['completed_ids']
        
    def get_pending_ids(self, all_ids: List[int]) -> List[int]:
        """Filter out already completed IDs."""
        return [pid for pid in all_ids if not self.is_completed(pid)]


class InterruptHandler:
    """Handle graceful shutdown on interrupt signals."""
    
    def __init__(self):
        self.interrupted = False
        self.original_handlers = {}
        
    def __enter__(self):
        self.original_handlers[signal.SIGINT] = signal.signal(signal.SIGINT, self._handle_interrupt)
        if hasattr(signal, 'SIGTERM'):
            self.original_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, self._handle_interrupt)
        return self
        
    def __exit__(self, exc_type, exc_val, exc_tb):
        for sig, handler in self.original_handlers.items():
            signal.signal(sig, handler)
            
    def _handle_interrupt(self, signum, frame):
        logger.info("Received interrupt signal, shutting down gracefully...")
        self.interrupted = True


def chunked(seq: List[int], n: int) -> Iterable[List[int]]:
    """Split sequence into chunks of size n."""
    if n <= 0:
        n = 1
    for i in range(0, len(seq), n):
        yield seq[i:i + n]


def find_manuals(names: Set[str]) -> List[SourceManual]:
    """Find manuals by name with better error reporting."""
    if not names:
        manuals = list(SourceManual.objects.filter(enable=True))
        logger.info(f"Found {len(manuals)} enabled manuals")
        return manuals
        
    lowered = {n.strip().lower() for n in names if n and n.strip()}
    qs = SourceManual.objects.filter(enable=True)
    found: List[SourceManual] = []
    
    for manual in qs:
        candidates = {
            (manual.name or "").lower(),
            (manual.title or "").lower(),
            (getattr(manual.source, "title", "") or "").lower(),
            (getattr(manual.source, "name", "") or "").lower(),
        }
        if candidates & lowered:
            found.append(manual)
            logger.debug(f"Matched manual: {manual.name} ({manual.title})")
            
    if not found:
        available = [m.name for m in qs if m.name]
        logger.error(f"No manuals found for {sorted(names)}. Available: {available}")
    else:
        logger.info(f"Found {len(found)} matching manuals")
        
    return found


def select_page_ids_batch(
    manual: SourceManual,
    max_count: Optional[int],
    since_id: Optional[int],
    until_id: Optional[int],
    forced: bool,
    name_filter: Optional[str],
    batch_size: int = 1000
) -> Iterable[List[int]]:
    """Memory-efficient batch selection of page IDs."""
    base_qs = build_base_queryset(manual, forced, since_id, until_id, name_filter)
    
    processed = 0
    offset = 0
    
    while True:
        if max_count and processed >= max_count:
            break
            
        current_batch_size = min(batch_size, (max_count - processed) if max_count else batch_size)
        batch_ids = list(base_qs.values_list("id", flat=True)[offset:offset + current_batch_size])
        
        if not batch_ids:
            break
            
        yield batch_ids
        processed += len(batch_ids)
        offset += current_batch_size
        
        logger.debug(f"Selected batch: {len(batch_ids)} IDs (total: {processed})")


def build_base_queryset(manual, forced, since_id, until_id, name_filter):
    """Build the base queryset for page selection."""
    if forced:
        qs = (
            NetworkMonitoredPage.objects.filter(source=manual.source)
            .exclude(
                Q(sliced_html__isnull=True)
                | Q(sliced_html__exact="")
                | Q(sliced_html__exact="{}")
                | Q(sliced_html__exact="[]")
                | Q(sliced_html__exact=" ")
                | Q(html__isnull=True)
                | Q(html__exact="")
                | Q(html__exact="error")
                | Q(html__exact="{}")
                | Q(html__exact="[]")
                | Q(html__exact=" ")
            )
            .order_by("id")
        )
    else:
        qs = (
            NetworkMonitoredPage.objects.filter(
                network_ad_manual__isnull=True, source=manual.source
            )
            .exclude(
                Q(sliced_html__isnull=True)
                | Q(sliced_html__exact="")
                | Q(sliced_html__exact="{}")
                | Q(sliced_html__exact="[]")
                | Q(sliced_html__exact=" ")
                | Q(html__isnull=True)
                | Q(html__exact="")
                | Q(html__exact="{}")
                | Q(html__exact="[]")
                | Q(html__exact=" ")
            )
            .order_by("id")
        )

    if since_id:
        qs = qs.filter(id__gte=since_id)
    if until_id:
        qs = qs.filter(id__lte=until_id)
    if name_filter:
        qs = qs.filter(name__icontains=name_filter)
        
    return qs


def run_manage_for_id_with_retry(
    page_id: int,
    config: ParserConfig,
    max_retries: int = 2
) -> TaskResult:
    """Run manage.py for a page ID with retry logic."""
    start_time = time.time()
    result = TaskResult(page_id=page_id, exit_code=1, duration=0.0, timestamp=start_time)
    
    for attempt in range(max_retries + 1):
        try:
            cmd = build_command(page_id, config)
            
            env = os.environ.copy()
            env.setdefault("PYTHONUNBUFFERED", "1")
            
            proc = Popen(
                cmd, 
                stdout=PIPE, 
                stderr=PIPE, 
                cwd=config.cwd, 
                env=env, 
                text=True
            )
            
            try:
                stdout, stderr = proc.communicate(timeout=config.timeout)
                result.exit_code = proc.returncode
                result.stdout = stdout
                result.stderr = stderr
                result.retries = attempt
                
                if proc.returncode == 0:
                    logger.debug(f"Success: page_id={page_id} (attempt {attempt + 1})")
                    break
                else:
                    logger.warning(f"Failed: page_id={page_id} rc={proc.returncode} (attempt {attempt + 1})")
                    if attempt < max_retries:
                        time.sleep(config.retry_delay * (2 ** attempt))  # Exponential backoff
                        
            except TimeoutExpired:
                proc.kill()
                proc.communicate()
                logger.error(f"Timeout: page_id={page_id} (attempt {attempt + 1})")
                if attempt < max_retries:
                    time.sleep(config.retry_delay)
                    
        except Exception as e:
            logger.error(f"Exception processing page_id={page_id}: {e}")
            if attempt < max_retries:
                time.sleep(config.retry_delay)
    
    result.duration = time.time() - start_time
    return result


def build_command(page_id: int, config: ParserConfig) -> List[str]:
    """Build the command to run manage.py."""
    cmd = [config.python_exe, config.manage_py, "manual_parser", "--id", str(page_id)]
    
    if config.dry_run:
        cmd.append("--dry-run")
    if config.force:
        cmd.append("--force")
    if config.force_names:
        cmd.extend(["--force-name", ",".join(config.force_names)])
        
    return cmd


def process_batch_worker(batch_and_config: Tuple[List[int], ParserConfig]) -> List[TaskResult]:
    """Worker function to process a batch of page IDs."""
    batch, config = batch_and_config
    results = []
    
    for page_id in batch:
        result = run_manage_for_id_with_retry(page_id, config, config.max_retries)
        results.append(result)
        
    return results


def load_config_file(config_path: str) -> Dict[str, Any]:
    """Load configuration from JSON file."""
    try:
        with open(config_path, 'r') as f:
            return json.load(f)
    except Exception as e:
        logger.error(f"Could not load config file {config_path}: {e}")
        return {}


def main(argv: Optional[List[str]] = None) -> int:
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Enhanced parallel orchestrator for manual_parser",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  %(prog)s --name dobry --workers 8 --max 1000
  %(prog)s --name dobry,otodom --workers 6 --max 3000 --timeout 600
  %(prog)s --config config/parser.json
  %(prog)s --name dobry --resume --state-file .parser_state.json
        """
    )
    
    # Configuration options
    parser.add_argument("--config", help="Load config from JSON file")
    parser.add_argument("--name", required=True, help="Manual name(s), comma-separated")
    parser.add_argument("--workers", type=int, default=8, help="Number of worker processes")
    parser.add_argument("--max", type=int, default=0, help="Maximum pages to process")
    parser.add_argument("--batch-size", type=int, default=1, help="Pages per batch")
    parser.add_argument("--since-id", type=int, help="Lower bound page ID")
    parser.add_argument("--until-id", type=int, help="Upper bound page ID")
    parser.add_argument("--filter-name", default="", help="Filter by page name")
    parser.add_argument("--dry-run", action="store_true", help="Dry run mode")
    parser.add_argument("--force", action="store_true", help="Force processing")
    parser.add_argument("--force-name", default="", help="Force specific manuals")
    parser.add_argument("--timeout", type=int, default=300, help="Timeout per page (seconds)")
    parser.add_argument("--max-retries", type=int, default=2, help="Max retries per page")
    parser.add_argument("--retry-delay", type=float, default=1.0, help="Delay between retries")
    parser.add_argument("--state-file", default=".parser_state.json", help="State file for resume")
    parser.add_argument("--resume", action="store_true", help="Resume from previous state")
    parser.add_argument("--log-level", default="INFO", help="Logging level")
    parser.add_argument("--memory-limit", type=int, default=4096, help="Memory limit in MB")
    parser.add_argument("--python", default=sys.executable, help="Python executable")
    parser.add_argument("--manage", default=str(MANAGE_PY), help="Path to manage.py")
    parser.add_argument("--cwd", default=str(REPO_ROOT), help="Working directory")
    
    args = parser.parse_args(argv)
    
    # Set logging level
    logging.getLogger().setLevel(getattr(logging, args.log_level.upper()))
    
    # Load config file if specified
    file_config = {}
    if args.config:
        file_config = load_config_file(args.config)
    
    # Build configuration
    config = ParserConfig(
        manuals=[s.strip() for s in args.name.split(",") if s.strip() ],
        workers=args.workers,
        max_pages=args.max,
        batch_size=args.batch_size,
        since_id=args.since_id,
        until_id=args.until_id,
        filter_name=args.filter_name,
        dry_run=args.dry_run,
        force=args.force,
        force_names=[s.strip() for s in args.force_name.split(",") if s.strip()],
        python_exe=args.python,
        manage_py=args.manage,
        cwd=args.cwd,
        timeout=args.timeout,
        max_retries=args.max_retries,
        retry_delay=args.retry_delay,
        state_file=args.state_file,
        log_level=args.log_level,
        memory_limit_mb=args.memory_limit
    )
    
    # Apply config file overrides
    for key, value in file_config.items():
        if hasattr(config, key):
            setattr(config, key, value)
    
    logger.info(f"Starting parallel parser with {config.workers} workers")
    logger.info(f"Target manuals: {', '.join(config.manuals)}")
    
    # Initialize state manager
    state_manager = StateManager(config.state_file)
    if args.resume:
        state_manager.load()
    
    # Find manuals
    manual_set = set(config.manuals) if config.manuals else set()
    manuals = find_manuals(manual_set)
    if not manuals:
        logger.error("No matching manuals found")
        return 1
    
    # Initialize monitoring and stats
    monitor = ProcessMonitor()
    stats = ProcessingStats(start_time=time.time())

    # Ensure durations exists for the finally-block (fix UnboundLocalError)
    durations: List[float] = []
    
    try:
        with InterruptHandler() as interrupt_handler:
            monitor.start()
            
            # Collect all page IDs
            logger.info("Collecting page IDs...")
            all_ids: List[int] = []
            
            for manual in manuals:
                for batch in select_page_ids_batch(
                    manual=manual,
                    max_count=config.max_pages if len(manuals) == 1 else None,
                    since_id=config.since_id,
                    until_id=config.until_id,
                    forced=config.force,
                    name_filter=config.filter_name if config.filter_name else None,
                    batch_size=1000
                ):
                    all_ids.extend(batch)
                    if interrupt_handler.interrupted:
                        break
                if interrupt_handler.interrupted:
                    break
            
            # Apply max limit across all manuals
            if config.max_pages > 0:
                all_ids = all_ids[:config.max_pages]
            
            # Filter out completed IDs if resuming
            pending_ids = state_manager.get_pending_ids(all_ids) if args.resume else all_ids
            
            # Remove duplicates while preserving order
            seen = set()
            unique_ids: List[int] = []
            for pid in pending_ids:
                if pid not in seen:
                    seen.add(pid)
                    unique_ids.append(pid)
            
            stats.total_pages = len(unique_ids)
            
            if not unique_ids:
                # Do not return early: we want final stats printout without crashing
                logger.info("No pages to process")
                # Save state anyway (useful for resume consistency)
                state_manager.save()
            else:
                logger.info(f"Processing {len(unique_ids)} pages in batches of {config.batch_size}")
                
                # Create batches
                batches = list(chunked(unique_ids, config.batch_size))
                batch_configs = [(batch, config) for batch in batches]
                
                # Process with progress tracking
                with tqdm(total=len(unique_ids), desc="Processing pages", disable=not HAS_TQDM) as pbar:
                    with ProcessPoolExecutor(max_workers=config.workers) as executor:
                        # Submit all batches
                        future_to_batch = {
                            executor.submit(process_batch_worker, batch_config): i
                            for i, batch_config in enumerate(batch_configs)
                        }
                        
                        # Collect results
                        for future in as_completed(future_to_batch):
                            if interrupt_handler.interrupted:
                                logger.info("Cancelling remaining tasks...")
                                for f in future_to_batch:
                                    f.cancel()
                                break
                            
                            try:
                                batch_results = future.result()
                                
                                for result in batch_results:
                                    durations.append(result.duration)
                                    
                                    if result.exit_code == 0:
                                        stats.completed += 1
                                        state_manager.mark_completed(result.page_id)
                                        logger.debug(f"✓ {result.page_id}")
                                    else:
                                        stats.failed += 1
                                        state_manager.mark_failed(result.page_id)
                                        logger.warning(f"✗ {result.page_id} (rc={result.exit_code})")
                                        if result.stderr:
                                            logger.debug(f"  stderr: {result.stderr.strip()}")
                                    
                                    if result.retries > 0:
                                        stats.retried += 1
                                    
                                    pbar.update(1)
                                
                                # Periodic state save
                                if stats.completed and (stats.completed % 100 == 0):
                                    state_manager.save()
                            
                            except Exception as e:
                                logger.error(f"Batch processing error: {e}")
                                logger.debug(traceback.format_exc())
                                stats.failed += config.batch_size
                                pbar.update(config.batch_size)
                
                # Final state save after processing
                state_manager.save()
            
    except KeyboardInterrupt:
        logger.info("Interrupted by user")
        return 130
    except Exception as e:
        logger.error(f"Unexpected error: {e}")
        logger.debug(traceback.format_exc())
        return 1
    finally:
        monitor.stop()
        stats.end_time = time.time()
        stats.peak_memory_mb = monitor.peak_memory
        # Safe even if no pages were processed
        if durations:
            stats.avg_duration = sum(durations) / len(durations)
    
    # Print final statistics
    print("\n" + "="*60)
    print("PROCESSING COMPLETE")
    print("="*60)
    print(f"Total pages:      {stats.total_pages:,}")
    print(f"Completed:        {stats.completed:,}")
    print(f"Failed:           {stats.failed:,}")
    print(f"Success rate:     {stats.success_rate():.1f}%")
    print(f"Retried:          {stats.retried:,}")
    print(f"Duration:         {stats.total_duration():.1f}s")
    print(f"Avg per page:     {stats.avg_duration:.2f}s")
    print(f"Peak memory:      {stats.peak_memory_mb:.1f} MB")
    
    if stats.failed > 0:
        print(f"\nFailed IDs saved to state file: {config.state_file}")
        print("Run with --resume to retry failed pages")
    
    return 0 if stats.failed == 0 else 1


if __name__ == "__main__":
    raise SystemExit(main())
