"""
Out-of-Sample Validator - Mandatory validation step for MMM models

Created by: Claude Code
Session ID: MMM Skill Enhancement
Date: 2025-11-26
Purpose: Validate MMM model predictive power on held-out data

CRITICAL: This is a MANDATORY step in the MMM skill pipeline.
A model CANNOT be considered production-ready without passing OOS validation.
"""

import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import json
from pathlib import Path
from typing import Dict, Tuple, Optional
from datetime import datetime


class OOSValidator:
    """
    Out-of-Sample Validator for MMM models.

    Performs temporal train/test split and validates model on unseen data.

    Key Metrics:
    - OOS R² (out-of-sample coefficient of determination)
    - OOS MAPE (Mean Absolute Percentage Error)
    - OOS RMSE (Root Mean Squared Error)
    - Overfitting Index (in-sample R² vs OOS R²)

    Passing Criteria:
    - OOS R² >= 0.40 (model explains 40%+ variance on unseen data)
    - OOS MAPE <= 20% (average prediction error < 20%)
    - Overfitting Index <= 0.25 (R² drop from train to test)
    """

    DEFAULT_THRESHOLDS = {
        'oos_r2_min': 0.40,       # Minimum OOS R²
        'oos_mape_max': 20.0,     # Maximum MAPE in %
        'overfitting_max': 0.25,  # Max R² drop from train to test
        'test_size': 0.20,        # 20% held out for testing
    }

    def __init__(
        self,
        df: pd.DataFrame,
        y_col: str,
        channel_cols: list,
        control_cols: list,
        date_col: str,
        scale_y: float = 1.0,
        test_size: float = 0.20,
        thresholds: Optional[Dict] = None
    ):
        """
        Initialize OOS Validator.

        Args:
            df: Full dataset
            y_col: Target column name
            channel_cols: List of channel spend columns
            control_cols: List of control variable columns
            date_col: Date column name
            scale_y: Target scaling factor
            test_size: Fraction of data to hold out (default 20%)
            thresholds: Custom thresholds dict (optional)
        """
        self.df = df.sort_values(date_col).reset_index(drop=True)
        self.y_col = y_col
        self.channel_cols = channel_cols
        self.control_cols = control_cols
        self.date_col = date_col
        self.scale_y = scale_y
        self.test_size = test_size
        self.thresholds = {**self.DEFAULT_THRESHOLDS, **(thresholds or {})}

        # Split data
        self._split_data()

        # Results storage
        self.train_metrics: Dict = {}
        self.test_metrics: Dict = {}
        self.validation_result: Optional[Dict] = None

    def _split_data(self):
        """Perform temporal train/test split."""
        n = len(self.df)
        split_idx = int(n * (1 - self.test_size))

        self.df_train = self.df.iloc[:split_idx].copy()
        self.df_test = self.df.iloc[split_idx:].copy()

        # Prepare arrays
        self.y_train = self.df_train[self.y_col].values / self.scale_y
        self.y_test = self.df_test[self.y_col].values / self.scale_y

        self.X_channels_train = self.df_train[self.channel_cols].values
        self.X_channels_test = self.df_test[self.channel_cols].values

        self.X_controls_train = self.df_train[self.control_cols].values
        self.X_controls_test = self.df_test[self.control_cols].values

        # Normalize channels using TRAIN statistics only
        self.channel_maxes = self.X_channels_train.max(axis=0, keepdims=True)
        self.channel_maxes[self.channel_maxes == 0] = 1

        self.X_channels_train_norm = self.X_channels_train / self.channel_maxes
        self.X_channels_test_norm = self.X_channels_test / self.channel_maxes

        print(f"[OOS Validator] Data split:")
        print(f"  Train: {len(self.df_train)} observations ({self.df_train[self.date_col].min()} to {self.df_train[self.date_col].max()})")
        print(f"  Test:  {len(self.df_test)} observations ({self.df_test[self.date_col].min()} to {self.df_test[self.date_col].max()})")

    def validate_with_trace(
        self,
        trace: az.InferenceData,
        y_pred_train: np.ndarray,
        in_sample_r2: float
    ) -> Dict:
        """
        Validate model using existing trace on test data.

        Args:
            trace: ArviZ InferenceData from training
            y_pred_train: Predicted values on training data
            in_sample_r2: R² from training

        Returns:
            Validation result dict with pass/fail status
        """
        print(f"\n{'='*60}")
        print("OUT-OF-SAMPLE VALIDATION")
        print(f"{'='*60}")

        # Store in-sample metrics
        self.train_metrics = {
            'r2': in_sample_r2,
            'n_obs': len(self.y_train),
        }

        # Get posterior parameters
        posterior = trace.posterior

        # Extract mean parameter values
        intercept = float(posterior['intercept'].mean())
        sigma = float(posterior['sigma'].mean())

        # Get channel betas and control coefs
        channel_betas = {}
        channel_lambdas = {}
        channel_alphas = {}

        for ch in self.channel_cols:
            if f'beta_{ch}' in posterior:
                channel_betas[ch] = float(posterior[f'beta_{ch}'].mean())
            if f'lam_{ch}' in posterior:
                channel_lambdas[ch] = float(posterior[f'lam_{ch}'].mean())
            if f'alpha_{ch}' in posterior:
                channel_alphas[ch] = float(posterior[f'alpha_{ch}'].mean())

        # Get control coefficients
        if 'control_coefs' in posterior:
            control_coefs = posterior['control_coefs'].mean(dim=['chain', 'draw']).values
        else:
            control_coefs = np.zeros(len(self.control_cols))

        # Predict on test data
        y_pred_test = self._predict_test(
            intercept=intercept,
            channel_betas=channel_betas,
            channel_lambdas=channel_lambdas,
            channel_alphas=channel_alphas,
            control_coefs=control_coefs
        )

        # Calculate OOS metrics
        oos_r2 = self._calc_r2(self.y_test, y_pred_test)
        oos_mape = self._calc_mape(self.y_test, y_pred_test)
        oos_rmse = self._calc_rmse(self.y_test, y_pred_test)

        # Overfitting index
        overfitting_index = in_sample_r2 - oos_r2

        self.test_metrics = {
            'r2': oos_r2,
            'mape': oos_mape,
            'rmse': oos_rmse,
            'n_obs': len(self.y_test),
        }

        # Check thresholds
        r2_pass = oos_r2 >= self.thresholds['oos_r2_min']
        mape_pass = oos_mape <= self.thresholds['oos_mape_max']
        overfit_pass = overfitting_index <= self.thresholds['overfitting_max']

        all_passed = r2_pass and mape_pass and overfit_pass

        # Print results
        print(f"\n{'─'*60}")
        print("RESULTS")
        print(f"{'─'*60}")
        print(f"  IN-SAMPLE  R²:  {in_sample_r2:.4f}")
        print(f"  OUT-OF-SAMPLE R²: {oos_r2:.4f} {'✅' if r2_pass else '❌'} (min: {self.thresholds['oos_r2_min']})")
        print(f"  OUT-OF-SAMPLE MAPE: {oos_mape:.2f}% {'✅' if mape_pass else '❌'} (max: {self.thresholds['oos_mape_max']}%)")
        print(f"  OUT-OF-SAMPLE RMSE: {oos_rmse:.4f}")
        print(f"  Overfitting Index: {overfitting_index:.4f} {'✅' if overfit_pass else '❌'} (max: {self.thresholds['overfitting_max']})")
        print(f"{'─'*60}")
        print(f"  OVERALL: {'✅ PASS' if all_passed else '❌ FAIL'}")
        print(f"{'='*60}")

        self.validation_result = {
            'passed': all_passed,
            'timestamp': datetime.now().isoformat(),
            'train_metrics': {
                'r2': float(in_sample_r2),
                'n_obs': len(self.y_train),
            },
            'test_metrics': {
                'r2': float(oos_r2),
                'mape': float(oos_mape),
                'rmse': float(oos_rmse),
                'n_obs': len(self.y_test),
            },
            'overfitting_index': float(overfitting_index),
            'thresholds': self.thresholds,
            'checks': {
                'oos_r2_pass': r2_pass,
                'oos_mape_pass': mape_pass,
                'overfitting_pass': overfit_pass,
            },
            'predictions': {
                'y_test_actual': self.y_test.tolist(),
                'y_test_predicted': y_pred_test.tolist(),
            }
        }

        return self.validation_result

    def _predict_test(
        self,
        intercept: float,
        channel_betas: Dict,
        channel_lambdas: Dict,
        channel_alphas: Dict,
        control_coefs: np.ndarray
    ) -> np.ndarray:
        """
        Generate predictions on test data using posterior means.

        Uses same transforms as training:
        1. Geometric adstock
        2. Hill saturation
        """
        n_test = len(self.y_test)

        # Start with intercept
        y_pred = np.full(n_test, intercept)

        # Add channel effects with adstock + saturation
        for i, ch in enumerate(self.channel_cols):
            # Get parameters (use defaults if not found)
            alpha = channel_alphas.get(ch, 0.5)
            lam = channel_lambdas.get(ch, 0.5)
            beta = channel_betas.get(ch, 0.1)

            # Get normalized channel spend
            x = self.X_channels_test_norm[:, i]

            # Apply geometric adstock
            x_adstocked = self._apply_adstock(x, alpha)

            # Normalize adstocked
            x_adstocked_norm = x_adstocked / (x_adstocked.max() + 1e-8)

            # Apply Hill saturation
            x_saturated = x_adstocked_norm / (lam + x_adstocked_norm + 1e-8)

            # Add scaled effect
            y_pred += beta * x_saturated

        # Add control effects
        if len(self.control_cols) > 0:
            y_pred += np.dot(self.X_controls_test, control_coefs)

        return y_pred

    def _apply_adstock(self, x: np.ndarray, alpha: float) -> np.ndarray:
        """Apply geometric adstock transformation."""
        n = len(x)
        adstocked = np.zeros(n)
        adstocked[0] = x[0]

        for t in range(1, n):
            adstocked[t] = x[t] + alpha * adstocked[t-1]

        return adstocked

    @staticmethod
    def _calc_r2(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """Calculate R² (coefficient of determination)."""
        ss_res = np.sum((y_true - y_pred) ** 2)
        ss_tot = np.sum((y_true - y_true.mean()) ** 2)
        return 1 - (ss_res / ss_tot) if ss_tot > 0 else 0.0

    @staticmethod
    def _calc_mape(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """Calculate Mean Absolute Percentage Error."""
        # Avoid division by zero
        mask = np.abs(y_true) > 1e-8
        if not mask.any():
            return 100.0
        return 100 * np.mean(np.abs((y_true[mask] - y_pred[mask]) / y_true[mask]))

    @staticmethod
    def _calc_rmse(y_true: np.ndarray, y_pred: np.ndarray) -> float:
        """Calculate Root Mean Squared Error."""
        return np.sqrt(np.mean((y_true - y_pred) ** 2))

    def save_results(self, output_dir: str) -> Path:
        """Save validation results to JSON."""
        output_path = Path(output_dir) / 'oos_validation.json'

        with open(output_path, 'w') as f:
            json.dump(self.validation_result, f, indent=2)

        print(f"[OOS Validator] Results saved to: {output_path}")
        return output_path

    def save_predictions_csv(self, output_dir: str) -> Path:
        """Save test predictions to CSV for analysis."""
        output_path = Path(output_dir) / 'oos_predictions.csv'

        pred_df = pd.DataFrame({
            'date': self.df_test[self.date_col],
            'actual': self.y_test * self.scale_y,
            'predicted': np.array(self.validation_result['predictions']['y_test_predicted']) * self.scale_y,
        })
        pred_df['error'] = pred_df['predicted'] - pred_df['actual']
        pred_df['error_pct'] = 100 * pred_df['error'] / pred_df['actual']

        pred_df.to_csv(output_path, index=False)
        print(f"[OOS Validator] Predictions saved to: {output_path}")
        return output_path


def run_oos_validation(
    csv_path: str,
    trace_path: str,
    config: dict,
    y_pred_train: np.ndarray,
    in_sample_r2: float,
    output_dir: str,
    test_size: float = 0.20
) -> Dict:
    """
    Run out-of-sample validation as part of MMM pipeline.

    Args:
        csv_path: Path to original CSV data
        trace_path: Path to saved trace.nc file
        config: Model config dict
        y_pred_train: Predicted values from training
        in_sample_r2: R² from training
        output_dir: Directory for output files
        test_size: Fraction of data to hold out

    Returns:
        Validation result dict
    """
    # Load data
    df = pd.read_csv(csv_path)
    df[config['data']['date_col']] = pd.to_datetime(df[config['data']['date_col']])
    df = df.sort_values(config['data']['date_col']).reset_index(drop=True)

    # Load trace
    trace = az.from_netcdf(trace_path)

    # Create validator
    validator = OOSValidator(
        df=df,
        y_col=config['data']['target_col'],
        channel_cols=config['data']['channels'],
        control_cols=config['data']['controls'],
        date_col=config['data']['date_col'],
        scale_y=config['preprocess']['scale_y'],
        test_size=test_size
    )

    # Run validation
    result = validator.validate_with_trace(trace, y_pred_train, in_sample_r2)

    # Save results
    validator.save_results(output_dir)
    validator.save_predictions_csv(output_dir)

    return result


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 5:
        print("Usage: python oos_validator.py <csv_path> <trace_path> <config_path> <output_dir> [test_size]")
        sys.exit(1)

    csv_path = sys.argv[1]
    trace_path = sys.argv[2]
    config_path = sys.argv[3]
    output_dir = sys.argv[4]
    test_size = float(sys.argv[5]) if len(sys.argv) > 5 else 0.20

    # Load config
    import yaml
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    # For standalone testing, we need to compute y_pred_train and in_sample_r2
    # This is normally done by the iteration engine
    print("Note: Standalone mode requires pre-computed train predictions")
    print("Use run_oos_validation() from iteration_engine instead")
