#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Shapley Value Attribution Analysis
Game theory-based fair attribution using Shapley values
"""

import pandas as pd
import numpy as np
from itertools import combinations
import warnings
warnings.filterwarnings('ignore')

class ShapleyValueAttributor:
    """Shapley value-based attribution analysis"""

    def __init__(self):
        """Initialize the Shapley value attributor"""
        self.shapley_values = {}
        self.attribution_weights = {}
        self.channel_combinations = {}
        self.combination_performance = {}

    def calculate_shapley_values(self, paths_df, value_column='conversion_value'):
        """
        Calculate Shapley values for channel attribution

        Args:
            paths_df (pd.DataFrame): Customer journey paths
            value_column (str): Column name for conversion values

        Returns:
            dict: Shapley values for each channel
        """
        print("\n=== 计算Shapley值归因 ===")

        # Get all unique channels
        all_channels = set()
        for path in paths_df['path']:
            for touchpoint in path:
                if touchpoint not in ['开始', '成功转化', '未转化']:
                    all_channels.add(touchpoint)

        channels = sorted(list(all_channels))
        print(f"识别了 {len(channels)} 个渠道: {channels}")

        # Calculate performance for each channel combination
        self._calculate_combination_performance(paths_df, channels, value_column)

        # Calculate Shapley values
        shapley_values = {}
        total_value = paths_df[value_column].sum()

        if total_value == 0:
            print("警告: 总转化价值为0，使用平均权重")
            avg_value = 1.0 / len(channels)
            shapley_values = {channel: avg_value for channel in channels}
        else:
            for channel in channels:
                print(f"\n计算 {channel} 的Shapley值...")
                shapley_value = self._calculate_channel_shapley_value(
                    channel, channels, total_value
                )
                shapley_values[channel] = shapley_value
                print(f"  {channel}: {shapley_value:.6f}")

        # Normalize Shapley values to ensure they sum to 1
        total_shapley = sum(shapley_values.values())
        if total_shapley > 0:
            attribution_weights = {
                channel: value / total_shapley
                for channel, value in shapley_values.items()
            }
        else:
            attribution_weights = {channel: 1.0/len(channels) for channel in channels}

        self.shapley_values = shapley_values
        self.attribution_weights = attribution_weights

        print(f"\nShapley值归因权重:")
        for channel, weight in sorted(attribution_weights.items(),
                                       key=lambda x: x[1], reverse=True):
            print(f"  {channel}: {weight:.4f} ({weight*100:.1f}%)")

        return attribution_weights

    def _calculate_combination_performance(self, paths_df, channels, value_column):
        """
        Calculate performance metrics for each channel combination

        Args:
            paths_df (pd.DataFrame): Customer journey paths
            channels (list): List of all channels
            value_column (str): Column name for conversion values
        """
        print("\n=== 计算渠道组合性能 ===")

        # Generate all possible channel combinations
        channel_combinations = []
        for r in range(1, len(channels) + 1):
            for combo in combinations(channels, r):
                channel_combinations.append(list(combo))

        # Add empty combination
        channel_combinations.append([])

        print(f"需要计算 {len(channel_combinations)} 个组合的性能...")

        self.channel_combinations = {i: combo for i, combo in enumerate(channel_combinations)}
        self.combination_performance = {}

        for i, combination in enumerate(channel_combinations):
            # Calculate total value generated by this combination
            total_value = 0
            total_conversions = 0

            for _, row in paths_df.iterrows():
                if row[value_column] > 0:  # Only consider converting paths
                    path_channels = set([
                        touchpoint for touchpoint in row['path']
                        if touchpoint not in ['开始', '成功转化', '未转化']
                    ])

                    # Check if this combination contributed to the conversion
                    if set(combination).issubset(path_channels):
                        total_value += row[value_column]
                        total_conversions += 1

            self.combination_performance[i] = {
                'combination': combination,
                'total_value': total_value,
                'conversions': total_conversions
            }

            if i % 100 == 0 and i > 0:
                print(f"  已计算 {i}/{len(channel_combinations)} 个组合...")

        print(f"渠道组合性能计算完成")

    def _calculate_channel_shapley_value(self, channel, all_channels, total_value):
        """
        Calculate Shapley value for a specific channel

        Args:
            channel (str): Target channel
            all_channels (list): All channels
            total_value (float): Total conversion value

        Returns:
            float: Shapley value for the channel
        """
        n = len(all_channels)
        shapley_value = 0.0

        # Calculate marginal contributions for all combinations
        for r in range(n):
            combination_count = 0

            # Generate all combinations of size r that don't include the channel
            for combo in combinations([c for c in all_channels if c != channel], r):
                combination_count += 1

                # Performance without the channel
                without_channel_performance = self._get_combination_performance(combo)

                # Performance with the channel added
                with_channel_combo = sorted(list(combo) + [channel])
                with_channel_performance = self._get_combination_performance(with_channel_combo)

                # Marginal contribution
                marginal_contribution = with_channel_performance - without_channel_performance

                # Weight in Shapley value calculation
                weight = 1.0 / (combination_count * n) if combination_count > 0 else 0

                shapley_value += weight * marginal_contribution

        return shapley_value

    def _get_combination_performance(self, combination):
        """
        Get performance value for a specific channel combination

        Args:
            combination (list): Channel combination

        Returns:
            float: Performance value (total conversion value)
        """
        # Convert combination to set for easier comparison
        combo_set = set(combination)

        # Find matching performance record
        for i, perf in self.combination_performance.items():
            if set(perf['combination']) == combo_set:
                return perf['total_value']

        return 0.0

    def calculate_channel_synergy(self):
        """
        Calculate channel synergy effects

        Returns:
            dict: Channel synergy analysis
        """
        print("\n=== 计算渠道协同效应 ===")

        if not self.combination_performance:
            print("错误: 需要先计算组合性能")
            return {}

        synergy_analysis = {}

        # Calculate individual channel performance
        individual_performance = {}
        for channel, perf in self.combination_performance.items():
            if len(perf['combination']) == 1:
                individual_performance[perf['combination'][0]] = perf['total_value']

        # Calculate synergy for each channel pair
        channels = list(individual_performance.keys())
        for i in range(len(channels)):
            for j in range(i + 1, len(channels)):
                channel1, channel2 = channels[i], channels[j]

                # Find performance of the pair
                pair_performance = 0
                for perf in self.combination_performance.values():
                    if (len(perf['combination']) == 2 and
                        set(perf['combination']) == {channel1, channel2}):
                        pair_performance = perf['total_value']
                        break

                # Calculate expected additive performance
                expected_additive = (individual_performance[channel1] +
                                    individual_performance[channel2])

                # Calculate synergy
                if expected_additive > 0:
                    synergy_ratio = pair_performance / expected_additive
                else:
                    synergy_ratio = 1.0

                synergy_analysis[f"{channel1}_{channel2}"] = {
                    'channel1': channel1,
                    'channel2': channel2,
                    'individual1': individual_performance[channel1],
                    'individual2': individual_performance[channel2],
                    'combined': pair_performance,
                    'expected_additive': expected_additive,
                    'synergy_ratio': synergy_ratio,
                    'synergy_type': 'positive' if synergy_ratio > 1.1 else 'neutral' if synergy_ratio > 0.9 else 'negative'
                }

        # Sort by synergy ratio
        sorted_synergy = dict(sorted(synergy_analysis.items(),
                                   key=lambda x: x[1]['synergy_ratio'],
                                   reverse=True))

        print("渠道协同效应分析:")
        for pair_key, analysis in sorted_synergy.items():
            synergy_type = analysis['synergy_type']
            print(f"  {analysis['channel1']} + {analysis['channel2']}: "
                  f"协同比={analysis['synergy_ratio']:.3f} ({synergy_type})")

        return sorted_synergy

    def analyze_marginal_contributions(self):
        """
        Analyze marginal contributions of channels

        Returns:
            pd.DataFrame: Marginal contribution analysis
        """
        print("\n=== 分析边际贡献 ===")

        if not self.shapley_values:
            print("错误: 需要先计算Shapley值")
            return pd.DataFrame()

        # Calculate marginal contribution statistics
        marginal_analysis = []
        total_shapley = sum(self.shapley_values.values())

        for channel, shapley_value in self.shapley_values.items():
            marginal_analysis.append({
                'channel': channel,
                'shapley_value': shapley_value,
                'attribution_weight': shapley_value / total_shapley if total_shapley > 0 else 0,
                'performance_tier': self._classify_channel_performance(shapley_value / total_shapley if total_shapley > 0 else 0)
            })

        marginal_df = pd.DataFrame(marginal_analysis)
        marginal_df = marginal_df.sort_values('shapley_value', ascending=False)

        print("边际贡献排名:")
        for idx, row in marginal_df.iterrows():
            print(f"  {row['channel']}: Shapley值={row['shapley_value']:.6f}, "
                  f"权重={row['attribution_weight']:.4f}, 层级={row['performance_tier']}")

        return marginal_df

    def _classify_channel_performance(self, attribution_weight):
        """Classify channel performance tier"""
        if attribution_weight >= 0.3:
            return 'top_performer'
        elif attribution_weight >= 0.15:
            return 'strong_performer'
        elif attribution_weight >= 0.05:
            return 'moderate_performer'
        else:
            return 'low_performer'

    def optimize_channel_mix(self, budget_constraints=None):
        """
        Optimize channel mix based on Shapley values

        Args:
            budget_constraints (dict): Optional budget constraints per channel

        Returns:
            dict: Channel mix optimization recommendations
        """
        print("\n=== 优化渠道组合 ===")

        if not self.attribution_weights:
            print("错误: 需要先计算归因权重")
            return {}

        optimization_results = {}

        # Sort channels by attribution weight
        sorted_channels = sorted(self.attribution_weights.items(),
                                 key=lambda x: x[1], reverse=True)

        # Calculate optimal budget allocation
        total_budget = 1.0  # Assume normalized budget
        optimal_allocation = {}

        remaining_budget = total_budget
        for channel, weight in sorted_channels:
            if budget_constraints and channel in budget_constraints:
                max_allocation = budget_constraints[channel]
                allocation = min(weight, max_allocation, remaining_budget)
            else:
                allocation = min(weight, remaining_budget)

            optimal_allocation[channel] = allocation
            remaining_budget -= allocation

        if remaining_budget > 0:
            # Distribute remaining budget proportionally
            for channel, weight in sorted_channels:
                if remaining_budget > 0:
                    extra = min(remaining_budget * weight / sum(w for w, _ in sorted_channels),
                                weight - optimal_allocation.get(channel, 0))
                    optimal_allocation[channel] += extra
                    remaining_budget -= extra

        # Calculate expected improvement
        current_efficiency = sum(self.attribution_weights.values())
        optimal_efficiency = sum(optimal_allocation.values())
        improvement = (optimal_efficiency - current_efficiency) / current_efficiency * 100 if current_efficiency > 0 else 0

        optimization_results = {
            'current_weights': self.attribution_weights,
            'optimal_allocation': optimal_allocation,
            'budget_constraints': budget_constraints,
            'current_efficiency': current_efficiency,
            'optimal_efficiency': optimal_efficiency,
            'expected_improvement': improvement,
            'recommendations': self._generate_optimization_recommendations(
                self.attribution_weights, optimal_allocation
            )
        }

        print("渠道组合优化结果:")
        print(f"当前效率: {current_efficiency:.4f}")
        print(f"优化效率: {optimal_efficiency:.4f}")
        print(f"预期改善: {improvement:.1f}%")

        return optimization_results

    def _generate_optimization_recommendations(self, current_weights, optimal_allocation):
        """Generate optimization recommendations"""
        recommendations = []

        for channel in current_weights:
            current = current_weights.get(channel, 0)
            optimal = optimal_allocation.get(channel, 0)
            difference = optimal - current

            if abs(difference) > 0.05:  # Significant difference
                if difference > 0:
                    recommendations.append({
                        'channel': channel,
                        'action': 'increase',
                        'reason': f"建议增加投资，从 {current:.3f} 增加到 {optimal:.3f}",
                        'priority': 'high' if difference > 0.1 else 'medium'
                    })
                else:
                    recommendations.append({
                        'channel': channel,
                        'action': 'decrease',
                        'reason': f"建议减少投资，从 {current:.3f} 减少到 {optimal:.3f}",
                        'priority': 'high' if abs(difference) > 0.1 else 'medium'
                    })

        return sorted(recommendations, key=lambda x: abs(x['priority']), reverse=True)

    def run_complete_shapley_analysis(self, paths_df, value_column='conversion_value'):
        """
        Run complete Shapley value attribution analysis

        Args:
            paths_df (pd.DataFrame): Customer journey paths
            value_column (str): Column name for conversion values

        Returns:
            dict: Complete Shapley value analysis results
        """
        print("🎮 开始Shapley值归因分析")
        print("=" * 50)

        # 1. Calculate Shapley values
        attribution_weights = self.calculate_shapley_values(paths_df, value_column)

        # 2. Calculate channel synergy
        synergy_analysis = self.calculate_channel_synergy()

        # 3. Analyze marginal contributions
        marginal_analysis = self.analyze_marginal_contributions()

        # 4. Optimize channel mix
        optimization_results = self.optimize_channel_mix()

        results = {
            'attribution_weights': attribution_weights,
            'shapley_values': self.shapley_values,
            'channel_synergy': synergy_analysis,
            'marginal_analysis': marginal_analysis,
            'optimization': optimization_results,
            'total_conversions': paths_df['converted'].sum(),
            'total_value': paths_df[value_column].sum()
        }

        print(f"\n✅ Shapley值分析完成！")
        print(f"总转化数: {results['total_conversions']:,}")
        print(f"总转化价值: {results['total_value']:,.2f}")
        print(f"分析了 {len(attribution_weights)} 个渠道的Shapley值")

        return results

def main():
    """Example usage of Shapley value attributor"""
    attributor = ShapleyValueAttributor()

    # Create sample data for demonstration
    sample_paths = [
        ['开始', '付费搜索', '社交媒体', '邮件营销', '成功转化'],
        ['开始', '社交媒体', '付费搜索', '成功转化'],
        ['开始', '邮件营销', '社交媒体', '成功转化'],
        ['开始', '付费搜索', '未转化'],
        ['开始', '社交媒体', '未转化'],
        ['开始', '付费搜索', '社交媒体', '邮件营销', '成功转化']
    ]

    paths_df = pd.DataFrame({
        'user_id': range(len(sample_paths)),
        'path': sample_paths,
        'converted': [1, 1, 1, 0, 0, 1],
        'conversion_value': [100, 150, 80, 0, 0, 200]
    })

    results = attributor.run_complete_shapley_analysis(paths_df)

    print(f"\nShapley值归因权重:")
    for channel, weight in results['attribution_weights'].items():
        print(f"  {channel}: {weight:.4f}")

if __name__ == "__main__":
    main()