"""
Config Generator - Generate YAML configs for MMM training

Created by: Claude Code
Session ID: Ralph Loop Execution
Date: 2025-11-26
Purpose: Generate PyMC model configuration based on data profile

Prior Selection Methodology:
- Source #1: Bayesian theory (weakly-informative priors)
- Source #2: Open MMM practice (Robyn, LightweightMMM, PyMC-Marketing docs)
- Source #3: Client data context
"""

import yaml
import json
from pathlib import Path
from typing import Dict, List, Optional
from dataclasses import dataclass

from data_analyzer import DataProfile


@dataclass
class ChannelConfig:
    """Configuration for a single marketing channel"""
    name: str

    # Adstock parameters (geometric decay)
    adstock_alpha_prior_a: float = 3.0  # Beta(a,b) for alpha
    adstock_alpha_prior_b: float = 3.0
    adstock_max_lag: int = 8  # Max lag days for digital

    # Saturation parameters (Hill function)
    saturation_lam_prior_mu: float = 0.0  # LogNormal prior for lambda
    saturation_lam_prior_sigma: float = 1.0
    saturation_beta_prior_mu: float = 0.0  # Effect coefficient
    saturation_beta_prior_sigma: float = 1.0

    # Channel type for grouping
    channel_type: str = "digital"  # digital, social, search, display


@dataclass
class ModelConfig:
    """Full model configuration"""
    # Data
    data_path: str
    date_col: str
    target_col: str
    channels: List[str]
    controls: List[str]

    # Preprocessing
    scale_y: float
    normalize_channels: bool = True

    # Model structure
    include_intercept: bool = True
    include_trend: bool = False  # Can add if needed
    include_seasonality: bool = False  # Day-of-week handled in controls

    # Channel configs
    channel_configs: Dict[str, ChannelConfig] = None

    # Sampler settings
    sampler_draws: int = 2000
    sampler_tune: int = 1500
    sampler_chains: int = 4
    sampler_target_accept: float = 0.95
    sampler_max_treedepth: int = 12
    sampler_cores: int = 4

    # Noise model
    likelihood: str = "normal"  # normal or student_t
    sigma_prior: float = 1.0  # HalfNormal prior scale

    def to_yaml(self, path: Path) -> None:
        """Save config to YAML file"""
        config_dict = {
            'data': {
                'path': self.data_path,
                'date_col': self.date_col,
                'target_col': self.target_col,
                'channels': self.channels,
                'controls': self.controls,
            },
            'preprocess': {
                'scale_y': self.scale_y,
                'normalize_channels': self.normalize_channels,
            },
            'model': {
                'include_intercept': self.include_intercept,
                'include_trend': self.include_trend,
                'include_seasonality': self.include_seasonality,
                'likelihood': self.likelihood,
                'sigma_prior': self.sigma_prior,
            },
            'channels': {},
            'sampler': {
                'draws': self.sampler_draws,
                'tune': self.sampler_tune,
                'chains': self.sampler_chains,
                'target_accept': self.sampler_target_accept,
                'max_treedepth': self.sampler_max_treedepth,
                'cores': self.sampler_cores,
            }
        }

        # Add channel configs
        for ch_name, ch_config in (self.channel_configs or {}).items():
            config_dict['channels'][ch_name] = {
                'adstock': {
                    'alpha_prior': [ch_config.adstock_alpha_prior_a, ch_config.adstock_alpha_prior_b],
                    'max_lag': ch_config.adstock_max_lag,
                },
                'saturation': {
                    'lam_prior': [ch_config.saturation_lam_prior_mu, ch_config.saturation_lam_prior_sigma],
                    'beta_prior': [ch_config.saturation_beta_prior_mu, ch_config.saturation_beta_prior_sigma],
                },
                'type': ch_config.channel_type,
            }

        with open(path, 'w') as f:
            yaml.dump(config_dict, f, default_flow_style=False, sort_keys=False)


def classify_channel_type(channel_name: str) -> str:
    """
    Classify channel type based on name.

    This helps set appropriate priors:
    - search: shorter adstock (1-7 days)
    - social: medium adstock (3-14 days)
    - display: longer adstock (7-21 days)
    - video: longer adstock (7-21 days)
    """
    name_lower = channel_name.lower()

    if any(kw in name_lower for kw in ['search', 'bof', 'mof', 'shopping']):
        return 'search'
    elif any(kw in name_lower for kw in ['facebook', 'meta', 'instagram', 'tiktok', 'social']):
        return 'social'
    elif any(kw in name_lower for kw in ['display', 'demand_gen', 'tof', 'performance_max', 'pmax']):
        return 'display'
    elif any(kw in name_lower for kw in ['youtube', 'video', 'tv']):
        return 'video'
    else:
        return 'digital'  # Default


def get_channel_priors(channel_type: str, channel_stats: Dict) -> ChannelConfig:
    """
    Generate priors based on channel type.

    Prior Selection Methodology:
    - Adstock alpha ~ Beta(a, b): controls decay rate
      - Higher a → faster decay (more immediate effect)
      - Digital channels: typically faster decay
    - Saturation lambda ~ LogNormal(mu, sigma): half-saturation point
    - Effect beta ~ Normal(0, sigma): coefficient (can be regularized)
    """
    config = ChannelConfig(name=channel_stats.get('name', 'unknown'))

    # Set priors based on channel type
    if channel_type == 'search':
        # Search: fast response, shorter memory
        config.adstock_alpha_prior_a = 4.0  # Favors higher alpha (faster decay)
        config.adstock_alpha_prior_b = 2.0
        config.adstock_max_lag = 7
        config.saturation_lam_prior_sigma = 0.8  # Tighter
    elif channel_type == 'social':
        # Social: medium response
        config.adstock_alpha_prior_a = 3.0
        config.adstock_alpha_prior_b = 3.0
        config.adstock_max_lag = 14
        config.saturation_lam_prior_sigma = 1.0
    elif channel_type == 'display':
        # Display/Awareness: slower response, longer memory
        config.adstock_alpha_prior_a = 2.0  # Favors lower alpha (slower decay)
        config.adstock_alpha_prior_b = 3.0
        config.adstock_max_lag = 21
        config.saturation_lam_prior_sigma = 1.2
    elif channel_type == 'video':
        # Video/TV: slowest response
        config.adstock_alpha_prior_a = 2.0
        config.adstock_alpha_prior_b = 4.0
        config.adstock_max_lag = 28
        config.saturation_lam_prior_sigma = 1.5
    else:  # digital default
        config.adstock_alpha_prior_a = 3.0
        config.adstock_alpha_prior_b = 3.0
        config.adstock_max_lag = 14
        config.saturation_lam_prior_sigma = 1.0

    config.channel_type = channel_type

    # Adjust beta prior based on spend level (from channel_stats)
    # Low-spend channels get more regularization
    non_zero_pct = channel_stats.get('non_zero_pct', 1.0)
    if non_zero_pct < 0.5:
        config.saturation_beta_prior_sigma = 0.5  # More regularization
    else:
        config.saturation_beta_prior_sigma = 1.0

    return config


def generate_config_from_profile(profile: DataProfile, iteration: int = 0) -> ModelConfig:
    """
    Generate model configuration from data profile.

    Args:
        profile: DataProfile from data_analyzer
        iteration: Iteration number for adjustment (0 = initial)

    Returns:
        ModelConfig ready for training
    """
    # Generate channel configs
    channel_configs = {}
    for ch_name in profile.channels:
        ch_stats = profile.channel_stats.get(ch_name, {})
        ch_stats['name'] = ch_name
        ch_type = classify_channel_type(ch_name)
        channel_configs[ch_name] = get_channel_priors(ch_type, ch_stats)

    # Initial sampler settings (conservative for convergence)
    sampler_draws = 2000
    sampler_tune = 1500
    target_accept = 0.95

    # Adjust based on iteration
    if iteration > 0:
        # Increase samples for better convergence
        sampler_draws = min(2000 + iteration * 400, 4000)
        sampler_tune = min(1500 + iteration * 300, 3000)
        target_accept = min(0.95 + iteration * 0.01, 0.99)

    config = ModelConfig(
        data_path=profile.file_path,
        date_col=profile.date_col,
        target_col=profile.target_col,
        channels=profile.channels,
        controls=profile.controls,
        scale_y=profile.scale_y,
        channel_configs=channel_configs,
        sampler_draws=sampler_draws,
        sampler_tune=sampler_tune,
        sampler_target_accept=target_accept,
    )

    return config


def adjust_config_for_diagnostics(
    config: ModelConfig,
    diagnostics: Dict,
    iteration: int
) -> ModelConfig:
    """
    Adjust config based on diagnostic results.

    Diagnostic issues and fixes:
    - rhat > 1.02: increase tune, target_accept
    - ess < 100: increase draws
    - divergences > 0: reduce treedepth, check priors
    - r2 < 0.55: consider adding controls or simplifying
    """
    new_config = config  # Start with current

    worst_rhat = diagnostics.get('worst_rhat', 1.0)
    min_ess = diagnostics.get('min_ess', 1000)
    divergences = diagnostics.get('divergences', 0)
    r2 = diagnostics.get('r2', 0.6)

    # Fix convergence issues
    if worst_rhat > 1.02:
        print(f"  [Adjust] rhat={worst_rhat:.4f} > 1.02 → increasing tune and target_accept")
        new_config.sampler_tune = min(config.sampler_tune + 500, 5000)
        new_config.sampler_target_accept = min(config.sampler_target_accept + 0.01, 0.995)

    # Fix ESS issues
    if min_ess < 100:
        print(f"  [Adjust] ESS={min_ess:.1f} < 100 → increasing draws")
        new_config.sampler_draws = min(config.sampler_draws + 500, 5000)

    # Fix divergences
    if divergences > 0:
        print(f"  [Adjust] divergences={divergences} > 0 → reducing treedepth")
        new_config.sampler_max_treedepth = max(config.sampler_max_treedepth - 2, 8)

        # Also tighten priors slightly
        for ch_name, ch_config in new_config.channel_configs.items():
            ch_config.saturation_lam_prior_sigma *= 0.8
            ch_config.saturation_beta_prior_sigma *= 0.8

    # Low R² might need model changes (less actionable here)
    if r2 < 0.50:
        print(f"  [Adjust] R²={r2:.4f} < 0.50 → consider model structure changes")

    return new_config


if __name__ == "__main__":
    import sys
    from data_analyzer import analyze_dataset

    if len(sys.argv) < 2:
        print("Usage: python config_generator.py <csv_path> [output_yaml]")
        sys.exit(1)

    csv_path = sys.argv[1]
    output_yaml = sys.argv[2] if len(sys.argv) > 2 else "mmm_config.yaml"

    # Analyze data
    print("Analyzing dataset...")
    profile = analyze_dataset(csv_path)

    # Generate config
    print("Generating config...")
    config = generate_config_from_profile(profile)

    # Save
    config.to_yaml(Path(output_yaml))
    print(f"\nConfig saved to: {output_yaml}")

    # Print summary
    print(f"\nConfig summary:")
    print(f"  Channels: {len(config.channels)}")
    print(f"  Controls: {len(config.controls)}")
    print(f"  Sampler: draws={config.sampler_draws}, tune={config.sampler_tune}")
    print(f"  Target accept: {config.sampler_target_accept}")
