"""
Data Analyzer - Auto-detect columns from CSV for MMM

Created by: Claude Code
Session ID: Ralph Loop Execution
Date: 2025-11-26
Purpose: Automatically detect date, target, channels, and controls from CSV

Based on: Bayesian MMM principles (NOT PyMC-Marketing)
"""

import pandas as pd
import numpy as np
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, asdict


@dataclass
class DataProfile:
    """Profile of MMM dataset for config generation"""
    # File info
    file_path: str
    n_rows: int
    n_cols: int

    # Column classification
    date_col: str
    target_col: str
    channels: List[str]
    controls: List[str]

    # Date range
    date_min: str
    date_max: str
    n_days: int

    # Target stats
    target_mean: float
    target_std: float
    target_min: float
    target_max: float

    # Channel stats
    channel_stats: Dict[str, Dict[str, float]]

    # Scale factors for priors
    scale_y: float  # For scaling target

    def to_dict(self) -> dict:
        return asdict(self)

    def to_json(self, path: Path) -> None:
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)


def detect_date_column(df: pd.DataFrame) -> str:
    """Detect date column by name or type"""
    # Try common names
    date_names = ['date', 'ds', 'datetime', 'time', 'day']
    for col in df.columns:
        if col.lower() in date_names:
            return col

    # Try to parse as date
    for col in df.columns:
        try:
            pd.to_datetime(df[col])
            return col
        except:
            continue

    raise ValueError("Could not detect date column")


def detect_target_column(df: pd.DataFrame) -> str:
    """Detect target (y) column - usually revenue or conversions"""
    target_names = ['revenue', 'sales', 'conversions', 'orders', 'target', 'y']
    for col in df.columns:
        if col.lower() in target_names:
            return col

    # Fallback: look for numeric columns with high variance
    numeric_cols = df.select_dtypes(include=[np.number]).columns
    variances = {col: df[col].var() for col in numeric_cols}
    return max(variances, key=variances.get)


def detect_channel_columns(df: pd.DataFrame, exclude: List[str] = None) -> List[str]:
    """
    Detect marketing channel columns.

    Heuristics:
    - Contains 'spend', 'cost', 'budget', 'investment'
    - OR platform names (google_, facebook_, bing_, meta_, tiktok_, etc.)
    - Numeric type
    - Non-negative values
    - Has variance (not constant)
    """
    exclude = exclude or []
    channels = []

    # Direct spend indicators
    spend_keywords = ['spend', 'cost', 'budget', 'investment', 'media']

    # Platform prefixes (may not have 'spend' in name)
    platform_prefixes = [
        'google_', 'facebook_', 'fb_', 'meta_', 'bing_', 'tiktok_',
        'linkedin_', 'twitter_', 'pinterest_', 'snapchat_',
        'youtube_', 'display_', 'search_', 'shopping_',
        'performance_max', 'pmax_', 'demand_gen', 'dg_'
    ]

    for col in df.columns:
        if col in exclude:
            continue

        col_lower = col.lower()

        # Check if spend keyword present
        is_spend = any(kw in col_lower for kw in spend_keywords)

        # Check if platform prefix present
        is_platform = any(col_lower.startswith(pf) or pf in col_lower
                         for pf in platform_prefixes)

        if is_spend or is_platform:
            # Validate: numeric, non-negative, has variance
            if df[col].dtype in [np.float64, np.int64, float, int]:
                if df[col].min() >= 0 and df[col].std() > 0:
                    channels.append(col)

    return channels


def detect_control_columns(df: pd.DataFrame, exclude: List[str] = None) -> List[str]:
    """
    Detect control variables.

    Categories:
    - Holiday flags (is_*, holiday)
    - Day of week (dow_*, weekday)
    - Weather (temperature, precipitation)
    - Seasonality (is_summer, is_winter)
    - Trends (GT_*, trend, index)
    """
    exclude = exclude or []
    controls = []

    # Control indicators
    control_patterns = [
        'is_', 'dow_', 'weekday', 'holiday',
        'temperature', 'precipitation', 'weather',
        'season', 'summer', 'winter', 'spring', 'fall',
        'trend', 'index', 'GT_',
        'avg_'  # Averaged metrics
    ]

    for col in df.columns:
        if col in exclude:
            continue

        col_lower = col.lower()
        if any(pattern.lower() in col_lower for pattern in control_patterns):
            controls.append(col)

    return controls


def compute_channel_stats(df: pd.DataFrame, channels: List[str]) -> Dict[str, Dict[str, float]]:
    """Compute statistics for each channel"""
    stats = {}
    for ch in channels:
        series = df[ch]
        stats[ch] = {
            'mean': float(series.mean()),
            'std': float(series.std()),
            'min': float(series.min()),
            'max': float(series.max()),
            'sum': float(series.sum()),
            'non_zero_pct': float((series > 0).mean()),
            'cv': float(series.std() / series.mean()) if series.mean() > 0 else 0
        }
    return stats


def analyze_dataset(csv_path: str) -> DataProfile:
    """
    Main function: analyze CSV and return DataProfile.

    Args:
        csv_path: Path to the MMM source CSV

    Returns:
        DataProfile with all detected columns and statistics
    """
    df = pd.read_csv(csv_path)

    # Step 1: Detect date column
    date_col = detect_date_column(df)
    df[date_col] = pd.to_datetime(df[date_col])

    # Step 2: Detect target column
    target_col = detect_target_column(df)

    # Step 3: Detect channel columns (exclude date and target)
    exclude = [date_col, target_col, 'total_spend']  # total_spend is sum of channels
    channels = detect_channel_columns(df, exclude)

    # Step 4: Detect control columns
    exclude_for_controls = exclude + channels
    controls = detect_control_columns(df, exclude_for_controls)

    # Step 5: Remove any mediators from controls (like new_users)
    # These should be handled separately in causal model
    mediator_patterns = ['new_user', 'traffic', 'visitor', 'lead', 'signup']
    controls_filtered = [c for c in controls
                        if not any(p in c.lower() for p in mediator_patterns)]

    # Compute statistics
    target_series = df[target_col]
    channel_stats = compute_channel_stats(df, channels)

    # Scale factor: std of target for prior scaling
    scale_y = float(target_series.std())

    profile = DataProfile(
        file_path=csv_path,
        n_rows=len(df),
        n_cols=len(df.columns),
        date_col=date_col,
        target_col=target_col,
        channels=channels,
        controls=controls_filtered,
        date_min=str(df[date_col].min().date()),
        date_max=str(df[date_col].max().date()),
        n_days=len(df),
        target_mean=float(target_series.mean()),
        target_std=float(target_series.std()),
        target_min=float(target_series.min()),
        target_max=float(target_series.max()),
        channel_stats=channel_stats,
        scale_y=scale_y
    )

    return profile


def print_profile(profile: DataProfile) -> None:
    """Pretty print the data profile"""
    print("=" * 60)
    print("MMM DATA PROFILE")
    print("=" * 60)
    print(f"\nFile: {profile.file_path}")
    print(f"Shape: {profile.n_rows} rows x {profile.n_cols} columns")
    print(f"Date range: {profile.date_min} → {profile.date_max} ({profile.n_days} days)")

    print(f"\n📊 Target: {profile.target_col}")
    print(f"   Mean: {profile.target_mean:,.2f}")
    print(f"   Std: {profile.target_std:,.2f}")
    print(f"   Range: [{profile.target_min:,.2f}, {profile.target_max:,.2f}]")

    print(f"\n📈 Channels ({len(profile.channels)}):")
    for ch in profile.channels:
        stats = profile.channel_stats[ch]
        print(f"   - {ch}: mean={stats['mean']:,.2f}, sum={stats['sum']:,.2f}, non_zero={stats['non_zero_pct']:.1%}")

    print(f"\n🎛️ Controls ({len(profile.controls)}):")
    for ctrl in profile.controls[:10]:  # First 10
        print(f"   - {ctrl}")
    if len(profile.controls) > 10:
        print(f"   ... and {len(profile.controls) - 10} more")

    print(f"\n⚙️ Scale factor (scale_y): {profile.scale_y:,.2f}")
    print("=" * 60)


if __name__ == "__main__":
    import sys

    if len(sys.argv) < 2:
        print("Usage: python data_analyzer.py <csv_path> [output_json]")
        sys.exit(1)

    csv_path = sys.argv[1]
    output_json = sys.argv[2] if len(sys.argv) > 2 else None

    profile = analyze_dataset(csv_path)
    print_profile(profile)

    if output_json:
        profile.to_json(Path(output_json))
        print(f"\nProfile saved to: {output_json}")
