#!/usr/bin/env python3
"""Analyze one column's value distribution and print a Markdown report."""

from __future__ import annotations

import argparse
import re
import sys
from datetime import UTC, datetime
from pathlib import Path

try:
    import duckdb
except ModuleNotFoundError:  # pragma: no cover - environment guard
    sys.stderr.write("Error: duckdb is required. Install it with `python -m pip install duckdb`.\n")
    raise SystemExit(1) from None


class MarkdownReport:
    def __init__(self, title: str) -> None:
        self.title = title
        self.sections: list[str] = []
        self.metadata: dict[str, str] = {
            "generated_at": datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S UTC"),
        }

    def add_metadata(self, key: str, value: str) -> None:
        self.metadata[key] = value

    def add_section(self, heading: str, content: str = "") -> None:
        self.sections.append(f"## {heading}\n\n{content}".rstrip())

    def add_text(self, text: str) -> None:
        self.sections.append(text)

    def add_summary_stats(self, stats: dict[str, str | int | float]) -> None:
        self.sections.append("\n".join(f"- **{key}**: {value}" for key, value in stats.items()))

    def add_table(
        self,
        headers: list[str],
        rows: list[list[str]],
        alignments: list[str] | None = None,
    ) -> None:
        alignments = alignments or ["left"] * len(headers)
        separators = []
        for alignment in alignments:
            if alignment == "right":
                separators.append("---:")
            elif alignment == "center":
                separators.append(":---:")
            else:
                separators.append("---")

        table = [
            "| " + " | ".join(headers) + " |",
            "| " + " | ".join(separators) + " |",
        ]
        table.extend("| " + " | ".join(str(cell) for cell in row) + " |" for row in rows)
        self.sections.append("\n".join(table))

    def build(self) -> str:
        parts = [f"# {self.title}"]
        if self.metadata:
            parts.append("\n".join(f"- **{key}**: {value}" for key, value in self.metadata.items()))
        parts.extend(self.sections)
        return "\n\n".join(parts)

    def write(self, output: str | None) -> None:
        content = self.build() + "\n"
        if output is None:
            sys.stdout.write(content)
            return
        Path(output).write_text(content, encoding="utf-8")


def validate_identifier(name: str) -> bool:
    return bool(re.match(r"^[A-Za-z_][A-Za-z0-9_]*$", name))


def quote_identifier(name: str) -> str:
    if not validate_identifier(name):
        raise ValueError(f"invalid identifier: {name}")
    return f'"{name}"'


def escape_string(value: str) -> str:
    return value.replace("'", "''")


def infer_source_type(source: str) -> str:
    source_lower = source.lower()
    if source_lower.endswith(".parquet"):
        return "parquet"
    if source_lower.endswith(".csv"):
        return "csv"
    if source_lower.endswith(".json") or source_lower.endswith(".jsonl"):
        return "json"
    return "unknown"


def build_scan_query(source: str, source_type: str) -> str:
    safe_source = escape_string(source)
    if source_type == "parquet":
        return f"parquet_scan('{safe_source}')"
    if source_type == "csv":
        return f"read_csv_auto('{safe_source}')"
    if source_type == "json":
        return f"read_json_auto('{safe_source}')"
    raise ValueError(f"unknown source type for {source!r}; pass --type")


def assess_cardinality(unique_count: int, total_count: int) -> str:
    if total_count == 0:
        return "Empty"
    ratio = unique_count / total_count
    if unique_count <= 10:
        return "Low"
    if unique_count <= 100 or ratio < 0.01:
        return "Medium"
    if ratio < 0.5:
        return "High"
    return "Very High (possibly unique identifier)"


def generate_observations(
    total: int,
    nulls: int,
    unique_count: int,
    top_value_pct: float,
) -> list[str]:
    observations = []
    cardinality = assess_cardinality(unique_count, total)
    if cardinality == "Low":
        observations.append(
            f"Low cardinality column ({unique_count} unique values) - suitable for categorical encoding"
        )
    elif cardinality == "Very High (possibly unique identifier)":
        observations.append("Very high cardinality - may be an identifier column, consider excluding from features")

    null_pct = (nulls / total * 100) if total > 0 else 0
    if null_pct == 0:
        observations.append("No missing data")
    elif null_pct < 1:
        observations.append(f"Minimal missing data ({null_pct:.2f}%)")
    elif null_pct < 5:
        observations.append(f"Some missing data ({null_pct:.2f}%) - consider imputation strategy")
    else:
        observations.append(f"Significant missing data ({null_pct:.2f}%) - investigate missingness pattern")

    if top_value_pct > 95:
        observations.append(f"Extreme imbalance: top value represents {top_value_pct:.1f}% of data")
    elif top_value_pct > 80:
        observations.append(f"Class imbalance detected: top value represents {top_value_pct:.1f}% of data")
    return observations


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Analyze column value distribution.")
    parser.add_argument("--source", required=True, help="Path to a Parquet, CSV, JSON, or JSONL file.")
    parser.add_argument("--column", required=True, help="Column name to analyze.")
    parser.add_argument("--output", help="Output file path. Defaults to stdout.")
    parser.add_argument("--limit", type=int, default=50, help="Max unique values to display.")
    parser.add_argument("--type", choices=["parquet", "csv", "json"], help="Override source type detection.")
    return parser.parse_args()


def query_column_stats(args: argparse.Namespace) -> tuple[int, int, int, list[tuple[object, int]]]:
    source_type = args.type or infer_source_type(args.source)
    scan = build_scan_query(args.source, source_type)
    safe_col = quote_identifier(args.column)
    conn = duckdb.connect(":memory:")
    try:
        total = conn.execute(f"SELECT COUNT(*) FROM {scan}").fetchone()[0]  # noqa: S608
        nulls = conn.execute(f"SELECT COUNT(*) FROM {scan} WHERE {safe_col} IS NULL").fetchone()[0]  # noqa: S608
        unique_count = conn.execute(f"SELECT COUNT(DISTINCT {safe_col}) FROM {scan}").fetchone()[0]  # noqa: S608
        distribution = conn.execute(
            f"""
            SELECT
                {safe_col} AS value,
                COUNT(*) AS count
            FROM {scan}
            WHERE {safe_col} IS NOT NULL
            GROUP BY {safe_col}
            ORDER BY count DESC
            LIMIT ?
            """,  # noqa: S608
            (args.limit,),
        ).fetchall()
    finally:
        conn.close()
    return total, nulls, unique_count, distribution


def build_distribution_rows(
    distribution: list[tuple[object, int]],
    total: int,
) -> tuple[list[list[str]], float]:
    cumulative = 0.0
    rows = []
    top_value_pct = 0.0

    for index, (value, count) in enumerate(distribution):
        pct = (count / total * 100) if total > 0 else 0
        cumulative += pct
        if index == 0:
            top_value_pct = pct

        display_value = str(value) if value is not None else "<NULL>"
        if len(display_value) > 50:
            display_value = display_value[:47] + "..."
        rows.append([display_value, f"{count:,}", f"{pct:.2f}%", f"{cumulative:.2f}%"])

    return rows, top_value_pct


def build_report(
    args: argparse.Namespace,
    total: int,
    nulls: int,
    unique_count: int,
    distribution_rows: list[list[str]],
    top_value_pct: float,
) -> MarkdownReport:
    non_null = total - nulls
    null_pct = (nulls / total * 100) if total > 0 else 0
    report = MarkdownReport(f"Column Distribution: {args.column}")
    report.add_metadata("source", args.source)
    report.add_metadata("column", args.column)
    report.add_section("Summary")
    report.add_summary_stats(
        {
            "Total rows": f"{total:,}",
            "Null/missing": f"{nulls:,} ({null_pct:.2f}%)",
            "Non-null rows": f"{non_null:,}",
            "Unique values": f"{unique_count:,}",
            "Cardinality": assess_cardinality(unique_count, total),
        }
    )
    report.add_section("Distribution")
    if unique_count > args.limit:
        report.add_text(f"*Showing top {args.limit} of {unique_count:,} unique values*")
    report.add_table(
        headers=["Value", "Count", "Percentage", "Cumulative"],
        rows=distribution_rows,
        alignments=["left", "right", "right", "right"],
    )

    observations = generate_observations(total, nulls, unique_count, top_value_pct)
    if observations:
        report.add_section("Observations", "\n".join(f"- {observation}" for observation in observations))
    return report


def main() -> int:
    args = parse_args()
    if not Path(args.source).exists():
        sys.stderr.write(f"Error: source not found: {args.source}\n")
        return 1
    if not validate_identifier(args.column):
        sys.stderr.write(f"Error: invalid column name: {args.column}\n")
        return 1

    try:
        total, nulls, unique_count, distribution = query_column_stats(args)
    except Exception as exc:
        sys.stderr.write(f"Error analyzing column: {exc}\n")
        return 1

    rows, top_value_pct = build_distribution_rows(distribution, total)
    build_report(args, total, nulls, unique_count, rows, top_value_pct).write(args.output)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
