"""
Diagnostics Checker - Validate MMM model metrics

Created by: Claude Code
Session ID: Ralph Loop Execution
Date: 2025-11-26
Purpose: Check if model meets convergence and fit criteria
"""

import json
from pathlib import Path
from typing import Dict, Tuple
from dataclasses import dataclass


@dataclass
class DiagnosticThresholds:
    """Target thresholds for model quality"""
    max_rhat: float = 1.02
    min_ess: float = 100
    max_divergences: int = 0
    min_r2: float = 0.55
    max_r2: float = 0.70  # Too high might indicate overfitting


@dataclass
class DiagnosticResult:
    """Result of diagnostic check"""
    passed: bool
    worst_rhat: float
    min_ess: float
    divergences: int
    r2: float
    issues: list
    recommendations: list


def load_diagnostics(artifacts_dir: str) -> Dict:
    """Load diagnostics from artifacts directory"""
    artifacts_dir = Path(artifacts_dir)

    result = {}

    # Load diagnostics_summary.txt
    diag_file = artifacts_dir / 'diagnostics' / 'diagnostics_summary.txt'
    if diag_file.exists():
        with open(diag_file, 'r') as f:
            for line in f:
                if ':' in line:
                    key, value = line.strip().split(':', 1)
                    key = key.strip()
                    value = value.strip()
                    try:
                        if '.' in value:
                            result[key] = float(value)
                        else:
                            result[key] = int(value)
                    except ValueError:
                        result[key] = value

    # Load metrics summary.json
    metrics_file = artifacts_dir / 'metrics' / 'summary.json'
    if metrics_file.exists():
        with open(metrics_file, 'r') as f:
            metrics = json.load(f)
            result['r2'] = metrics.get('r2', 0)
            result['n_days'] = metrics.get('n_days', 0)

    return result


def check_diagnostics(
    diagnostics: Dict,
    thresholds: DiagnosticThresholds = None
) -> DiagnosticResult:
    """
    Check if diagnostics meet thresholds.

    Returns DiagnosticResult with pass/fail and recommendations.
    """
    if thresholds is None:
        thresholds = DiagnosticThresholds()

    worst_rhat = diagnostics.get('worst_rhat', 1.0)
    min_ess = diagnostics.get('min_ess', 1000)
    divergences = diagnostics.get('divergences', 0)
    r2 = diagnostics.get('r2', 0.5)

    issues = []
    recommendations = []

    # Check rhat
    if worst_rhat > thresholds.max_rhat:
        issues.append(f"rhat={worst_rhat:.4f} > {thresholds.max_rhat}")
        recommendations.append("Increase tune samples (e.g., tune += 500)")
        recommendations.append("Increase target_accept (e.g., 0.95 → 0.98)")

    # Check ESS
    if min_ess < thresholds.min_ess:
        issues.append(f"ESS={min_ess:.1f} < {thresholds.min_ess}")
        recommendations.append("Increase draws (e.g., draws += 500)")

    # Check divergences
    if divergences > thresholds.max_divergences:
        issues.append(f"divergences={divergences} > {thresholds.max_divergences}")
        recommendations.append("Reduce max_treedepth (e.g., 12 → 10)")
        recommendations.append("Tighten priors (reduce sigma)")
        recommendations.append("Check for prior-likelihood conflict")

    # Check R²
    if r2 < thresholds.min_r2:
        issues.append(f"R²={r2:.4f} < {thresholds.min_r2}")
        recommendations.append("Add more controls or channels")
        recommendations.append("Consider different model structure")
    elif r2 > thresholds.max_r2:
        issues.append(f"R²={r2:.4f} > {thresholds.max_r2} (possible overfitting)")
        recommendations.append("Add regularization")
        recommendations.append("Simplify model")

    passed = len(issues) == 0

    return DiagnosticResult(
        passed=passed,
        worst_rhat=worst_rhat,
        min_ess=min_ess,
        divergences=divergences,
        r2=r2,
        issues=issues,
        recommendations=recommendations,
    )


def print_diagnostic_result(result: DiagnosticResult) -> None:
    """Pretty print diagnostic result"""
    status = "✅ PASS" if result.passed else "❌ FAIL"

    print(f"\n{'='*60}")
    print(f"DIAGNOSTIC CHECK: {status}")
    print(f"{'='*60}")

    print(f"\nMetrics:")
    print(f"  worst_rhat: {result.worst_rhat:.6f} {'✅' if result.worst_rhat <= 1.02 else '❌'}")
    print(f"  min_ess: {result.min_ess:.1f} {'✅' if result.min_ess >= 100 else '❌'}")
    print(f"  divergences: {result.divergences} {'✅' if result.divergences == 0 else '❌'}")
    print(f"  R²: {result.r2:.4f} {'✅' if 0.55 <= result.r2 <= 0.70 else '⚠️'}")

    if result.issues:
        print(f"\nIssues found:")
        for issue in result.issues:
            print(f"  ❌ {issue}")

    if result.recommendations:
        print(f"\nRecommendations:")
        for rec in result.recommendations:
            print(f"  → {rec}")

    print(f"{'='*60}")


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python diagnostics_checker.py <artifacts_dir>")
        sys.exit(1)

    artifacts_dir = sys.argv[1]

    print(f"Loading diagnostics from: {artifacts_dir}")
    diagnostics = load_diagnostics(artifacts_dir)

    print(f"\nLoaded diagnostics: {diagnostics}")

    result = check_diagnostics(diagnostics)
    print_diagnostic_result(result)

    # Exit code based on pass/fail
    sys.exit(0 if result.passed else 1)
