#!/usr/bin/env nix-shell
#!nix-shell -i python3 -p "python3.withPackages(ps: with ps; [ sqlcipher3 pycryptodome ])"

"""
signal-search-fast — Direct SQL search against Signal Desktop's encrypted database.

Bypasses sigexport entirely for sub-second searches against live data.
"""

import argparse
import json
import os
import re
import subprocess
import sys
from datetime import datetime
from pathlib import Path

# Import after nix-shell sets up environment
from sqlcipher3 import dbapi2
from Crypto.Cipher import AES
from Crypto.Hash import SHA1
from Crypto.Protocol.KDF import PBKDF2
from Crypto.Util.Padding import unpad


def get_signal_dir() -> Path:
    """Get Signal Desktop data directory."""
    if sys.platform == "darwin":
        return Path.home() / "Library" / "Application Support" / "Signal"
    elif sys.platform == "linux":
        # Try standard location first, then flatpak
        std = Path.home() / ".config" / "Signal"
        if std.exists():
            return std
        flatpak = Path.home() / ".var" / "app" / "org.signal.Signal" / "config" / "Signal"
        if flatpak.exists():
            return flatpak
        return std  # Return standard for error message
    else:
        return Path.home() / "AppData" / "Roaming" / "Signal"


def get_keychain_password() -> str:
    """Get Signal Safe Storage password from macOS keychain."""
    result = subprocess.run(
        ["security", "find-generic-password", "-ws", "Signal Safe Storage"],
        capture_output=True,
        text=True,
    )
    if result.returncode != 0:
        raise RuntimeError(f"Failed to get keychain password: {result.stderr}")
    return result.stdout.strip()


def decrypt_key(password: str, encrypted_key_hex: str) -> str:
    """Decrypt the database key using the keychain password."""
    encrypted_key = bytes.fromhex(encrypted_key_hex)
    
    # macOS uses "v10" prefix with 1003 iterations
    prefix = b"v10"
    if not encrypted_key.startswith(prefix):
        raise ValueError(f"Unexpected key prefix, expected {prefix}")
    
    encrypted_key = encrypted_key[len(prefix):]
    
    salt = b"saltysalt"
    key = PBKDF2(
        password,
        salt=salt,
        dkLen=128 // 8,
        count=1003,
        hmac_hash_module=SHA1,
    )
    iv = b" " * 16
    decrypted = AES.new(key, AES.MODE_CBC, iv).decrypt(encrypted_key)
    return unpad(decrypted, block_size=16).decode("ascii")


def get_db_key(signal_dir: Path) -> str:
    """Get the decrypted database key."""
    config_path = signal_dir / "config.json"
    if not config_path.exists():
        raise FileNotFoundError(f"Signal config not found at {config_path}")
    
    with open(config_path) as f:
        config = json.load(f)
    
    # Old versions stored key directly
    if "key" in config:
        return config["key"]
    
    # New versions encrypt the key
    if "encryptedKey" not in config:
        raise ValueError("No key or encryptedKey found in config.json")
    
    password = get_keychain_password()
    return decrypt_key(password, config["encryptedKey"])


def open_db(signal_dir: Path, key: str):
    """Open the Signal database with the decryption key."""
    db_path = signal_dir / "sql" / "db.sqlite"
    if not db_path.exists():
        raise FileNotFoundError(f"Signal database not found at {db_path}")
    
    db = dbapi2.connect(str(db_path))
    cursor = db.cursor()
    
    # Configure SQLCipher
    cursor.execute(f"PRAGMA KEY = \"x'{key}'\"")
    cursor.execute("PRAGMA cipher_page_size = 4096")
    cursor.execute("PRAGMA kdf_iter = 64000")
    cursor.execute("PRAGMA cipher_hmac_algorithm = HMAC_SHA512")
    cursor.execute("PRAGMA cipher_kdf_algorithm = PBKDF2_HMAC_SHA512")
    
    return db


def get_contacts(db) -> dict:
    """Load all contacts/conversations."""
    cursor = db.cursor()
    cursor.execute("""
        SELECT id, serviceId, e164, name, profileName, type, members
        FROM conversations
    """)
    
    contacts = {}
    for row in cursor:
        cid, service_id, phone, name, profile_name, conv_type, members = row
        display_name = name or profile_name or phone or cid
        contacts[cid] = {
            "id": cid,
            "serviceId": service_id,
            "phone": phone,
            "name": name,
            "profileName": profile_name,
            "displayName": display_name,
            "isGroup": conv_type == "group",
            "members": members.split(" ") if members else [],
        }
    
    # Build serviceId -> contact mapping for resolving senders
    by_service_id = {c["serviceId"]: c for c in contacts.values() if c["serviceId"]}
    
    return contacts, by_service_id


def list_chats(db) -> None:
    """List all available chats."""
    contacts, _ = get_contacts(db)
    
    # Sort by name
    sorted_contacts = sorted(
        contacts.values(),
        key=lambda c: (c["displayName"] or "").lower()
    )
    
    for contact in sorted_contacts:
        name = contact["displayName"]
        chat_type = "group" if contact["isGroup"] else "direct"
        print(f"{name} ({chat_type})")


def dump_day(
    db,
    date_str: str,
    chat_filter: str = None,
) -> None:
    """Dump all messages from a specific day."""
    contacts, by_service_id = get_contacts(db)
    
    # Parse date
    try:
        target_date = datetime.strptime(date_str, "%Y-%m-%d").date()
    except ValueError:
        print(f"Invalid date format: {date_str}. Use YYYY-MM-DD", file=sys.stderr)
        return
    
    # Calculate timestamp range (start of day to end of day in ms)
    start_ts = int(datetime(target_date.year, target_date.month, target_date.day, 0, 0, 0).timestamp() * 1000)
    end_ts = int(datetime(target_date.year, target_date.month, target_date.day, 23, 59, 59, 999999).timestamp() * 1000)
    
    # Find matching conversation IDs if chat filter specified
    target_conv_ids = None
    if chat_filter:
        chat_filter_lower = chat_filter.lower()
        target_conv_ids = [
            cid for cid, c in contacts.items()
            if chat_filter_lower in (c["displayName"] or "").lower()
        ]
        if not target_conv_ids:
            print(f"No chats matching '{chat_filter}'", file=sys.stderr)
            return
    
    # Build the query
    cursor = db.cursor()
    
    sql = """
        SELECT
            conversationId,
            body,
            sent_at,
            sourceServiceId,
            type,
            json
        FROM messages
        WHERE body IS NOT NULL AND body != ''
        AND sent_at >= ? AND sent_at <= ?
    """
    
    params = [start_ts, end_ts]
    
    if target_conv_ids:
        placeholders = ",".join("?" * len(target_conv_ids))
        sql += f" AND conversationId IN ({placeholders})"
        params.extend(target_conv_ids)
    
    sql += " ORDER BY sent_at ASC"
    
    cursor.execute(sql, params)
    
    # Print messages
    msg_count = 0
    for row in cursor:
        conv_id, body, sent_at, source_service_id, msg_type, msg_json = row
        
        if not body:
            continue
        
        # Parse timestamp
        ts = datetime.fromtimestamp(sent_at / 1000) if sent_at else None
        ts_str = ts.strftime("%Y-%m-%d %H:%M:%S") if ts else "unknown"
        
        # Get sender name
        contact = contacts.get(conv_id, {})
        chat_name = contact.get("displayName", conv_id)
        
        if source_service_id and source_service_id in by_service_id:
            sender = by_service_id[source_service_id].get("displayName", "Unknown")
        elif msg_type == "outgoing":
            sender = "Me"
        else:
            sender = "Unknown"
        
        print(f"[{ts_str}] {chat_name} | {sender}: {body}")
        msg_count += 1
    
    if msg_count == 0:
        print(f"No messages found for {date_str}", file=sys.stderr)
    else:
        print(f"\n-- {msg_count} message(s) from {date_str} --", file=sys.stderr)


def search_messages(
    db,
    query: str,
    chat_filter: str = None,
    context_lines: int = 0,
    max_results: int = None,
    case_sensitive: bool = False,
    literal: bool = False,
) -> None:
    """Search messages matching query."""
    contacts, by_service_id = get_contacts(db)
    
    # Find matching conversation IDs if chat filter specified
    target_conv_ids = None
    if chat_filter:
        chat_filter_lower = chat_filter.lower()
        target_conv_ids = [
            cid for cid, c in contacts.items()
            if chat_filter_lower in (c["displayName"] or "").lower()
        ]
        if not target_conv_ids:
            print(f"No chats matching '{chat_filter}'", file=sys.stderr)
            return
    
    # Build the query
    cursor = db.cursor()
    
    sql = """
        SELECT
            conversationId,
            body,
            sent_at,
            sourceServiceId,
            type,
            json
        FROM messages
        WHERE body IS NOT NULL AND body != ''
    """
    
    if target_conv_ids:
        placeholders = ",".join("?" * len(target_conv_ids))
        sql += f" AND conversationId IN ({placeholders})"
    
    sql += " ORDER BY sent_at ASC"
    
    params = target_conv_ids if target_conv_ids else []
    cursor.execute(sql, params)
    
    # Compile search pattern
    if literal:
        if case_sensitive:
            pattern = re.compile(re.escape(query))
        else:
            pattern = re.compile(re.escape(query), re.IGNORECASE)
    else:
        flags = 0 if case_sensitive else re.IGNORECASE
        pattern = re.compile(query, flags)
    
    # Collect all messages for context support
    all_messages = []
    for row in cursor:
        conv_id, body, sent_at, source_service_id, msg_type, msg_json = row
        
        if not body:
            continue
        
        # Parse timestamp
        ts = datetime.fromtimestamp(sent_at / 1000) if sent_at else None
        ts_str = ts.strftime("%Y-%m-%d %H:%M:%S") if ts else "unknown"
        
        # Get sender name
        contact = contacts.get(conv_id, {})
        chat_name = contact.get("displayName", conv_id)
        
        if source_service_id and source_service_id in by_service_id:
            sender = by_service_id[source_service_id].get("displayName", "Unknown")
        elif msg_type == "outgoing":
            sender = "Me"
        else:
            sender = "Unknown"
        
        all_messages.append({
            "chat": chat_name,
            "sender": sender,
            "timestamp": ts_str,
            "body": body,
            "matches": bool(pattern.search(body)),
        })
    
    # Find and print matches with context
    match_count = 0
    printed_indices = set()
    
    for i, msg in enumerate(all_messages):
        if not msg["matches"]:
            continue
        
        if max_results and match_count >= max_results:
            break
        
        # Calculate context range
        start = max(0, i - context_lines)
        end = min(len(all_messages), i + context_lines + 1)
        
        # Print separator between match groups
        if match_count > 0 and start > max(printed_indices) + 1:
            print("--")
        
        # Print context and match
        for j in range(start, end):
            if j in printed_indices:
                continue
            printed_indices.add(j)
            
            m = all_messages[j]
            prefix = ">" if m["matches"] else " "
            print(f"{prefix} [{m['timestamp']}] {m['chat']} | {m['sender']}: {m['body']}")
        
        match_count += 1
    
    if match_count == 0:
        print(f"No messages matching '{query}'", file=sys.stderr)
    else:
        print(f"\n-- {match_count} match(es) found --", file=sys.stderr)


def main():
    parser = argparse.ArgumentParser(
        description="Fast search against Signal Desktop's encrypted database"
    )
    subparsers = parser.add_subparsers(dest="command", required=True)
    
    # List command
    list_parser = subparsers.add_parser("list", help="List available chats")
    
    # Dump command
    dump_parser = subparsers.add_parser("dump", help="Dump all messages from a specific day")
    dump_parser.add_argument(
        "--date", "-d",
        required=True,
        help="Date to dump (YYYY-MM-DD format)"
    )
    dump_parser.add_argument(
        "--chat", "-c",
        help="Filter to chats containing this string"
    )
    
    # Search command
    search_parser = subparsers.add_parser("search", help="Search messages")
    search_parser.add_argument(
        "--query", "-q",
        required=True,
        help="Search pattern (regex by default)"
    )
    search_parser.add_argument(
        "--chat", "-c",
        help="Filter to chats containing this string"
    )
    search_parser.add_argument(
        "--context", "-C",
        type=int,
        default=0,
        help="Show N messages of context before and after matches"
    )
    search_parser.add_argument(
        "--max-count", "-m",
        type=int,
        help="Stop after N matches"
    )
    search_parser.add_argument(
        "--case-sensitive", "-s",
        action="store_true",
        help="Case-sensitive search"
    )
    search_parser.add_argument(
        "--literal", "-F",
        action="store_true",
        help="Treat query as literal string, not regex"
    )
    
    args = parser.parse_args()
    
    # Get Signal directory
    signal_dir = get_signal_dir()
    if not signal_dir.exists():
        print(f"Signal Desktop not found at {signal_dir}", file=sys.stderr)
        sys.exit(1)
    
    # Get database key
    try:
        key = get_db_key(signal_dir)
    except Exception as e:
        print(f"Failed to get database key: {e}", file=sys.stderr)
        sys.exit(1)
    
    # Open database
    try:
        db = open_db(signal_dir, key)
    except Exception as e:
        print(f"Failed to open database: {e}", file=sys.stderr)
        sys.exit(1)
    
    # Execute command
    if args.command == "list":
        list_chats(db)
    elif args.command == "dump":
        dump_day(
            db,
            date_str=args.date,
            chat_filter=args.chat,
        )
    elif args.command == "search":
        search_messages(
            db,
            query=args.query,
            chat_filter=args.chat,
            context_lines=args.context,
            max_results=args.max_count,
            case_sensitive=args.case_sensitive,
            literal=args.literal,
        )
    
    db.close()


if __name__ == "__main__":
    main()
