"""
Iteration Engine - Orchestrate MMM training iterations with Knowledge Framework output

Created by: Claude Code
Session ID: Ralph Loop Execution
Date: 2025-11-26
Updated: 2025-11-26 - Added adstock/saturation support and KF documentation
Purpose: Run training iterations until metrics are achieved, documenting thinking process
"""

import yaml
import json
import shutil
import numpy as np
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional

from data_analyzer import analyze_dataset, DataProfile
from config_generator import generate_config_from_profile, adjust_config_for_diagnostics, ModelConfig
from diagnostics_checker import load_diagnostics, check_diagnostics, DiagnosticResult
from oos_validator import OOSValidator, run_oos_validation


class IterationEngine:
    """
    Orchestrates MMM model training iterations with Knowledge Framework documentation.

    Workflow:
    1. Analyze data → DataProfile
    2. Generate initial config
    3. Train model (v3 with adstock + saturation)
    4. Check diagnostics
    5. If FAIL: adjust config and retry (with reasoning)
    6. If PASS or budget exhausted: return best result
    7. Generate Knowledge Framework documentation

    KF Output:
    - iteration_thinking.md: Documents thinking at each iteration
    - Each decision is logged with reasoning
    """

    def __init__(
        self,
        csv_path: str,
        output_dir: str,
        max_iterations: int = 10,
        conda_env: str = 'pymc_gpu_015',
        model_type: str = 'v3'
    ):
        self.csv_path = csv_path
        self.output_dir = Path(output_dir)
        self.max_iterations = max_iterations
        self.conda_env = conda_env
        self.model_type = model_type  # Default to v3 with adstock+saturation

        self.output_dir.mkdir(parents=True, exist_ok=True)

        # State
        self.profile: Optional[DataProfile] = None
        self.config: Optional[ModelConfig] = None
        self.iteration_log: List[Dict] = []
        self.best_result: Optional[Dict] = None
        self.thinking_history: List[Dict] = []  # Track all thinking across iterations

    def analyze_data(self) -> DataProfile:
        """Step 1: Analyze dataset"""
        print(f"\n[Iteration Engine] Analyzing data: {self.csv_path}")
        self.profile = analyze_dataset(self.csv_path)
        return self.profile

    def generate_initial_config(self) -> str:
        """Step 2: Generate initial config"""
        if self.profile is None:
            self.analyze_data()

        print(f"[Iteration Engine] Generating initial config...")
        self.config = generate_config_from_profile(self.profile, iteration=0)

        config_path = self.output_dir / 'config_iter_0.yaml'
        self.config.to_yaml(config_path)
        print(f"[Iteration Engine] Config saved to: {config_path}")

        return str(config_path)

    def run_training(self, config_path: str, iteration: int) -> str:
        """Step 3: Run PyMC training with adstock + saturation (v3)"""
        import subprocess

        artifacts_dir = self.output_dir / f'artifacts_iter_{iteration}'
        artifacts_dir.mkdir(exist_ok=True)

        print(f"\n[Iteration Engine] Running training iteration {iteration}...")
        print(f"  Config: {config_path}")
        print(f"  Output: {artifacts_dir}")
        print(f"  Model type: {self.model_type} (with adstock + saturation)")

        # Record thinking
        self.thinking_history.append({
            'iteration': iteration,
            'phase': 'training_start',
            'timestamp': datetime.now().isoformat(),
            'reasoning': f'Starting iteration {iteration} with {self.model_type} model (adstock + Hill saturation)'
        })

        # Run model_trainer.py
        trainer_script = Path(__file__).parent / 'model_trainer.py'

        # Try direct python execution first (faster), fall back to conda run
        conda_python = f'$CONDA_PREFIX/envs/{self.conda_env}/bin/python'

        cmd = [
            conda_python,
            str(trainer_script),
            config_path,
            str(artifacts_dir),
            self.model_type  # Use v3 by default
        ]

        print(f"  Command: {' '.join(cmd)}")

        try:
            result = subprocess.run(
                cmd,
                capture_output=True,
                text=True,
                timeout=1800,  # 30 min timeout
                cwd=str(Path(__file__).parent)
            )
            print(result.stdout)
            if result.returncode != 0:
                print(f"  Warning: Training returned non-zero exit code")
                print(f"  stderr: {result.stderr}")
                self.thinking_history.append({
                    'iteration': iteration,
                    'phase': 'training_error',
                    'error': result.stderr[:500],
                    'reasoning': 'Training failed - will check if artifacts exist'
                })
        except subprocess.TimeoutExpired:
            print(f"  Error: Training timed out!")
            self.thinking_history.append({
                'iteration': iteration,
                'phase': 'training_timeout',
                'reasoning': 'Training exceeded 30 minute timeout - model may be too complex'
            })
        except Exception as e:
            print(f"  Error: {e}")
            self.thinking_history.append({
                'iteration': iteration,
                'phase': 'training_exception',
                'error': str(e),
                'reasoning': f'Unexpected error during training: {e}'
            })

        return str(artifacts_dir)

    def check_results(self, artifacts_dir: str, iteration: int) -> DiagnosticResult:
        """Step 4: Check diagnostics with thinking documentation"""
        print(f"\n[Iteration Engine] Checking results: {artifacts_dir}")
        diagnostics = load_diagnostics(artifacts_dir)
        result = check_diagnostics(diagnostics)

        # Log iteration
        log_entry = {
            'iteration': iteration,
            'timestamp': datetime.now().isoformat(),
            'artifacts_dir': artifacts_dir,
            'diagnostics': diagnostics,
            'passed': result.passed,
            'issues': result.issues,
        }
        self.iteration_log.append(log_entry)

        # Record thinking about results
        thinking_entry = {
            'iteration': iteration,
            'phase': 'diagnostics_check',
            'timestamp': datetime.now().isoformat(),
            'diagnostics': diagnostics,
            'passed': result.passed,
            'reasoning': self._generate_diagnostics_reasoning(diagnostics, result)
        }
        self.thinking_history.append(thinking_entry)

        # Track best result
        if self.best_result is None or diagnostics.get('r2', 0) > self.best_result.get('r2', 0):
            self.best_result = diagnostics
            self.best_result['artifacts_dir'] = artifacts_dir
            self.thinking_history.append({
                'iteration': iteration,
                'phase': 'new_best',
                'reasoning': f"New best result: R²={diagnostics.get('r2', 0):.4f}"
            })

        return result

    def run_oos_validation(self, artifacts_dir: str, iteration: int) -> Dict:
        """
        Step 4b: Run Out-of-Sample validation (MANDATORY).

        This step validates model predictive power on held-out data.
        A model CANNOT be considered production-ready without passing OOS validation.
        """
        import pandas as pd
        import arviz as az

        print(f"\n[Iteration Engine] Running OUT-OF-SAMPLE VALIDATION...")

        artifacts_path = Path(artifacts_dir)

        # Load trace
        trace_path = artifacts_path / 'trace.nc'
        if not trace_path.exists():
            print(f"  ERROR: trace.nc not found at {trace_path}")
            return {'passed': False, 'error': 'trace.nc not found'}

        trace = az.from_netcdf(str(trace_path))

        # Load config
        config_path = self.output_dir / f'config_iter_{iteration}.yaml'
        with open(config_path, 'r') as f:
            config = yaml.safe_load(f)

        # Load data
        df = pd.read_csv(self.csv_path)
        df[config['data']['date_col']] = pd.to_datetime(df[config['data']['date_col']])
        df = df.sort_values(config['data']['date_col']).reset_index(drop=True)

        # Get in-sample R² from diagnostics
        metrics_path = artifacts_path / 'metrics' / 'summary.json'
        if metrics_path.exists():
            with open(metrics_path, 'r') as f:
                metrics = json.load(f)
                in_sample_r2 = metrics.get('r2', 0.0)
        else:
            in_sample_r2 = 0.0

        # Create OOS validator
        validator = OOSValidator(
            df=df,
            y_col=config['data']['target_col'],
            channel_cols=config['data']['channels'],
            control_cols=config['data']['controls'],
            date_col=config['data']['date_col'],
            scale_y=config['preprocess']['scale_y'],
            test_size=0.20  # 20% held out
        )

        # Get train predictions from diagnostics
        pred_path = artifacts_path / 'diagnostics' / 'predicted_vs_actual_series.csv'
        if pred_path.exists():
            pred_df = pd.read_csv(pred_path)
            y_pred_train = pred_df['predicted'].values / config['preprocess']['scale_y']
        else:
            y_pred_train = np.zeros(len(df))

        # Run validation
        oos_result = validator.validate_with_trace(trace, y_pred_train, in_sample_r2)

        # Save results
        validator.save_results(str(artifacts_path))
        validator.save_predictions_csv(str(artifacts_path))

        # Record thinking
        self.thinking_history.append({
            'iteration': iteration,
            'phase': 'oos_validation',
            'timestamp': datetime.now().isoformat(),
            'oos_r2': oos_result['test_metrics']['r2'],
            'oos_mape': oos_result['test_metrics']['mape'],
            'overfitting_index': oos_result['overfitting_index'],
            'passed': oos_result['passed'],
            'reasoning': self._generate_oos_reasoning(oos_result)
        })

        return oos_result

    def _generate_oos_reasoning(self, oos_result: Dict) -> str:
        """Generate reasoning text about OOS validation."""
        parts = []

        r2 = oos_result['test_metrics']['r2']
        mape = oos_result['test_metrics']['mape']
        overfit = oos_result['overfitting_index']

        if r2 >= 0.40:
            parts.append(f"OOS R²={r2:.4f} ≥ 0.40 ✅ (good predictive power)")
        else:
            parts.append(f"OOS R²={r2:.4f} < 0.40 ❌ (weak predictive power)")

        if mape <= 20:
            parts.append(f"MAPE={mape:.1f}% ≤ 20% ✅")
        else:
            parts.append(f"MAPE={mape:.1f}% > 20% ❌ (high prediction error)")

        if overfit <= 0.25:
            parts.append(f"Overfitting={overfit:.4f} ≤ 0.25 ✅")
        else:
            parts.append(f"Overfitting={overfit:.4f} > 0.25 ❌ (model overfitting)")

        return " | ".join(parts)

    def _generate_diagnostics_reasoning(self, diagnostics: Dict, result: DiagnosticResult) -> str:
        """Generate reasoning text about diagnostics."""
        r2 = diagnostics.get('r2', 0)
        rhat = diagnostics.get('worst_rhat', 2.0)
        ess = diagnostics.get('min_ess', 0)
        div = diagnostics.get('divergences', 999)

        parts = []

        # R² analysis
        if r2 < 0.55:
            parts.append(f"R²={r2:.4f} below target - model underfitting")
        elif r2 > 0.70:
            parts.append(f"R²={r2:.4f} above target - possible overfitting")
        else:
            parts.append(f"R²={r2:.4f} in target range ✅")

        # Convergence
        if rhat > 1.02:
            parts.append(f"rhat={rhat:.4f} > 1.02 - not converged")
        else:
            parts.append(f"rhat={rhat:.4f} converged ✅")

        # ESS
        if ess < 100:
            parts.append(f"ESS={ess:.0f} < 100 - insufficient samples")
        else:
            parts.append(f"ESS={ess:.0f} sufficient ✅")

        # Divergences
        if div > 0:
            parts.append(f"{div} divergences - sampling issues")
        else:
            parts.append("No divergences ✅")

        return " | ".join(parts)

    def adjust_config(self, diagnostics: Dict, iteration: int) -> str:
        """Step 5: Adjust config based on diagnostics with reasoning"""
        print(f"\n[Iteration Engine] Adjusting config for iteration {iteration}...")

        # Determine what adjustments to make
        adjustments = []
        r2 = diagnostics.get('r2', 0)
        rhat = diagnostics.get('worst_rhat', 2.0)
        ess = diagnostics.get('min_ess', 0)
        div = diagnostics.get('divergences', 0)

        if rhat > 1.02:
            adjustments.append("Increasing tune iterations for better convergence")
        if ess < 100:
            adjustments.append("Increasing draws for more effective samples")
        if div > 0:
            adjustments.append("Adjusting target_accept to reduce divergences")
        if r2 < 0.55:
            adjustments.append("Model underfitting - may need more features or looser priors")
        if r2 > 0.70:
            adjustments.append("Model overfitting - tightening regularization")

        # Record thinking
        self.thinking_history.append({
            'iteration': iteration,
            'phase': 'config_adjustment',
            'timestamp': datetime.now().isoformat(),
            'adjustments': adjustments,
            'reasoning': f"Based on iteration {iteration-1} results: {' | '.join(adjustments) if adjustments else 'Minor tuning'}"
        })

        # Reload and adjust config
        self.config = adjust_config_for_diagnostics(self.config, diagnostics, iteration)

        config_path = self.output_dir / f'config_iter_{iteration}.yaml'
        self.config.to_yaml(config_path)

        return str(config_path)

    def run_full_pipeline(self) -> Dict:
        """
        Run full iteration pipeline until success or budget exhausted.

        Returns:
            Dictionary with final results and KF documentation path
        """
        print(f"\n{'='*60}")
        print("MMM ITERATION ENGINE (with Adstock + Saturation)")
        print(f"{'='*60}")
        print(f"Data: {self.csv_path}")
        print(f"Output: {self.output_dir}")
        print(f"Max iterations: {self.max_iterations}")
        print(f"Model type: {self.model_type}")
        print(f"Conda env: {self.conda_env}")

        self.thinking_history.append({
            'phase': 'pipeline_start',
            'timestamp': datetime.now().isoformat(),
            'reasoning': f'Starting MMM pipeline with {self.model_type} model (adstock + Hill saturation)'
        })

        # Step 1: Analyze data
        self.analyze_data()

        # Step 2: Generate initial config
        config_path = self.generate_initial_config()

        result = None  # Track last result

        # Iteration loop
        for iteration in range(self.max_iterations):
            print(f"\n{'='*60}")
            print(f"ITERATION {iteration + 1}/{self.max_iterations}")
            print(f"{'='*60}")

            # Step 3: Train
            artifacts_dir = self.run_training(config_path, iteration)

            # Step 4a: Check diagnostics
            result = self.check_results(artifacts_dir, iteration)

            if result.passed:
                # Step 4b: Run MANDATORY Out-of-Sample validation
                print(f"\n📊 Diagnostics PASSED - Running OOS Validation...")
                oos_result = self.run_oos_validation(artifacts_dir, iteration)

                if oos_result.get('passed', False):
                    print(f"\n✅ FULL SUCCESS! Model converged AND passed OOS validation at iteration {iteration + 1}")
                    self.thinking_history.append({
                        'phase': 'full_success',
                        'iteration': iteration,
                        'timestamp': datetime.now().isoformat(),
                        'reasoning': f'All metrics + OOS validation achieved at iteration {iteration + 1}'
                    })
                    # Store OOS results in best_result
                    if self.best_result:
                        self.best_result['oos_validation'] = oos_result
                    break
                else:
                    print(f"\n⚠️ Diagnostics passed but OOS validation FAILED - continuing iterations")
                    self.thinking_history.append({
                        'phase': 'oos_failed',
                        'iteration': iteration,
                        'timestamp': datetime.now().isoformat(),
                        'reasoning': f'Diagnostics OK but OOS validation failed - model may be overfitting'
                    })
                    # Continue to next iteration - treat as partial failure

            # Step 5: Adjust (if not last iteration)
            if iteration < self.max_iterations - 1:
                diagnostics = load_diagnostics(artifacts_dir)
                config_path = self.adjust_config(diagnostics, iteration + 1)
            else:
                print(f"\n⚠️ Budget exhausted. Returning best result.")
                self.thinking_history.append({
                    'phase': 'budget_exhausted',
                    'iteration': iteration,
                    'timestamp': datetime.now().isoformat(),
                    'reasoning': f'Max iterations ({self.max_iterations}) reached without full convergence'
                })

        # Save iteration log
        log_path = self.output_dir / 'iteration_log.json'
        with open(log_path, 'w') as f:
            json.dump(self.iteration_log, f, indent=2)

        # Save thinking history
        thinking_path = self.output_dir / 'thinking_history.json'
        with open(thinking_path, 'w') as f:
            json.dump(self.thinking_history, f, indent=2)

        # Generate Knowledge Framework documentation
        kf_path = self.generate_kf_documentation()

        # Final summary
        # Status is 'full_success' only if both diagnostics AND OOS validation passed
        oos_passed = self.best_result and self.best_result.get('oos_validation', {}).get('passed', False)
        if result and result.passed and oos_passed:
            status = 'full_success'
        elif result and result.passed:
            status = 'diagnostics_only'  # Diagnostics OK but OOS failed
        else:
            status = 'budget_exhausted'

        final_result = {
            'status': status,
            'iterations': len(self.iteration_log),
            'best_result': self.best_result,
            'oos_validation_passed': oos_passed,
            'log_path': str(log_path),
            'thinking_path': str(thinking_path),
            'kf_documentation': str(kf_path),
        }

        print(f"\n{'='*60}")
        print("FINAL RESULT")
        print(f"{'='*60}")
        print(f"Status: {final_result['status']}")
        print(f"Iterations: {final_result['iterations']}")
        if self.best_result:
            print(f"In-sample R²: {self.best_result.get('r2', 'N/A')}")
            oos = self.best_result.get('oos_validation', {})
            if oos:
                print(f"Out-of-sample R²: {oos.get('test_metrics', {}).get('r2', 'N/A'):.4f}")
                print(f"Out-of-sample MAPE: {oos.get('test_metrics', {}).get('mape', 'N/A'):.1f}%")
                print(f"OOS Validation: {'✅ PASSED' if oos.get('passed') else '❌ FAILED'}")
            print(f"Best artifacts: {self.best_result.get('artifacts_dir', 'N/A')}")
        print(f"KF Documentation: {kf_path}")

        return final_result

    def generate_kf_documentation(self) -> Path:
        """
        Generate Knowledge Framework documentation of the iteration process.

        Creates a markdown file documenting:
        - Model architecture decisions
        - Iteration-by-iteration thinking
        - Final results and evaluation
        """
        kf_path = self.output_dir / 'ITERATION_ANALYSIS.md'

        # Pre-calculate values for safe formatting
        n_days = self.profile.n_days if self.profile else 'N/A'
        n_channels = len(self.profile.channels) if self.profile else 'N/A'
        n_controls = len(self.profile.controls) if self.profile else 'N/A'
        date_range = f"{self.profile.date_min} to {self.profile.date_max}" if self.profile else 'N/A'

        # Best result metrics
        if self.best_result:
            r2_val = self.best_result.get('r2')
            r2_str = f"{r2_val:.4f}" if r2_val is not None else 'N/A'
            r2_status = '✅' if r2_val and 0.55 <= r2_val <= 0.70 else '❌'

            rhat_val = self.best_result.get('worst_rhat')
            rhat_str = f"{rhat_val:.4f}" if rhat_val is not None else 'N/A'
            rhat_status = '✅' if rhat_val and rhat_val <= 1.02 else '❌'

            ess_val = self.best_result.get('min_ess')
            ess_str = f"{ess_val:.0f}" if ess_val is not None else 'N/A'
            ess_status = '✅' if ess_val and ess_val >= 100 else '❌'

            div_val = self.best_result.get('divergences')
            div_str = str(div_val) if div_val is not None else 'N/A'
            div_status = '✅' if div_val is not None and div_val == 0 else '❌'

            # OOS Validation metrics
            oos = self.best_result.get('oos_validation', {})
            if oos:
                oos_r2_val = oos.get('test_metrics', {}).get('r2')
                oos_r2_str = f"{oos_r2_val:.4f}" if oos_r2_val is not None else 'N/A'
                oos_r2_status = '✅' if oos_r2_val and oos_r2_val >= 0.40 else '❌'

                oos_mape_val = oos.get('test_metrics', {}).get('mape')
                oos_mape_str = f"{oos_mape_val:.1f}%" if oos_mape_val is not None else 'N/A'
                oos_mape_status = '✅' if oos_mape_val and oos_mape_val <= 20 else '❌'

                overfit_val = oos.get('overfitting_index')
                overfit_str = f"{overfit_val:.4f}" if overfit_val is not None else 'N/A'
                overfit_status = '✅' if overfit_val and overfit_val <= 0.25 else '❌'

                oos_passed = oos.get('passed', False)
            else:
                oos_r2_str = oos_mape_str = overfit_str = 'N/A'
                oos_r2_status = oos_mape_status = overfit_status = '❌'
                oos_passed = False
        else:
            r2_str = rhat_str = ess_str = div_str = 'N/A'
            r2_status = rhat_status = ess_status = div_status = '❌'
            oos_r2_str = oos_mape_str = overfit_str = 'N/A'
            oos_r2_status = oos_mape_status = overfit_status = '❌'
            oos_passed = False

        content = f"""# MMM Model Building - Iteration Analysis

**Thesis:** This document captures the iterative model building process, documenting each decision, adjustment, and the reasoning behind moving toward optimal model performance.

## Overview

This Knowledge Framework document follows the MMM skill execution from data analysis through model convergence. Each iteration is documented with diagnostics, reasoning, and adjustments made.

```mermaid
graph LR
    subgraph Pipeline["MMM Pipeline (v3 with Adstock + Saturation + OOS)"]
        A[Data Analysis] --> B[Config Generation]
        B --> C[Model Training]
        C --> D{{Diagnostics Check}}
        D -->|FAIL| F[Config Adjustment]
        F --> C
        D -->|PASS| G{{OOS Validation}}
        G -->|PASS| E[✅ Full Success]
        G -->|FAIL| F
    end

    style Pipeline fill:#e1f5ff
    style E fill:#c8f7dc
    style G fill:#fff4e1
```

---

## §1.0 Data Profile

| Attribute | Value |
|-----------|-------|
| Data Path | `{self.csv_path}` |
| Observations | {n_days} |
| Channels | {n_channels} |
| Controls | {n_controls} |
| Date Range | {date_range} |

### Channel List
{self._format_channel_list()}

---

## §2.0 Iteration History

{self._format_iteration_history()}

---

## §3.0 Thinking Process

This section documents the reasoning at each step of the model building process.

{self._format_thinking_history()}

---

## §4.0 Final Results

### In-Sample Diagnostics

| Metric | Value | Target | Status |
|--------|-------|--------|--------|
| R² | {r2_str} | 0.55-0.70 | {r2_status} |
| worst_rhat | {rhat_str} | ≤1.02 | {rhat_status} |
| min_ess | {ess_str} | ≥100 | {ess_status} |
| divergences | {div_str} | 0 | {div_status} |

### Out-of-Sample Validation (MANDATORY)

| Metric | Value | Target | Status |
|--------|-------|--------|--------|
| OOS R² | {oos_r2_str} | ≥0.40 | {oos_r2_status} |
| OOS MAPE | {oos_mape_str} | ≤20% | {oos_mape_status} |
| Overfitting Index | {overfit_str} | ≤0.25 | {overfit_status} |

**OOS Validation Status:** {'✅ PASSED' if oos_passed else '❌ FAILED/NOT RUN'}

> ⚠️ **IMPORTANT:** A model is NOT production-ready unless OOS validation passes.
> OOS validation tests predictive power on held-out data (20% test set).

### Model Configuration

- **Model Type:** {self.model_type}
- **Features:**
  - ✅ Geometric Adstock (carryover effects)
  - ✅ Hill Saturation (diminishing returns)
  - ✅ Per-channel priors based on channel type
  - ✅ Out-of-Sample Validation (temporal split)

---

## §5.0 Key Learnings

{self._generate_key_learnings()}

---

**Generated:** {datetime.now().isoformat()}
**Skill Version:** 0.3.0 (with adstock + saturation + OOS validation)
"""

        with open(kf_path, 'w') as f:
            f.write(content)

        return kf_path

    def _format_channel_list(self) -> str:
        """Format channel list for documentation."""
        if not self.profile:
            return "No profile available"

        lines = []
        for ch in self.profile.channels:
            lines.append(f"- `{ch}`")
        return "\n".join(lines) if lines else "No channels detected"

    def _format_iteration_history(self) -> str:
        """Format iteration history for documentation."""
        if not self.iteration_log:
            return "No iterations recorded"

        sections = []
        for entry in self.iteration_log:
            diag = entry.get('diagnostics', {})
            passed = entry.get('passed', False)

            # Format values safely
            r2_val = diag.get('r2')
            r2_str = f"{r2_val:.4f}" if r2_val is not None else 'N/A'
            r2_status = '✅' if r2_val and 0.55 <= r2_val <= 0.70 else '❌'

            rhat_val = diag.get('worst_rhat')
            rhat_str = f"{rhat_val:.4f}" if rhat_val is not None else 'N/A'
            rhat_status = '✅' if rhat_val and rhat_val <= 1.02 else '❌'

            ess_val = diag.get('min_ess')
            ess_str = f"{ess_val:.0f}" if ess_val is not None else 'N/A'
            ess_status = '✅' if ess_val and ess_val >= 100 else '❌'

            div_val = diag.get('divergences')
            div_str = str(div_val) if div_val is not None else 'N/A'
            div_status = '✅' if div_val is not None and div_val == 0 else '❌'

            result_str = '✅ PASS' if passed else '❌ FAIL - adjustments needed'

            section = f"""### Iteration {entry.get('iteration', '?')}

| Metric | Value | Status |
|--------|-------|--------|
| R² | {r2_str} | {r2_status} |
| worst_rhat | {rhat_str} | {rhat_status} |
| min_ess | {ess_str} | {ess_status} |
| divergences | {div_str} | {div_status} |

**Result:** {result_str}
"""
            sections.append(section)

        return "\n".join(sections)

    def _format_thinking_history(self) -> str:
        """Format thinking history for documentation."""
        if not self.thinking_history:
            return "No thinking recorded"

        entries = []
        for entry in self.thinking_history:
            phase = entry.get('phase', 'unknown')
            reasoning = entry.get('reasoning', 'No reasoning provided')
            iteration = entry.get('iteration', '')

            prefix = f"**[Iter {iteration}]** " if iteration != '' else ""
            entries.append(f"- {prefix}**{phase}:** {reasoning}")

        return "\n".join(entries)

    def _generate_key_learnings(self) -> str:
        """Generate key learnings from the iteration process."""
        learnings = []

        if self.iteration_log:
            n_iters = len(self.iteration_log)
            learnings.append(f"1. Model converged in **{n_iters} iteration(s)**")

            # Check if first iteration passed
            if self.iteration_log[0].get('passed'):
                learnings.append("2. Initial configuration was well-tuned - first attempt succeeded")
            else:
                learnings.append("2. Required iterative refinement to achieve target metrics")

        if self.best_result:
            r2 = self.best_result.get('r2', 0)
            if r2 > 0.65:
                learnings.append(f"3. Strong model fit (R²={r2:.4f}) indicates good explanatory power")
            elif r2 > 0.55:
                learnings.append(f"3. Moderate model fit (R²={r2:.4f}) is acceptable for MMM")
            else:
                learnings.append(f"3. Weak model fit (R²={r2:.4f}) may need additional features")

        learnings.append(f"4. Model type **{self.model_type}** with adstock + saturation captures marketing dynamics")

        return "\n".join(learnings) if learnings else "No learnings to report"


def run_iteration_engine(
    csv_path: str,
    output_dir: str,
    max_iterations: int = 10,
    conda_env: str = 'pymc_gpu_015',
    model_type: str = 'v3'
) -> Dict:
    """Convenience function to run iteration engine with v3 model (adstock + saturation)"""
    engine = IterationEngine(
        csv_path=csv_path,
        output_dir=output_dir,
        max_iterations=max_iterations,
        conda_env=conda_env,
        model_type=model_type
    )
    return engine.run_full_pipeline()


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 3:
        print("Usage: python iteration_engine.py <csv_path> <output_dir> [max_iterations] [model_type]")
        print("  model_type: v3 (default with adstock+saturation), simple, v1, v2, v3_proper")
        sys.exit(1)

    csv_path = sys.argv[1]
    output_dir = sys.argv[2]
    max_iterations = int(sys.argv[3]) if len(sys.argv) > 3 else 10
    model_type = sys.argv[4] if len(sys.argv) > 4 else 'v3'

    result = run_iteration_engine(csv_path, output_dir, max_iterations, model_type=model_type)

    print(f"\nKF Documentation: {result.get('kf_documentation', 'N/A')}")

    print(f"\nFinal result: {result}")
