"""
Out-of-Sample Validation Script for MMM

Created by: Claude Code
Date: 2025-11-30
Purpose: Calculate true out-of-sample R² using saved normalization parameters
"""

import pandas as pd
import numpy as np
import arviz as az
import yaml
import json
from pathlib import Path
import argparse


def load_normalization_params(artifacts_dir: str) -> dict:
    """Load saved normalization parameters from training"""
    params_path = Path(artifacts_dir) / 'preprocessing' / 'normalization_params.json'
    with open(params_path, 'r') as f:
        return json.load(f)


def preprocess_data(df: pd.DataFrame, channels: list, params: dict) -> tuple:
    """Apply same preprocessing as training using saved parameters"""
    # Extract target
    y = df['revenue'].values.astype(np.float64)

    # Scale target using TRAINING scale_y
    scale_y = params['scale_y']
    y_scaled = y / scale_y

    # Extract channels
    X_channels = df[channels].values.astype(np.float64)

    # Normalize channels using TRAINING channel_maxes
    if params['normalize_channels']:
        for i, channel in enumerate(channels):
            channel_max = params['channel_maxes'][channel]
            X_channels[:, i] = X_channels[:, i] / channel_max

    return y, y_scaled, X_channels


def predict(X_channels: np.ndarray, trace: az.InferenceData, scale_y: float) -> tuple:
    """Make predictions using trained model"""
    # Get posterior mean parameters
    intercept = float(trace.posterior['intercept'].mean().values)
    coefs = trace.posterior['channel_coefs'].mean(dim=['chain', 'draw']).values

    # Predict (scaled)
    y_pred_scaled = np.full(len(X_channels), intercept)
    for i, coef in enumerate(coefs):
        y_pred_scaled += coef * X_channels[:, i]

    # Convert back to original scale
    y_pred_raw = y_pred_scaled * scale_y

    return y_pred_raw, y_pred_scaled


def calculate_metrics(y_actual: np.ndarray, y_pred: np.ndarray, y_scaled: np.ndarray, y_pred_scaled: np.ndarray) -> dict:
    """Calculate R² and MAPE"""
    # R² on scaled data
    ss_res = np.sum((y_scaled - y_pred_scaled) ** 2)
    ss_tot = np.sum((y_scaled - y_scaled.mean()) ** 2)
    r2 = 1 - (ss_res / ss_tot)

    # MAPE on original scale
    mape = np.mean(np.abs((y_pred - y_actual) / y_actual)) * 100

    # MAE
    mae = np.mean(np.abs(y_pred - y_actual))

    return {
        'r2': float(r2),
        'mape': float(mape),
        'mae': float(mae)
    }


def validate_out_of_sample(
    train_artifacts_dir: str,
    test_csv_path: str,
    config_path: str = None
) -> dict:
    """
    Calculate out-of-sample R² on test set using trained model.

    Args:
        train_artifacts_dir: Path to training artifacts (contains trace.nc and preprocessing/)
        test_csv_path: Path to test CSV file
        config_path: Optional path to config (otherwise inferred from artifacts_dir)

    Returns:
        Dictionary with metrics and predictions
    """
    artifacts_dir = Path(train_artifacts_dir)

    # Load config
    if config_path is None:
        # Try to find config in parent directory
        config_path = artifacts_dir.parent / f'config_{artifacts_dir.name.split("_")[-1]}.yaml'

    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)

    channels = config['data']['channels']

    # Load normalization parameters from training
    print(f"Loading normalization params from: {artifacts_dir}/preprocessing/")
    norm_params = load_normalization_params(str(artifacts_dir))

    # Load test data
    print(f"Loading test data from: {test_csv_path}")
    test_df = pd.read_csv(test_csv_path)

    # Preprocess test data using TRAINING normalization
    y_actual, y_scaled, X_channels = preprocess_data(test_df, channels, norm_params)

    # Load trained model
    print(f"Loading trained model from: {artifacts_dir}/trace.nc")
    trace = az.from_netcdf(str(artifacts_dir / 'trace.nc'))

    # Make predictions
    y_pred, y_pred_scaled = predict(X_channels, trace, norm_params['scale_y'])

    # Calculate metrics
    metrics = calculate_metrics(y_actual, y_pred, y_scaled, y_pred_scaled)

    # Create results
    results = {
        'metrics': metrics,
        'predictions': pd.DataFrame({
            'week': test_df['week'] if 'week' in test_df.columns else range(len(test_df)),
            'actual': y_actual,
            'predicted': y_pred,
            'error_pct': ((y_pred - y_actual) / y_actual) * 100
        })
    }

    return results


def main():
    parser = argparse.ArgumentParser(description='Validate MMM model on out-of-sample data')
    parser.add_argument('train_artifacts', help='Path to training artifacts directory')
    parser.add_argument('test_csv', help='Path to test CSV file')
    parser.add_argument('-c', '--config', help='Path to config file (optional)')
    parser.add_argument('-o', '--output', help='Output CSV path for predictions')

    args = parser.parse_args()

    # Run validation
    results = validate_out_of_sample(args.train_artifacts, args.test_csv, args.config)

    # Print results
    print('\n' + '='*80)
    print('OUT-OF-SAMPLE VALIDATION RESULTS')
    print('='*80)
    print(f"R² (variance explained): {results['metrics']['r2']:.4f}")
    print(f"MAPE (prediction error):  {results['metrics']['mape']:.1f}%")
    print(f"MAE (mean absolute error): ${results['metrics']['mae']:,.0f}")
    print('='*80)

    # Show sample predictions
    print('\nSample predictions (first 10 rows):')
    print(results['predictions'].head(10).to_string(index=False))

    # Save if requested
    if args.output:
        results['predictions'].to_csv(args.output, index=False)
        print(f'\n✅ Predictions saved to: {args.output}')

    return results


if __name__ == '__main__':
    main()
