"""
QA Copilot — Platform Monitoring
SQLite-backed metrics: request logs, error logs, system health snapshots.
"""

import os
import sqlite3
import threading
import time
import resource
import datetime
import json
import urllib.request
import urllib.error
from pathlib import Path

SCRIPT_DIR = Path(__file__).parent
METRICS_DB = SCRIPT_DIR / "metrics.db"

_LOG_RETENTION_DAYS = int(os.environ.get("QA_LOG_RETENTION_DAYS", "30"))
_HEALTH_INTERVAL = int(os.environ.get("QA_HEALTH_INTERVAL", "60"))
_CLEANUP_INTERVAL = int(os.environ.get("QA_CLEANUP_INTERVAL", "3600"))


def _now_iso() -> str:
    return datetime.datetime.now(datetime.UTC).isoformat().replace("+00:00", "Z")


def _connect(db_path: str | None = None) -> sqlite3.Connection:
    conn = sqlite3.connect(str(db_path or METRICS_DB))
    conn.execute("PRAGMA journal_mode=WAL")
    conn.row_factory = sqlite3.Row
    return conn


def init_metrics_db(db_path: str | None = None) -> None:
    """Create metrics.db tables and seed default config (idempotent)."""
    conn = _connect(db_path)
    try:
        conn.execute("""
            CREATE TABLE IF NOT EXISTS request_logs (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp   TEXT NOT NULL,
                method      TEXT NOT NULL,
                path        TEXT NOT NULL,
                status_code INTEGER NOT NULL,
                duration_ms REAL NOT NULL,
                user_id     TEXT,
                username    TEXT,
                tokens_used INTEGER,
                model       TEXT,
                error       TEXT
            )
        """)
        conn.execute("""
            CREATE TABLE IF NOT EXISTS system_snapshots (
                id                 INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp          TEXT NOT NULL,
                ollama_up          INTEGER NOT NULL,
                rag_up             INTEGER NOT NULL,
                active_connections INTEGER NOT NULL,
                memory_mb          REAL NOT NULL,
                cpu_percent        REAL NOT NULL
            )
        """)
        conn.execute("""
            CREATE TABLE IF NOT EXISTS error_logs (
                id        INTEGER PRIMARY KEY AUTOINCREMENT,
                timestamp TEXT NOT NULL,
                level     TEXT NOT NULL,
                source    TEXT NOT NULL,
                message   TEXT NOT NULL,
                path      TEXT,
                user_id   TEXT
            )
        """)
        conn.execute("""
            CREATE TABLE IF NOT EXISTS monitoring_config (
                key   TEXT PRIMARY KEY,
                value TEXT NOT NULL
            )
        """)
        conn.execute("CREATE INDEX IF NOT EXISTS idx_request_logs_timestamp ON request_logs(timestamp)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_request_logs_path ON request_logs(path)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_request_logs_user ON request_logs(user_id)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_snapshots_timestamp ON system_snapshots(timestamp)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_error_logs_timestamp ON error_logs(timestamp)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_error_logs_level ON error_logs(level)")
        conn.execute("CREATE INDEX IF NOT EXISTS idx_error_logs_source ON error_logs(source)")
        conn.execute(
            "INSERT OR IGNORE INTO monitoring_config (key, value) VALUES (?, ?)",
            ("retention_days", str(_LOG_RETENTION_DAYS)),
        )
        conn.commit()
    finally:
        conn.close()


# ---------------------------------------------------------------------------
# SSE subscriber management
# ---------------------------------------------------------------------------

_log_subscribers: list = []
_log_sub_lock = threading.Lock()


def add_log_subscriber(wfile) -> None:
    with _log_sub_lock:
        _log_subscribers.append(wfile)


def remove_log_subscriber(wfile) -> None:
    with _log_sub_lock:
        try:
            _log_subscribers.remove(wfile)
        except ValueError:
            pass


def _broadcast_log_entry(entry: dict) -> None:
    payload = "data: " + json.dumps(entry) + "\n\n"
    dead = []
    with _log_sub_lock:
        subscribers = list(_log_subscribers)
    for wfile in subscribers:
        try:
            wfile.write(payload.encode())
            wfile.flush()
        except Exception:
            dead.append(wfile)
    for wfile in dead:
        remove_log_subscriber(wfile)


# ---------------------------------------------------------------------------
# Request log CRUD
# ---------------------------------------------------------------------------

def insert_request_log(
    *,
    db_path=None,
    method: str,
    path: str,
    status_code: int,
    duration_ms: float,
    user_id: str | None = None,
    username: str | None = None,
    tokens_used: int | None = None,
    model: str | None = None,
    error: str | None = None,
) -> None:
    timestamp = _now_iso()
    conn = _connect(db_path)
    try:
        conn.execute(
            """
            INSERT INTO request_logs
                (timestamp, method, path, status_code, duration_ms,
                 user_id, username, tokens_used, model, error)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """,
            (timestamp, method, path, status_code, duration_ms,
             user_id, username, tokens_used, model, error),
        )
        conn.commit()
    finally:
        conn.close()
    entry = {
        "timestamp": timestamp,
        "method": method,
        "path": path,
        "status_code": status_code,
        "duration_ms": duration_ms,
        "user_id": user_id,
        "username": username,
        "tokens_used": tokens_used,
        "model": model,
        "error": error,
    }
    _broadcast_log_entry(entry)


def query_request_logs(
    *,
    db_path=None,
    page: int = 1,
    limit: int = 100,
    path_filter: str | None = None,
    user_filter: str | None = None,
    status_min: int | None = None,
    status_max: int | None = None,
    date_from: str | None = None,
    date_to: str | None = None,
) -> dict:
    conditions = []
    params: list = []

    if path_filter is not None:
        conditions.append("path = ?")
        params.append(path_filter)
    if user_filter is not None:
        conditions.append("user_id = ?")
        params.append(user_filter)
    if status_min is not None:
        conditions.append("status_code >= ?")
        params.append(status_min)
    if status_max is not None:
        conditions.append("status_code <= ?")
        params.append(status_max)
    if date_from is not None:
        conditions.append("timestamp >= ?")
        params.append(date_from)
    if date_to is not None:
        conditions.append("timestamp <= ?")
        params.append(date_to)

    where = ("WHERE " + " AND ".join(conditions)) if conditions else ""

    conn = _connect(db_path)
    try:
        total = conn.execute(
            f"SELECT COUNT(*) FROM request_logs {where}", params
        ).fetchone()[0]

        offset = (page - 1) * limit
        rows = conn.execute(
            f"SELECT * FROM request_logs {where} ORDER BY id DESC LIMIT ? OFFSET ?",
            params + [limit, offset],
        ).fetchall()
    finally:
        conn.close()

    pages = max(1, (total + limit - 1) // limit)
    return {
        "items": [dict(r) for r in rows],
        "total": total,
        "page": page,
        "pages": pages,
    }


# ---------------------------------------------------------------------------
# Error log CRUD
# ---------------------------------------------------------------------------

_LEVEL_EMOJI = {
    "WARN": "\u26a0",      # ⚠
    "WARNING": "\u26a0",
    "ERROR": "\u2717",     # ✗
    "CRITICAL": "\U0001f534",  # 🔴
}


def insert_error_log(
    *,
    db_path=None,
    level: str,
    source: str,
    message: str,
    path: str | None = None,
    user_id: str | None = None,
) -> None:
    timestamp = _now_iso()
    emoji = _LEVEL_EMOJI.get(level.upper(), "\u2139")
    print(f"{emoji} [{level}] {source}: {message}")
    conn = _connect(db_path)
    try:
        conn.execute(
            """
            INSERT INTO error_logs (timestamp, level, source, message, path, user_id)
            VALUES (?, ?, ?, ?, ?, ?)
            """,
            (timestamp, level, source, message, path, user_id),
        )
        conn.commit()
    finally:
        conn.close()


def query_error_logs(
    *,
    db_path=None,
    page: int = 1,
    limit: int = 50,
    level_filter: str | None = None,
    source_filter: str | None = None,
    date_from: str | None = None,
    date_to: str | None = None,
) -> dict:
    conditions = []
    params: list = []

    if level_filter is not None:
        conditions.append("level = ?")
        params.append(level_filter)
    if source_filter is not None:
        conditions.append("source = ?")
        params.append(source_filter)
    if date_from is not None:
        conditions.append("timestamp >= ?")
        params.append(date_from)
    if date_to is not None:
        conditions.append("timestamp <= ?")
        params.append(date_to)

    where = ("WHERE " + " AND ".join(conditions)) if conditions else ""

    conn = _connect(db_path)
    try:
        total = conn.execute(
            f"SELECT COUNT(*) FROM error_logs {where}", params
        ).fetchone()[0]

        offset = (page - 1) * limit
        rows = conn.execute(
            f"SELECT * FROM error_logs {where} ORDER BY id DESC LIMIT ? OFFSET ?",
            params + [limit, offset],
        ).fetchall()
    finally:
        conn.close()

    pages = max(1, (total + limit - 1) // limit)
    return {
        "items": [dict(r) for r in rows],
        "total": total,
        "page": page,
        "pages": pages,
    }


# ---------------------------------------------------------------------------
# System snapshots
# ---------------------------------------------------------------------------

def insert_snapshot(
    *,
    db_path=None,
    ollama_up: bool,
    rag_up: bool,
    active_connections: int,
    memory_mb: float,
    cpu_percent: float,
) -> None:
    """Insert a row into system_snapshots."""
    conn = _connect(db_path)
    try:
        conn.execute(
            """
            INSERT INTO system_snapshots
                (timestamp, ollama_up, rag_up, active_connections, memory_mb, cpu_percent)
            VALUES (?, ?, ?, ?, ?, ?)
            """,
            (
                _now_iso(),
                int(ollama_up),
                int(rag_up),
                active_connections,
                round(memory_mb, 1),
                round(cpu_percent, 1),
            ),
        )
        conn.commit()
    finally:
        conn.close()


def get_latest_snapshot(*, db_path=None) -> dict | None:
    """Return the most recent system snapshot, or None if table is empty."""
    conn = _connect(db_path)
    try:
        row = conn.execute(
            "SELECT * FROM system_snapshots ORDER BY id DESC LIMIT 1"
        ).fetchone()
    finally:
        conn.close()
    return dict(row) if row is not None else None


def get_snapshot_history(*, db_path=None, hours: int = 24) -> list:
    """Return snapshots from the last *hours* hours, ordered oldest-first."""
    cutoff = (
        datetime.datetime.now(datetime.UTC) - datetime.timedelta(hours=hours)
    ).isoformat().replace("+00:00", "Z")
    conn = _connect(db_path)
    try:
        rows = conn.execute(
            "SELECT * FROM system_snapshots WHERE timestamp >= ? ORDER BY id ASC",
            (cutoff,),
        ).fetchall()
    finally:
        conn.close()
    return [dict(r) for r in rows]


# ---------------------------------------------------------------------------
# Stats aggregation
# ---------------------------------------------------------------------------

_GENERATION_PATHS = ("/api/stream", "/api/chat", "/api/webhook/generate")


def get_stats(*, db_path=None) -> dict:
    """Return aggregated stats for today (since midnight UTC)."""
    today = datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%dT00:00:00")
    placeholders = ",".join("?" * len(_GENERATION_PATHS))

    conn = _connect(db_path)
    try:
        requests_today = conn.execute(
            "SELECT COUNT(*) FROM request_logs WHERE timestamp >= ?",
            (today,),
        ).fetchone()[0]

        generations_today = conn.execute(
            f"SELECT COUNT(*) FROM request_logs WHERE timestamp >= ? AND path IN ({placeholders})",
            (today, *_GENERATION_PATHS),
        ).fetchone()[0]

        avg_row = conn.execute(
            "SELECT AVG(duration_ms) FROM request_logs WHERE timestamp >= ?",
            (today,),
        ).fetchone()[0]
        avg_response_ms = round(avg_row, 1) if avg_row is not None else 0.0

        error_count = conn.execute(
            "SELECT COUNT(*) FROM request_logs WHERE timestamp >= ? AND status_code >= 500",
            (today,),
        ).fetchone()[0]
        error_rate_pct = round(error_count / requests_today * 100, 1) if requests_today else 0.0

        active_users_24h = conn.execute(
            """
            SELECT COUNT(DISTINCT user_id) FROM request_logs
            WHERE timestamp >= ? AND user_id IS NOT NULL
            """,
            (today,),
        ).fetchone()[0]

        tokens_row = conn.execute(
            "SELECT SUM(tokens_used) FROM request_logs WHERE timestamp >= ?",
            (today,),
        ).fetchone()[0]
        tokens_today = tokens_row if tokens_row is not None else 0
    finally:
        conn.close()

    return {
        "requests_today": requests_today,
        "generations_today": generations_today,
        "avg_response_ms": avg_response_ms,
        "error_rate_pct": error_rate_pct,
        "active_users_24h": active_users_24h,
        "tokens_today": tokens_today,
    }


def get_usage_breakdown(*, db_path=None, days: int = 7) -> list:
    """Per-day breakdown for the last *days* days."""
    cutoff = (
        datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
    ).isoformat().replace("+00:00", "Z")
    placeholders = ",".join("?" * len(_GENERATION_PATHS))

    conn = _connect(db_path)
    try:
        rows = conn.execute(
            f"""
            SELECT
                DATE(timestamp)                                       AS date,
                COUNT(*)                                              AS requests,
                SUM(CASE WHEN path IN ({placeholders}) THEN 1 ELSE 0 END) AS generations,
                COUNT(DISTINCT user_id)                               AS unique_users,
                COALESCE(SUM(tokens_used), 0)                        AS tokens
            FROM request_logs
            WHERE timestamp >= ?
            GROUP BY DATE(timestamp)
            ORDER BY DATE(timestamp) ASC
            """,
            (*_GENERATION_PATHS, cutoff),
        ).fetchall()
    finally:
        conn.close()
    return [dict(r) for r in rows]


# ---------------------------------------------------------------------------
# Config CRUD
# ---------------------------------------------------------------------------

def get_retention_days(*, db_path=None) -> int:
    """Read retention_days from monitoring_config, fallback to _LOG_RETENTION_DAYS."""
    conn = _connect(db_path)
    try:
        row = conn.execute(
            "SELECT value FROM monitoring_config WHERE key = 'retention_days'"
        ).fetchone()
    finally:
        conn.close()
    if row is not None:
        return int(row[0])
    return _LOG_RETENTION_DAYS


def set_retention_days(days: int, *, db_path=None) -> None:
    """Persist retention_days into monitoring_config."""
    conn = _connect(db_path)
    try:
        conn.execute(
            "INSERT OR REPLACE INTO monitoring_config (key, value) VALUES (?, ?)",
            ("retention_days", str(days)),
        )
        conn.commit()
    finally:
        conn.close()


# ---------------------------------------------------------------------------
# Retention cleanup
# ---------------------------------------------------------------------------

def run_cleanup(*, db_path=None) -> int:
    """Delete rows older than the configured retention period.

    Returns the total number of rows deleted across all three tables.
    """
    retention = get_retention_days(db_path=db_path)
    cutoff = (
        datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=retention)
    ).isoformat().replace("+00:00", "Z")

    conn = _connect(db_path)
    try:
        deleted_requests = conn.execute(
            "DELETE FROM request_logs WHERE timestamp < ?", (cutoff,)
        ).rowcount
        deleted_errors = conn.execute(
            "DELETE FROM error_logs WHERE timestamp < ?", (cutoff,)
        ).rowcount
        deleted_snaps = conn.execute(
            "DELETE FROM system_snapshots WHERE timestamp < ?", (cutoff,)
        ).rowcount
        conn.commit()
    finally:
        conn.close()

    total = deleted_requests + deleted_errors + deleted_snaps
    if total > 0:
        print(
            f"[cleanup] Deleted {deleted_requests} request_logs, "
            f"{deleted_errors} error_logs, {deleted_snaps} system_snapshots "
            f"(cutoff={cutoff})"
        )
    return total


def get_user_activity(*, db_path=None, days: int = 7) -> list:
    """Per-user activity summary for the last *days* days, ordered by generations DESC."""
    cutoff = (
        datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
    ).isoformat().replace("+00:00", "Z")
    placeholders = ",".join("?" * len(_GENERATION_PATHS))

    conn = _connect(db_path)
    try:
        rows = conn.execute(
            f"""
            SELECT
                user_id,
                MAX(username)                                         AS username,
                SUM(CASE WHEN path IN ({placeholders}) THEN 1 ELSE 0 END) AS generations,
                MAX(timestamp)                                        AS last_active,
                COALESCE(SUM(tokens_used), 0)                        AS total_tokens
            FROM request_logs
            WHERE timestamp >= ? AND user_id IS NOT NULL
            GROUP BY user_id
            ORDER BY generations DESC
            """,
            (*_GENERATION_PATHS, cutoff),
        ).fetchall()
    finally:
        conn.close()
    return [dict(r) for r in rows]


# ---------------------------------------------------------------------------
# Active connection counter
# ---------------------------------------------------------------------------

_active_connections = 0
_conn_lock = threading.Lock()


def inc_connections():
    global _active_connections
    with _conn_lock:
        _active_connections += 1


def dec_connections():
    global _active_connections
    with _conn_lock:
        _active_connections = max(0, _active_connections - 1)


def get_active_connections() -> int:
    with _conn_lock:
        return _active_connections


# ---------------------------------------------------------------------------
# CPU tracking state
# ---------------------------------------------------------------------------

_last_process_time = 0.0
_last_wall_time = 0.0


# ---------------------------------------------------------------------------
# Health collector
# ---------------------------------------------------------------------------

def _collect_health(ollama_url: str, rag_url: str):
    """Collect one health snapshot and schedule the next."""
    global _last_process_time, _last_wall_time

    # Ollama check - GET {ollama_url}/api/tags with 5s timeout
    ollama_up = False
    try:
        req = urllib.request.Request(f"{ollama_url}/api/tags", method="GET")
        with urllib.request.urlopen(req, timeout=5):
            ollama_up = True
    except Exception:
        pass

    # RAG check - GET {rag_url}/health with 3s timeout (skip if rag_url empty)
    rag_up = False
    if rag_url:
        try:
            req = urllib.request.Request(f"{rag_url}/health", method="GET")
            with urllib.request.urlopen(req, timeout=3):
                rag_up = True
        except Exception:
            pass

    # Memory (macOS: ru_maxrss is in bytes)
    usage = resource.getrusage(resource.RUSAGE_SELF)
    memory_mb = usage.ru_maxrss / (1024 * 1024)

    # CPU estimate via process_time delta
    now_process = time.process_time()
    now_wall = time.time()
    if _last_wall_time > 0:
        wall_delta = now_wall - _last_wall_time
        proc_delta = now_process - _last_process_time
        cpu_percent = (proc_delta / wall_delta * 100) if wall_delta > 0 else 0.0
    else:
        cpu_percent = 0.0
    _last_process_time = now_process
    _last_wall_time = now_wall

    insert_snapshot(
        ollama_up=ollama_up, rag_up=rag_up,
        active_connections=get_active_connections(),
        memory_mb=memory_mb, cpu_percent=cpu_percent,
    )

    # Schedule next run
    t = threading.Timer(_HEALTH_INTERVAL, _collect_health, args=(ollama_url, rag_url))
    t.daemon = True
    t.start()


# ---------------------------------------------------------------------------
# Cleanup loop
# ---------------------------------------------------------------------------

def _cleanup_loop():
    """Run cleanup periodically."""
    while True:
        time.sleep(_CLEANUP_INTERVAL)
        try:
            run_cleanup()
        except Exception as e:
            print(f"  \u26a0  Cleanup error: {e}")


# ---------------------------------------------------------------------------
# Background thread starter
# ---------------------------------------------------------------------------

def start_background_threads(ollama_url: str, rag_url: str):
    """Start health collector and cleanup threads. Call once from main()."""
    t = threading.Timer(2, _collect_health, args=(ollama_url, rag_url))
    t.daemon = True
    t.start()

    cleanup = threading.Thread(target=_cleanup_loop, daemon=True)
    cleanup.start()
    print(f"  \u2713  Monitoring: health every {_HEALTH_INTERVAL}s, cleanup every {_CLEANUP_INTERVAL}s")
