#!/usr/bin/env python3
"""
search.py - Hybrid Grep-to-AST Retrieval Tool
Locates matches with ripgrep, then expands them into full code blocks using tree-sitter
or pre-generated _MAP.md files from mapping-codebases.

v0.3.0: Added --use-maps mode to eliminate redundant tree-sitter parsing (#276).
When _MAP.md files exist (generated by mapping-codebases), tree-sitter is not needed
at runtime. Maps provide symbol names and start line numbers; the script reads source
files to extract function/class bodies using line-based heuristics.
"""

import subprocess
import json
import sys
import os
import re
from dataclasses import dataclass
from typing import List, Optional, Dict, Set
from pathlib import Path
import argparse

# tree-sitter is optional when using --use-maps mode
_tree_sitter_available = False
try:
    from tree_sitter_language_pack import get_parser
    _tree_sitter_available = True
except ImportError:
    pass

# Configuration: Relevant node types for "context" in each language
# These are the nodes we want to capture (Functions, Classes, etc.)
RELEVANT_NODE_TYPES = {
    'python': {'function_definition', 'class_definition'},
    'javascript': {'function_declaration', 'class_declaration', 'method_definition', 'variable_declarator', 'arrow_function', 'function_expression'},
    'typescript': {'function_declaration', 'class_declaration', 'method_definition', 'variable_declarator', 'arrow_function', 'function_expression', 'interface_declaration', 'enum_declaration'},
    'go': {'function_declaration', 'type_declaration', 'method_declaration'},
    'rust': {'function_item', 'struct_item', 'impl_item', 'trait_item', 'enum_item', 'mod_item'},
    'ruby': {'method', 'class', 'module'},
    'java': {'class_declaration', 'interface_declaration', 'method_declaration', 'constructor_declaration'},
    'c': {'function_definition', 'struct_specifier'},
    'cpp': {'function_definition', 'class_specifier', 'struct_specifier'},
    'php': {'function_definition', 'class_declaration', 'method_declaration'},
    'c_sharp': {'class_declaration', 'method_declaration', 'interface_declaration', 'struct_declaration', 'enum_declaration', 'namespace_declaration'},
}

# Mapping extensions to languages (consistent with tree-sitter-language-pack)
EXT_TO_LANG = {
    '.py': 'python',
    '.js': 'javascript',
    '.jsx': 'javascript',
    '.ts': 'typescript',
    '.tsx': 'typescript',
    '.go': 'go',
    '.rs': 'rust',
    '.rb': 'ruby',
    '.java': 'java',
    '.c': 'c',
    '.h': 'c',
    '.cpp': 'cpp',
    '.hpp': 'cpp',
    '.cc': 'cpp',
    '.php': 'php',
    '.cs': 'c_sharp',
}

@dataclass
class CodeContext:
    file_path: str
    start_line: int
    end_line: int
    match_line: int
    node_type: str
    name: str
    source: str
    language: str
    signature: Optional[str] = None  # For progressive disclosure


# --- Map-based context expansion (#276) ---

@dataclass
class MapSymbol:
    """A symbol parsed from a _MAP.md file."""
    name: str
    kind: str  # 'C' (class), 'f' (function), 'm' (method), 'I' (interface)
    line: int  # 1-indexed start line
    signature: Optional[str] = None
    parent: Optional[str] = None  # Parent class name for methods


# Map kind codes to node_type-like names
_KIND_TO_NODE_TYPE = {
    'C': 'class_definition',
    'f': 'function_definition',
    'm': 'method_definition',
    'I': 'interface_declaration',
}


def _parse_map_file(map_path: str) -> Dict[str, List[MapSymbol]]:
    """Parse a _MAP.md file and return symbols indexed by filename.

    Returns:
        Dict mapping filename -> list of MapSymbol, sorted by line number.
    """
    symbols_by_file = {}
    current_file = None
    current_class = None
    # State for multi-line signature accumulation
    pending: Optional[Dict] = None  # {indent, name, kind, sig_start}

    try:
        with open(map_path, 'r') as f:
            for line in f:
                line = line.rstrip()

                # Detect file headers: ### filename.py
                file_match = re.match(r'^### (.+\.\w+)$', line)
                if file_match:
                    current_file = file_match.group(1)
                    current_class = None
                    pending = None
                    if current_file not in symbols_by_file:
                        symbols_by_file[current_file] = []
                    continue

                if not current_file:
                    continue

                # If accumulating a multi-line signature, look for the closing ` :LINE
                if pending is not None:
                    end_match = re.search(r'`\s*:(\d+)\s*$', line)
                    if end_match:
                        # Extract the last part of the signature before the closing backtick
                        sig_end = line[:end_match.start()]
                        full_sig = pending['sig'] + sig_end.strip()
                        line_num = int(end_match.group(1))
                        is_method = len(pending['indent']) >= 2
                        parent = current_class if is_method else None

                        sym = MapSymbol(
                            name=pending['name'],
                            kind=pending['kind'],
                            line=line_num,
                            signature=full_sig,
                            parent=parent
                        )
                        symbols_by_file[current_file].append(sym)

                        if pending['kind'] == 'C':
                            current_class = pending['name']
                        elif not is_method:
                            current_class = None

                        pending = None
                    else:
                        # Accumulate continuation line into signature
                        pending['sig'] += line.strip()
                    continue

                # Parse symbol lines: - **Name** (K) `(signature)` :LINE
                # Or indented methods:   - **name** (m) `(signature)` :LINE
                sym_match = re.match(
                    r'^(\s*)- \*\*(\w+)\*\*\s+\((\w)\)\s*(?:`([^`]*)`\s*)?:(\d+)',
                    line
                )
                if sym_match:
                    indent, name, kind, sig, line_num = sym_match.groups()
                    is_method = len(indent) >= 2
                    parent = current_class if is_method else None

                    sym = MapSymbol(
                        name=name,
                        kind=kind,
                        line=int(line_num),
                        signature=sig,
                        parent=parent
                    )
                    symbols_by_file[current_file].append(sym)

                    # Track current class for method parenting
                    if kind == 'C':
                        current_class = name
                    elif not is_method:
                        current_class = None
                    continue

                # Check for start of multi-line signature:
                # - **name** (kind) `(sig_start...   (opening backtick, no closing `:LINE`)
                multi_start = re.match(
                    r'^(\s*)- \*\*(\w+)\*\*\s+\((\w)\)\s*`(.*)$',
                    line
                )
                if multi_start:
                    indent, name, kind, sig_start = multi_start.groups()
                    pending = {
                        'indent': indent,
                        'name': name,
                        'kind': kind,
                        'sig': sig_start,
                    }

    except (FileNotFoundError, PermissionError):
        pass

    # Sort symbols by line number within each file
    for fname in symbols_by_file:
        symbols_by_file[fname].sort(key=lambda s: s.line)

    return symbols_by_file


def _find_map_for_file(file_path: str, search_root: str) -> Optional[str]:
    """Find the _MAP.md file that covers a given source file.

    Looks in the same directory as the source file.
    """
    directory = os.path.dirname(file_path)
    map_path = os.path.join(directory, '_MAP.md')
    if os.path.exists(map_path):
        return map_path
    return None


def _expand_from_maps(file_path: str, line_number: int, search_root: str,
                      signatures_only: bool = True) -> Optional[CodeContext]:
    """Expand a match line into its containing function/class using _MAP.md data.

    Uses pre-generated maps instead of tree-sitter for structural context.
    Reads the source file to extract the actual code between symbol boundaries.

    Args:
        file_path: Absolute path to the matched file
        line_number: 0-indexed line number of the match
        search_root: Root directory of the search (for finding maps)
        signatures_only: If True, return only the signature

    Returns:
        CodeContext or None if no structural context found
    """
    map_path = _find_map_for_file(file_path, search_root)
    if not map_path:
        return None

    filename = os.path.basename(file_path)
    lang = EXT_TO_LANG.get(os.path.splitext(file_path)[1].lower())
    if not lang:
        return None

    symbols_by_file = _parse_map_file(map_path)
    symbols = symbols_by_file.get(filename, [])
    if not symbols:
        return None

    # Find the symbol whose start line is <= match line (1-indexed)
    match_line_1indexed = line_number + 1
    containing_symbol = None
    for sym in symbols:
        if sym.line <= match_line_1indexed:
            containing_symbol = sym
        else:
            break

    if not containing_symbol:
        return None

    # Determine end line: next symbol's start line - 1, or EOF
    sym_idx = symbols.index(containing_symbol)
    next_line = None
    for i in range(sym_idx + 1, len(symbols)):
        # Skip child methods of the same class
        if containing_symbol.kind == 'C' and symbols[i].parent == containing_symbol.name:
            continue
        next_line = symbols[i].line
        break

    # Read source file
    try:
        with open(file_path, 'r') as f:
            lines = f.readlines()
    except (FileNotFoundError, PermissionError):
        return None

    start_line = containing_symbol.line  # 1-indexed
    if next_line:
        end_line = next_line - 1
    else:
        end_line = len(lines)

    # Trim trailing blank lines
    while end_line > start_line and not lines[end_line - 1].strip():
        end_line -= 1

    source = ''.join(lines[start_line - 1:end_line])

    # Build signature if requested
    signature = None
    if signatures_only:
        sig_parts = containing_symbol.signature or ''
        node_type = _KIND_TO_NODE_TYPE.get(containing_symbol.kind, containing_symbol.kind)
        if lang == 'python':
            if containing_symbol.kind == 'C':
                signature = f"class {containing_symbol.name}:\n    ..."
            elif containing_symbol.kind in ('f', 'm'):
                signature = f"def {containing_symbol.name}{sig_parts}:\n    ..."
        elif lang in ('javascript', 'typescript'):
            if containing_symbol.kind == 'C':
                signature = f"class {containing_symbol.name} {{ ... }}"
            elif containing_symbol.kind in ('f', 'm'):
                signature = f"function {containing_symbol.name}{sig_parts} {{ ... }}"
        elif lang == 'go':
            signature = f"func {containing_symbol.name}{sig_parts} {{ ... }}"
        else:
            signature = f"{containing_symbol.name}{sig_parts}"

    node_type = _KIND_TO_NODE_TYPE.get(containing_symbol.kind, containing_symbol.kind)
    display_name = containing_symbol.name
    if containing_symbol.parent:
        display_name = f"{containing_symbol.parent}.{containing_symbol.name}"

    return CodeContext(
        file_path=file_path,
        start_line=start_line,
        end_line=end_line,
        match_line=match_line_1indexed,
        node_type=node_type,
        name=display_name,
        source=source,
        language=lang,
        signature=signature
    )


# --- Tree-sitter based context expansion (original) ---

class HybridRetriever:
    def __init__(self, use_maps: bool = False, search_root: str = "."):
        self.parsers = {}
        self.use_maps = use_maps
        self.search_root = os.path.abspath(search_root)

    def _get_language(self, file_path: str) -> Optional[str]:
        ext = os.path.splitext(file_path)[1].lower()
        return EXT_TO_LANG.get(ext)

    def _get_parser(self, language: str):
        if language not in self.parsers:
            try:
                self.parsers[language] = get_parser(language)
            except Exception as e:
                print(f"Warning: Could not load parser for {language}: {e}", file=sys.stderr)
                self.parsers[language] = None
        return self.parsers[language]

    def _run_ripgrep(self, query: str, path: str, glob: Optional[str] = None) -> List[Dict]:
        """
        Phase 1: The Dragnet. Fast, text-based search.
        """
        command = [
            "rg", "--json", "-e", query,
            "--path-separator", "/",
            path
        ]
        if glob:
            command.extend(["--glob", glob])

        try:
            result = subprocess.run(
                command, capture_output=True, text=True, check=False
            )
        except FileNotFoundError:
            # Auto-install ripgrep and retry
            print("ripgrep not found, installing...", file=sys.stderr)
            install_result = subprocess.run(
                ["apt-get", "install", "-y", "-qq", "ripgrep"],
                capture_output=True, text=True
            )
            if install_result.returncode != 0:
                print(f"Error: Failed to install ripgrep: {install_result.stderr}", file=sys.stderr)
                sys.exit(1)
            # Retry the search
            result = subprocess.run(
                command, capture_output=True, text=True, check=False
            )
        except Exception as e:
            print(f"Error running ripgrep: {e}", file=sys.stderr)
            return []

        matches = []
        for line in result.stdout.splitlines():
            try:
                data = json.loads(line)
                if data["type"] == "match":
                    matches.append(data["data"])
            except json.JSONDecodeError:
                continue
        return matches

    def _get_node_name(self, node, source_bytes: bytes) -> str:
        """Attempt to extract a name from a node."""
        # Generic heuristic: look for "name" or "identifier" child
        child = node.child_by_field_name("name")
        if child:
            return source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace')

        # Fallback: scan children for identifier
        for child in node.children:
            if child.type in ('identifier', 'type_identifier', 'property_identifier', 'name'):
                 return source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace')

        return "anonymous"

    def _extract_signature(self, node, source_bytes: bytes, language: str) -> Optional[str]:
        """
        Extract just the signature from a function/class/method node.
        Returns the declaration line(s) including docstring but excluding body.
        """
        if language == 'python':
            return self._extract_python_signature(node, source_bytes)
        elif language in ('javascript', 'typescript'):
            return self._extract_js_signature(node, source_bytes)
        elif language == 'go':
            return self._extract_go_signature(node, source_bytes)
        # Add more language-specific extractors as needed
        return None

    def _extract_python_signature(self, node, source_bytes: bytes) -> Optional[str]:
        """Extract Python function/class signature including docstring."""
        parts = []
        docstring = None

        if node.type == 'class_definition':
            # Get everything up to the body (class keyword, name, bases, colon)
            for child in node.children:
                if child.type == 'block':
                    # Extract docstring if present (first child after block start)
                    for block_child in child.children:
                        if block_child.type == 'string':
                            docstring = source_bytes[block_child.start_byte:block_child.end_byte].decode('utf-8', errors='replace')
                            break
                    break
                else:
                    parts.append(source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace'))

            if docstring:
                parts.append(f"\n    {docstring}")
            parts.append("\n    ...")
            return ''.join(parts)

        elif node.type == 'function_definition':
            # Get everything before the body  (def keyword, name, params, colon)
            for child in node.children:
                if child.type == 'block':
                    # Extract docstring if present
                    for block_child in child.children:
                        if block_child.type == 'string':
                            docstring = source_bytes[block_child.start_byte:block_child.end_byte].decode('utf-8', errors='replace')
                            break
                    break
                else:
                    parts.append(source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace'))

            if docstring:
                parts.append(f"\n    {docstring}")
            parts.append("\n    ...")
            return ''.join(parts)

        return None

    def _extract_js_signature(self, node, source_bytes: bytes) -> Optional[str]:
        """Extract JavaScript/TypeScript function/class signature."""
        # For JS/TS, extract up to the opening brace
        parts = []
        for child in node.children:
            if child.type in ('statement_block', 'class_body'):
                parts.append(' { ... }')
                break
            else:
                parts.append(source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace'))
        return ''.join(parts) if parts else None

    def _extract_go_signature(self, node, source_bytes: bytes) -> Optional[str]:
        """Extract Go function signature."""
        parts = []
        for child in node.children:
            if child.type == 'block':
                parts.append(' { ... }')
                break
            else:
                parts.append(source_bytes[child.start_byte:child.end_byte].decode('utf-8', errors='replace'))
        return ''.join(parts) if parts else None

    def _get_node_at_line(self, root_node, line_number: int, language: str):
        """
        Finds the smallest relevant node (Function/Class) containing the line number.
        """
        target_node = None
        relevant_types = RELEVANT_NODE_TYPES.get(language, set())

        cursor = root_node.walk()
        visited_children = False

        while True:
            # Check if current node covers the line
            if cursor.node.start_point[0] <= line_number and cursor.node.end_point[0] >= line_number:

                # If this node is one of our target types, it's a candidate
                # ONLY update if we are visiting for the first time (going down)
                # This prevents backtracking from overwriting a more specific match with a parent.
                if not visited_children and cursor.node.type in relevant_types:
                    target_node = cursor.node

                if not visited_children:
                    if cursor.goto_first_child():
                        continue

            visited_children = True
            if cursor.goto_next_sibling():
                visited_children = False
                continue

            if cursor.goto_parent():
                visited_children = True
                continue
            else:
                break

        return target_node

    def _expand_context(self, file_path: str, line_number: int, signatures_only: bool = True) -> Optional[CodeContext]:
        """
        Phase 2: The Scalpel. Syntax-aware context expansion.
        """
        lang = self._get_language(file_path)
        if not lang:
            return None # Skip unsupported languages

        parser = self._get_parser(lang)
        if not parser:
            return None

        try:
            with open(file_path, "rb") as f:
                source_bytes = f.read()

            tree = parser.parse(source_bytes)
            node = self._get_node_at_line(tree.root_node, line_number, lang)

            if not node:
                return None

            node_name = self._get_node_name(node, source_bytes)
            full_source = source_bytes[node.start_byte:node.end_byte].decode('utf-8', errors='replace')

            # Extract signature if requested
            signature = None
            if signatures_only:
                signature = self._extract_signature(node, source_bytes, lang)

            return CodeContext(
                file_path=file_path,
                start_line=node.start_point[0] + 1,
                end_line=node.end_point[0] + 1,
                match_line=line_number + 1,
                node_type=node.type,
                name=node_name,
                source=full_source,
                language=lang,
                signature=signature
            )

        except Exception as e:
            # print(f"Error parsing {file_path}: {e}", file=sys.stderr)
            return None

    def search(self, query: str, path: str = ".", glob: Optional[str] = None, signatures_only: bool = True) -> List[CodeContext]:
        raw_matches = self._run_ripgrep(query, path, glob)
        contexts = []
        seen_ranges = set() # (file_path, start_line, end_line)

        for match in raw_matches:
            file_path = match["path"]["text"]
            line_num = match["line_number"] - 1 # 0-indexed for Tree-sitter

            # Choose expansion strategy based on mode
            if self.use_maps:
                context = _expand_from_maps(file_path, line_num, self.search_root, signatures_only)
            else:
                if not _tree_sitter_available:
                    print("Error: tree-sitter-language-pack not available and --use-maps not specified.", file=sys.stderr)
                    print("Either install tree-sitter-language-pack or use --use-maps with pre-generated _MAP.md files.", file=sys.stderr)
                    sys.exit(1)
                context = self._expand_context(file_path, line_num, signatures_only)

            if context:
                dedup_key = (context.file_path, context.start_line, context.end_line)
                if dedup_key not in seen_ranges:
                    contexts.append(context)
                    seen_ranges.add(dedup_key)

        return contexts

def main():
    parser = argparse.ArgumentParser(description="Hybrid Grep-to-AST Code Search")
    parser.add_argument("query", help="Search query (passed to ripgrep)")
    parser.add_argument("path", nargs="?", default=".", help="Root directory to search")
    parser.add_argument("--glob", help="Glob pattern for filtering files (e.g. '*.py')")
    parser.add_argument("--json", action="store_true", help="Output JSON instead of Markdown")
    parser.add_argument("--expand-full", action="store_true", help="Show full implementations instead of signatures")
    parser.add_argument("--use-maps", action="store_true",
                        help="Use pre-generated _MAP.md files from mapping-codebases instead of tree-sitter. "
                             "Eliminates need for tree-sitter-language-pack at runtime (#276).")

    args = parser.parse_args()

    retriever = HybridRetriever(use_maps=args.use_maps, search_root=args.path)
    results = retriever.search(args.query, args.path, args.glob, signatures_only=not args.expand_full)

    if args.json:
        output = []
        for res in results:
            output.append({
                "file": res.file_path,
                "name": res.name,
                "type": res.node_type,
                "start_line": res.start_line,
                "end_line": res.end_line,
                "source": res.source,
                "signature": res.signature
            })
        print(json.dumps(output, indent=2))
    else:
        if not results:
            print("No structural matches found.")
            return

        print(f"Found {len(results)} matches for '{args.query}':\n")
        for res in results:
            print(f"### {res.file_path} matches `{args.query}`")
            print(f"**{res.node_type}**: `{res.name}` (Lines {res.start_line}-{res.end_line})")

            # Use signature if available and not expanding full, otherwise full source
            display_content = res.signature if (res.signature and not args.expand_full) else res.source

            print(f"```{res.language}")
            print(display_content)
            print("```\n")

if __name__ == "__main__":
    main()
