# qa-copilot/test_runs.py
"""
Test execution tracking and dashboard queries (#14).

Records test run results (pass/fail/skip per test), computes trends,
detects flaky tests, and tracks self-healing success rates.

Storage: SQLite (test_runs.db), WAL mode.
"""
from __future__ import annotations

import datetime
import sqlite3
import uuid
from pathlib import Path

SCRIPT_DIR = Path(__file__).parent
_DEFAULT_DB = str(SCRIPT_DIR / "test_runs.db")


def init_db(db_path: str = _DEFAULT_DB) -> None:
    conn = sqlite3.connect(db_path)
    conn.execute("PRAGMA journal_mode=WAL")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS runs (
            id          TEXT PRIMARY KEY,
            project     TEXT NOT NULL,
            suite       TEXT NOT NULL DEFAULT '',
            total       INTEGER NOT NULL DEFAULT 0,
            passed      INTEGER NOT NULL DEFAULT 0,
            failed      INTEGER NOT NULL DEFAULT 0,
            skipped     INTEGER NOT NULL DEFAULT 0,
            duration_ms INTEGER NOT NULL DEFAULT 0,
            created_at  TEXT NOT NULL
        )
    """)
    conn.execute("CREATE INDEX IF NOT EXISTS idx_runs_project ON runs(project, created_at DESC)")
    conn.execute("""
        CREATE TABLE IF NOT EXISTS run_results (
            id          INTEGER PRIMARY KEY AUTOINCREMENT,
            run_id      TEXT NOT NULL REFERENCES runs(id),
            name        TEXT NOT NULL,
            status      TEXT NOT NULL,
            duration_ms INTEGER NOT NULL DEFAULT 0,
            error       TEXT,
            healed      INTEGER NOT NULL DEFAULT 0,
            created_at  TEXT NOT NULL
        )
    """)
    conn.execute("CREATE INDEX IF NOT EXISTS idx_run_results_run ON run_results(run_id)")
    conn.execute("CREATE INDEX IF NOT EXISTS idx_run_results_name ON run_results(name, created_at DESC)")
    conn.commit()
    conn.close()


def record_run(project: str, suite: str, results: list[dict], db_path: str = _DEFAULT_DB) -> str:
    run_id = str(uuid.uuid4())[:8]
    now = datetime.datetime.utcnow().isoformat()
    passed = sum(1 for r in results if r["status"] == "passed")
    failed = sum(1 for r in results if r["status"] == "failed")
    skipped = sum(1 for r in results if r["status"] == "skipped")
    total_dur = sum(r.get("duration_ms", 0) for r in results)
    conn = sqlite3.connect(db_path)
    conn.execute(
        "INSERT INTO runs (id, project, suite, total, passed, failed, skipped, duration_ms, created_at) VALUES (?,?,?,?,?,?,?,?,?)",
        (run_id, project, suite, len(results), passed, failed, skipped, total_dur, now),
    )
    for r in results:
        conn.execute(
            "INSERT INTO run_results (run_id, name, status, duration_ms, error, healed, created_at) VALUES (?,?,?,?,?,?,?)",
            (run_id, r["name"], r["status"], r.get("duration_ms", 0), r.get("error"), int(r.get("healed", False)), now),
        )
    conn.commit()
    conn.close()
    return run_id


def get_runs(project: str | None = None, limit: int = 50, db_path: str = _DEFAULT_DB) -> list[dict]:
    conn = sqlite3.connect(db_path)
    if project:
        rows = conn.execute(
            "SELECT id, project, suite, total, passed, failed, skipped, duration_ms, created_at FROM runs WHERE project = ? ORDER BY created_at DESC LIMIT ?",
            (project, limit),
        ).fetchall()
    else:
        rows = conn.execute(
            "SELECT id, project, suite, total, passed, failed, skipped, duration_ms, created_at FROM runs ORDER BY created_at DESC LIMIT ?",
            (limit,),
        ).fetchall()
    conn.close()
    return [
        {"id": r[0], "project": r[1], "suite": r[2], "total": r[3], "passed": r[4], "failed": r[5], "skipped": r[6], "duration_ms": r[7], "created_at": r[8]}
        for r in rows
    ]


def get_trend(project: str, limit: int = 30, db_path: str = _DEFAULT_DB) -> list[dict]:
    conn = sqlite3.connect(db_path)
    rows = conn.execute(
        "SELECT id, passed, failed, skipped, total, created_at FROM runs WHERE project = ? ORDER BY created_at DESC LIMIT ?",
        (project, limit),
    ).fetchall()
    conn.close()
    return [
        {"run_id": r[0], "passed": r[1], "failed": r[2], "skipped": r[3], "total": r[4], "created_at": r[5]}
        for r in reversed(rows)
    ]


def get_flaky_tests(project: str, window: int = 10, min_flips: int = 2, db_path: str = _DEFAULT_DB) -> list[dict]:
    conn = sqlite3.connect(db_path)
    run_ids = conn.execute(
        "SELECT id FROM runs WHERE project = ? ORDER BY created_at DESC LIMIT ?",
        (project, window),
    ).fetchall()
    if not run_ids:
        conn.close()
        return []
    placeholders = ",".join("?" for _ in run_ids)
    ids = [r[0] for r in run_ids]
    rows = conn.execute(
        f"SELECT name, status FROM run_results WHERE run_id IN ({placeholders}) ORDER BY name, created_at",
        ids,
    ).fetchall()
    conn.close()
    flaky = []
    current_name = None
    prev_status = None
    flips = 0
    total = 0
    for name, status in rows:
        if name != current_name:
            if current_name and flips >= min_flips:
                flaky.append({"name": current_name, "flips": flips, "runs": total})
            current_name = name
            prev_status = status
            flips = 0
            total = 1
        else:
            total += 1
            if status != prev_status:
                flips += 1
            prev_status = status
    if current_name and flips >= min_flips:
        flaky.append({"name": current_name, "flips": flips, "runs": total})
    return flaky


def get_healing_rate(project: str, db_path: str = _DEFAULT_DB) -> dict:
    conn = sqlite3.connect(db_path)
    run_ids = conn.execute("SELECT id FROM runs WHERE project = ?", (project,)).fetchall()
    if not run_ids:
        conn.close()
        return {"healed": 0, "total": 0, "rate": 0.0}
    placeholders = ",".join("?" for _ in run_ids)
    ids = [r[0] for r in run_ids]
    row = conn.execute(
        f"SELECT SUM(healed), COUNT(*) FROM run_results WHERE run_id IN ({placeholders})",
        ids,
    ).fetchone()
    conn.close()
    healed = row[0] or 0
    total = row[1] or 0
    return {"healed": healed, "total": total, "rate": round(healed / total, 4) if total else 0.0}
