"""Fine-tuning pipeline integration for QA Copilot.

Provides dataset management, training orchestration (local + cloud GPU),
and deployment to Ollama. Follows the monitoring.py module pattern.
"""
from __future__ import annotations

import hashlib
import json
import os
import sqlite3
import threading
import time
from pathlib import Path
from typing import Any

# ── Constants ────────────────────────────────────────────────────────────
_DIR = Path(__file__).resolve().parent
_FINETUNE_DIR = _DIR.parent / "fine-tuning"
_DATASET_DIR = _FINETUNE_DIR / "dataset"
_DB_PATH = str(_DIR / "fine_tuning.db")


def _now_iso() -> str:
    from datetime import datetime, timezone
    return datetime.now(timezone.utc).isoformat()


def _connect(db_path: str | None = None) -> sqlite3.Connection:
    conn = sqlite3.connect(db_path or _DB_PATH, timeout=10)
    conn.row_factory = sqlite3.Row
    conn.execute("PRAGMA journal_mode=WAL")
    return conn


def _content_hash(messages: list[dict]) -> str:
    raw = json.dumps(messages, sort_keys=True, ensure_ascii=False)
    return hashlib.sha256(raw.encode()).hexdigest()[:16]


# ── Database Initialization ──────────────────────────────────────────────

def init_db(*, db_path: str | None = None) -> None:
    """Create fine_tuning.db tables (idempotent)."""
    conn = _connect(db_path)
    try:
        conn.executescript("""
            CREATE TABLE IF NOT EXISTS pairs (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                messages    TEXT    NOT NULL,
                track       TEXT    NOT NULL DEFAULT 'unknown',
                status      TEXT    NOT NULL DEFAULT 'raw',
                reviewer    TEXT,
                notes       TEXT,
                content_hash TEXT,
                synthetic   INTEGER DEFAULT 0,
                created_at  TEXT    NOT NULL,
                updated_at  TEXT    NOT NULL
            );

            CREATE INDEX IF NOT EXISTS idx_pairs_status ON pairs(status);
            CREATE INDEX IF NOT EXISTS idx_pairs_track ON pairs(track);
            CREATE UNIQUE INDEX IF NOT EXISTS idx_pairs_hash ON pairs(content_hash);

            CREATE TABLE IF NOT EXISTS jobs (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                type        TEXT    NOT NULL,
                status      TEXT    NOT NULL DEFAULT 'pending',
                config      TEXT,
                log         TEXT    DEFAULT '',
                started_at  TEXT,
                finished_at TEXT,
                error       TEXT,
                created_at  TEXT    NOT NULL
            );

            CREATE INDEX IF NOT EXISTS idx_jobs_type ON jobs(type);
            CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);

            CREATE TABLE IF NOT EXISTS cloud_sessions (
                id            INTEGER PRIMARY KEY AUTOINCREMENT,
                provider      TEXT    NOT NULL,
                instance_id   TEXT,
                gpu_tier      TEXT,
                status        TEXT    NOT NULL DEFAULT 'pending',
                credentials   TEXT,
                cost_estimate REAL    DEFAULT 0,
                cost_actual   REAL    DEFAULT 0,
                max_budget    REAL    DEFAULT 10.0,
                idle_timeout  INTEGER DEFAULT 1800,
                created_at    TEXT    NOT NULL,
                updated_at    TEXT    NOT NULL
            );
        """)

        # Per-project model support (Roadmap 3.3)
        try:
            conn.execute("ALTER TABLE pairs ADD COLUMN project TEXT DEFAULT ''")
        except Exception:
            pass  # Column already exists

        conn.execute("""
            CREATE TABLE IF NOT EXISTS project_models (
                id          INTEGER PRIMARY KEY AUTOINCREMENT,
                project     TEXT NOT NULL UNIQUE,
                adapter_path TEXT NOT NULL,
                base_model  TEXT NOT NULL,
                pair_count  INTEGER DEFAULT 0,
                trained_at  TEXT,
                created_at  TEXT NOT NULL,
                updated_at  TEXT NOT NULL
            );
        """)

        conn.commit()
    finally:
        conn.close()


# ── Pair CRUD ────────────────────────────────────────────────────────────

def _row_to_pair(row: sqlite3.Row) -> dict:
    d = dict(row)
    d["messages"] = json.loads(d["messages"])
    d["synthetic"] = bool(d.get("synthetic"))
    d.setdefault("project", "")
    return d


def insert_pair(*, messages: list[dict], track: str = "unknown", synthetic: bool = False, project: str = "", db_path: str | None = None) -> int:
    """Insert a training pair. Returns existing id if content hash matches."""
    ch = _content_hash(messages)
    now = _now_iso()
    conn = _connect(db_path)
    try:
        existing = conn.execute("SELECT id FROM pairs WHERE content_hash = ?", (ch,)).fetchone()
        if existing:
            return existing[0]
        cur = conn.execute(
            """INSERT INTO pairs (messages, track, status, synthetic, content_hash, project, created_at, updated_at)
               VALUES (?, ?, 'raw', ?, ?, ?, ?, ?)""",
            (json.dumps(messages, ensure_ascii=False), track, int(synthetic), ch, project, now, now),
        )
        conn.commit()
        return cur.lastrowid
    finally:
        conn.close()


def get_pair(*, pair_id: int, db_path: str | None = None) -> dict | None:
    conn = _connect(db_path)
    try:
        row = conn.execute("SELECT * FROM pairs WHERE id = ?", (pair_id,)).fetchone()
        return _row_to_pair(row) if row else None
    finally:
        conn.close()


def query_pairs(*, status: str | None = None, track: str | None = None, page: int = 1, limit: int = 50, db_path: str | None = None) -> dict:
    conditions, params = [], []
    if status:
        conditions.append("status = ?"); params.append(status)
    if track:
        conditions.append("track = ?"); params.append(track)
    where = ("WHERE " + " AND ".join(conditions)) if conditions else ""
    conn = _connect(db_path)
    try:
        total = conn.execute(f"SELECT COUNT(*) FROM pairs {where}", params).fetchone()[0]
        offset = (page - 1) * limit
        rows = conn.execute(f"SELECT * FROM pairs {where} ORDER BY id DESC LIMIT ? OFFSET ?", params + [limit, offset]).fetchall()
    finally:
        conn.close()
    pages = max(1, (total + limit - 1) // limit)
    return {"items": [_row_to_pair(r) for r in rows], "total": total, "page": page, "pages": pages}


def update_pair(*, pair_id: int, status: str | None = None, reviewer: str | None = None, notes: str | None = None, messages: list[dict] | None = None, db_path: str | None = None) -> None:
    sets, params = [], []
    if status is not None:
        sets.append("status = ?"); params.append(status)
    if reviewer is not None:
        sets.append("reviewer = ?"); params.append(reviewer)
    if notes is not None:
        sets.append("notes = ?"); params.append(notes)
    if messages is not None:
        sets.append("messages = ?"); params.append(json.dumps(messages, ensure_ascii=False))
        sets.append("content_hash = ?"); params.append(_content_hash(messages))
    if not sets:
        return
    sets.append("updated_at = ?"); params.append(_now_iso())
    params.append(pair_id)
    conn = _connect(db_path)
    try:
        conn.execute(f"UPDATE pairs SET {', '.join(sets)} WHERE id = ?", params)
        conn.commit()
    finally:
        conn.close()


def get_pair_stats(*, db_path: str | None = None) -> dict:
    conn = _connect(db_path)
    try:
        rows = conn.execute("SELECT status, COUNT(*) as cnt FROM pairs GROUP BY status").fetchall()
        total = conn.execute("SELECT COUNT(*) FROM pairs").fetchone()[0]
        tracks = conn.execute("SELECT track, COUNT(*) as cnt FROM pairs GROUP BY track").fetchall()
    finally:
        conn.close()
    by_status = {r["status"]: r["cnt"] for r in rows}
    by_track = {r["track"]: r["cnt"] for r in tracks}
    return {"total": total, "raw": by_status.get("raw", 0), "pending": by_status.get("pending", 0),
            "approved": by_status.get("approved", 0), "rejected": by_status.get("rejected", 0), "by_track": by_track}


def get_approved_examples(*, limit: int = 3, project: str | None = None, db_path: str | None = None) -> list[dict]:
    """Return the most recent approved pairs for few-shot injection.
    Returns a list of dicts with 'messages' key (parsed from JSON).
    Most recent pairs come first. Optionally filter by project.
    """
    conn = _connect(db_path)
    try:
        if project is not None:
            rows = conn.execute(
                "SELECT * FROM pairs WHERE status = 'approved' AND project = ? ORDER BY updated_at DESC LIMIT ?",
                (project, limit),
            ).fetchall()
        else:
            rows = conn.execute(
                "SELECT * FROM pairs WHERE status = 'approved' ORDER BY updated_at DESC LIMIT ?",
                (limit,),
            ).fetchall()
        return [_row_to_pair(r) for r in rows]
    finally:
        conn.close()


def get_project_model(*, project: str, db_path: str | None = None) -> dict | None:
    conn = _connect(db_path)
    try:
        row = conn.execute("SELECT * FROM project_models WHERE project = ?", (project,)).fetchone()
        return dict(row) if row else None
    finally:
        conn.close()


def register_project_model(*, project: str, adapter_path: str, base_model: str, pair_count: int = 0, db_path: str | None = None) -> None:
    now = _now_iso()
    conn = _connect(db_path)
    try:
        conn.execute(
            """INSERT INTO project_models (project, adapter_path, base_model, pair_count, trained_at, created_at, updated_at)
               VALUES (?, ?, ?, ?, ?, ?, ?)
               ON CONFLICT(project) DO UPDATE SET
                   adapter_path=excluded.adapter_path, base_model=excluded.base_model,
                   pair_count=excluded.pair_count, trained_at=excluded.trained_at, updated_at=excluded.updated_at""",
            (project, adapter_path, base_model, pair_count, now, now, now),
        )
        conn.commit()
    finally:
        conn.close()


def list_project_models(*, db_path: str | None = None) -> list[dict]:
    conn = _connect(db_path)
    try:
        rows = conn.execute("SELECT * FROM project_models ORDER BY project").fetchall()
        return [dict(r) for r in rows]
    finally:
        conn.close()


# ── Job Tracking ─────────────────────────────────────────────────────────

def create_job(*, job_type: str, config: dict | None = None, db_path: str | None = None) -> int:
    now = _now_iso()
    conn = _connect(db_path)
    try:
        cur = conn.execute("INSERT INTO jobs (type, status, config, created_at) VALUES (?, 'pending', ?, ?)", (job_type, json.dumps(config or {}), now))
        conn.commit()
        return cur.lastrowid
    finally:
        conn.close()


def get_job(*, job_id: int, db_path: str | None = None) -> dict | None:
    conn = _connect(db_path)
    try:
        row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
        return dict(row) if row else None
    finally:
        conn.close()


def update_job(*, job_id: int, status: str | None = None, error: str | None = None, db_path: str | None = None) -> None:
    sets, params = [], []
    if status is not None:
        sets.append("status = ?"); params.append(status)
        if status == "running":
            sets.append("started_at = ?"); params.append(_now_iso())
        elif status in ("completed", "failed", "cancelled"):
            sets.append("finished_at = ?"); params.append(_now_iso())
    if error is not None:
        sets.append("error = ?"); params.append(error)
    if not sets:
        return
    params.append(job_id)
    conn = _connect(db_path)
    try:
        conn.execute(f"UPDATE jobs SET {', '.join(sets)} WHERE id = ?", params)
        conn.commit()
    finally:
        conn.close()


def append_job_log(*, job_id: int, line: str, db_path: str | None = None) -> None:
    conn = _connect(db_path)
    try:
        conn.execute("UPDATE jobs SET log = log || ? WHERE id = ?", (line + "\n", job_id))
        conn.commit()
    finally:
        conn.close()


def list_jobs(*, job_type: str | None = None, limit: int = 20, db_path: str | None = None) -> list[dict]:
    conn = _connect(db_path)
    try:
        if job_type:
            rows = conn.execute("SELECT * FROM jobs WHERE type = ? ORDER BY id DESC LIMIT ?", (job_type, limit)).fetchall()
        else:
            rows = conn.execute("SELECT * FROM jobs ORDER BY id DESC LIMIT ?", (limit,)).fetchall()
    finally:
        conn.close()
    return [dict(r) for r in rows]


def mark_orphaned_jobs(*, db_path: str | None = None) -> None:
    now = _now_iso()
    conn = _connect(db_path)
    try:
        conn.execute("UPDATE jobs SET status = 'failed', error = 'server restarted', finished_at = ? WHERE status = 'running'", (now,))
        conn.commit()
    finally:
        conn.close()


# ── SSE Broadcasting ─────────────────────────────────────────────────────

_ft_subscribers: list = []
_ft_sub_lock = threading.Lock()


def add_subscriber(wfile) -> None:
    with _ft_sub_lock:
        _ft_subscribers.append(wfile)


def remove_subscriber(wfile) -> None:
    with _ft_sub_lock:
        try:
            _ft_subscribers.remove(wfile)
        except ValueError:
            pass


def _broadcast(event: dict) -> None:
    payload = "data: " + json.dumps(event) + "\n\n"
    dead = []
    with _ft_sub_lock:
        subscribers = list(_ft_subscribers)
    for wfile in subscribers:
        try:
            wfile.write(payload.encode())
            wfile.flush()
        except Exception:
            dead.append(wfile)
    for wfile in dead:
        remove_subscriber(wfile)


# ── Dataset Collection ────────────────────────────────────────────────────

import importlib.util


def _import_build_dataset():
    spec = importlib.util.spec_from_file_location("build_dataset", _FINETUNE_DIR / "build_dataset.py")
    mod = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(mod)
    return mod


def collect_dataset(*, tracks: list[str], jql: str = "", limit: int = 1000, config: dict, db_path: str | None = None) -> int:
    """Run dataset collection for given tracks in a background thread. Returns job id immediately."""
    job_id = create_job(job_type="collect", config={"tracks": tracks, "jql": jql, "limit": limit}, db_path=db_path)

    def _run():
        update_job(job_id=job_id, status="running", db_path=db_path)
        try:
            bd = _import_build_dataset()
            all_pairs = []
            base_url = config.get("atlassian_base_url", "")
            email = config.get("atlassian_email", "")
            token = config.get("atlassian_api_token", "")
            auth = bd._auth_header(email, token) if email and token else ""
            ollama_url = config.get("ollama_url", "http://localhost:11434")
            ollama_model = config.get("ollama_model", "qwen2.5:7b")

            for track in tracks:
                msg = f"Collecting track: {track}"
                append_job_log(job_id=job_id, line=msg, db_path=db_path)
                _broadcast({"type": "log", "job_id": job_id, "line": msg})
                if track == "istqb" and hasattr(bd, "_ISTQB_PAIRS"):
                    pairs = list(bd._ISTQB_PAIRS)
                elif track == "jira" and auth and hasattr(bd, "build_jira_pairs"):
                    pairs = bd.build_jira_pairs(base_url, auth, jql, limit, False)
                elif track == "ac" and auth and hasattr(bd, "build_ac_pairs"):
                    pairs = bd.build_ac_pairs(base_url, auth, jql, limit, False)
                elif track == "synthetic" and auth and hasattr(bd, "build_synthetic_pairs"):
                    pairs = bd.build_synthetic_pairs(base_url, auth, jql, limit, ollama_url, ollama_model, False)
                else:
                    append_job_log(job_id=job_id, line=f"  Skipped {track} (missing config or function)", db_path=db_path)
                    continue
                msg = f"  {track}: {len(pairs)} pairs collected"
                append_job_log(job_id=job_id, line=msg, db_path=db_path)
                _broadcast({"type": "log", "job_id": job_id, "line": msg})
                all_pairs.extend([(p, track) for p in pairs])

            inserted = 0
            for pair, track in all_pairs:
                messages = pair.get("messages", [])
                if messages:
                    insert_pair(messages=messages, track=track, synthetic=pair.get("_synthetic", False), db_path=db_path)
                    inserted += 1

            msg = f"Collection complete: {inserted} pairs inserted"
            append_job_log(job_id=job_id, line=msg, db_path=db_path)
            _broadcast({"type": "done", "job_id": job_id, "inserted": inserted})
            update_job(job_id=job_id, status="completed", db_path=db_path)
        except Exception as e:
            update_job(job_id=job_id, status="failed", error=str(e), db_path=db_path)
            _broadcast({"type": "error", "job_id": job_id, "message": str(e)})

    threading.Thread(target=_run, daemon=True).start()
    return job_id


# ── Filter and Assemble ───────────────────────────────────────────────────

def filter_pairs(*, min_chars: int = 100, db_path: str | None = None) -> dict:
    conn = _connect(db_path)
    try:
        rows = conn.execute("SELECT * FROM pairs WHERE status = 'raw'").fetchall()
    finally:
        conn.close()
    passed = 0
    rejected = 0
    for row in rows:
        messages = json.loads(row["messages"])
        assistant_content = ""
        for m in messages:
            if m.get("role") == "assistant":
                assistant_content += m.get("content", "")
        if len(assistant_content.strip()) >= min_chars:
            update_pair(pair_id=row["id"], status="pending", db_path=db_path)
            passed += 1
        else:
            update_pair(pair_id=row["id"], status="rejected", notes=f"assistant content too short ({len(assistant_content)} < {min_chars})", db_path=db_path)
            rejected += 1
    return {"passed": passed, "rejected": rejected}


def assemble_final(*, db_path: str | None = None) -> dict:
    conn = _connect(db_path)
    try:
        rows = conn.execute("SELECT * FROM pairs WHERE status = 'approved'").fetchall()
    finally:
        conn.close()
    final_records = []
    for row in rows:
        messages = json.loads(row["messages"])
        final_records.append({"messages": messages})
        user_content = ""
        for m in messages:
            if m.get("role") == "user":
                user_content = m["content"]
                break
        if user_content:
            variant_messages = []
            for m in messages:
                if m["role"] == "user":
                    variant_messages.append({"role": "user", "content": user_content + "\nGenerate: all"})
                else:
                    variant_messages.append(dict(m))
            final_records.append({"messages": variant_messages})
    _DATASET_DIR.mkdir(parents=True, exist_ok=True)
    output_path = _DATASET_DIR / "final.jsonl"
    with open(output_path, "w", encoding="utf-8") as f:
        for rec in final_records:
            f.write(json.dumps(rec, ensure_ascii=False) + "\n")
    return {"count": len(final_records), "path": str(output_path)}


# ── JSONL Import / Export ─────────────────────────────────────────────────

def export_jsonl(*, output_dir, db_path: str | None = None) -> dict:
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    counts = {}
    for status in ("raw", "pending", "approved", "rejected"):
        result = query_pairs(status=status, limit=100000, db_path=db_path)
        if result["items"]:
            path = output_dir / f"{status}.jsonl"
            with open(path, "w", encoding="utf-8") as f:
                for item in result["items"]:
                    record = {"messages": item["messages"]}
                    if item.get("synthetic"):
                        record["_synthetic"] = True
                    if item.get("notes") and status == "rejected":
                        record["_filter_reason"] = item["notes"]
                    f.write(json.dumps(record, ensure_ascii=False) + "\n")
            counts[status] = len(result["items"])
    return counts


def import_jsonl(*, input_dir, db_path: str | None = None) -> dict:
    input_dir = Path(input_dir)
    total = 0
    skipped = 0
    status_map = {"raw.jsonl": "raw", "pending.jsonl": "pending", "approved.jsonl": "approved", "rejected.jsonl": "rejected"}
    for filename, status in status_map.items():
        path = input_dir / filename
        if not path.exists():
            continue
        with open(path, encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                record = json.loads(line)
                messages = record.get("messages", [])
                if not messages:
                    continue
                ch = _content_hash(messages)
                conn = _connect(db_path)
                try:
                    existing = conn.execute("SELECT id FROM pairs WHERE content_hash = ?", (ch,)).fetchone()
                finally:
                    conn.close()
                if existing:
                    skipped += 1
                    continue
                pid = insert_pair(messages=messages, track=record.get("track", "imported"), synthetic=record.get("_synthetic", False), db_path=db_path)
                if status != "raw":
                    update_pair(pair_id=pid, status=status, db_path=db_path)
                total += 1
    return {"total": total, "skipped": skipped}


# ── Cloud column whitelist ────────────────────────────────────────────────

_CLOUD_COLUMNS = {"instance_id", "status", "cost_estimate", "cost_actual", "credentials", "updated_at"}

# ── Training Manager ──────────────────────────────────────────────────────

_training_lock = threading.Lock()
_training_thread: threading.Thread | None = None
_training_cancel = threading.Event()


def is_training_active() -> bool:
    return _training_thread is not None and _training_thread.is_alive()


def start_local_training(*, backend: str, params: dict, db_path: str | None = None) -> int:
    global _training_thread
    with _training_lock:
        if is_training_active():
            raise RuntimeError("A training job is already running")
    job_id = create_job(job_type="train", config={"backend": backend, **params}, db_path=db_path)
    _training_cancel.clear()

    def _run():
        update_job(job_id=job_id, status="running", db_path=db_path)
        _log(job_id, f"Starting {backend} training...", db_path)
        try:
            dataset_path = str(_DATASET_DIR / "final.jsonl")
            if not Path(dataset_path).exists():
                raise FileNotFoundError(f"Dataset not found: {dataset_path}. Run Assemble first.")
            _log(job_id, f"Dataset: {dataset_path}", db_path)

            def progress_cb(step, total, loss, epoch):
                line = f"[epoch {epoch}] step {step}/{total} loss={loss:.4f}"
                append_job_log(job_id=job_id, line=line, db_path=db_path)
                _broadcast({"type": "metric", "job_id": job_id, "step": step, "total": total, "loss": loss, "epoch": epoch})
                if _training_cancel.is_set():
                    raise InterruptedError("Training cancelled by user")

            if backend == "mlx":
                _log(job_id, "Loading mlx-lm backend...", db_path)
                mlx_spec = importlib.util.spec_from_file_location("mlx_train", _FINETUNE_DIR / "backends" / "mlx_train.py")
                mlx_mod = importlib.util.module_from_spec(mlx_spec)
                mlx_spec.loader.exec_module(mlx_mod)
                _log(job_id, f"Model: {params.get('model', 'mlx-community/Qwen2.5-7B-Instruct-4bit')}, iters: {params.get('iters', 1000)}", db_path)
                mlx_mod.train(
                    model=params.get("model", "mlx-community/Qwen2.5-7B-Instruct-4bit"),
                    data_path=Path(dataset_path), output_dir=_FINETUNE_DIR / "models",
                    iters=params.get("iters", 1000), batch_size=params.get("batch_size", 1),
                    lora_layers=params.get("lora_layers", 8),
                    llama_cpp_dir=Path(params["llama_cpp"]) if params.get("llama_cpp") else None,
                )
            elif backend == "qlora":
                _log(job_id, "Loading QLoRA backend...", db_path)
                qlora_spec = importlib.util.spec_from_file_location("train_qlora", _FINETUNE_DIR / "train_qlora.py")
                qlora_mod = importlib.util.module_from_spec(qlora_spec)
                qlora_spec.loader.exec_module(qlora_mod)
                qlora_mod.train(
                    model=params.get("model", "Qwen/Qwen2.5-7B-Instruct"),
                    dataset_path=dataset_path,
                    output_dir=str(_FINETUNE_DIR / "output" / "qa_copilot_qlora"),
                    epochs=params.get("epochs", 3), batch_size=params.get("batch_size", 2),
                    grad_acc=params.get("grad_acc", 8), lr=params.get("lr", 2e-4),
                    max_seq=params.get("max_seq", 2048), lora_r=params.get("lora_r", 16),
                    lora_alpha=params.get("lora_alpha", 32),
                    no_quantize=params.get("no_quantize", False),
                    progress_callback=progress_cb,
                )
            _log(job_id, "Training completed successfully", db_path)
            update_job(job_id=job_id, status="completed", db_path=db_path)
            _broadcast({"type": "done", "job_id": job_id})
        except InterruptedError:
            update_job(job_id=job_id, status="cancelled", db_path=db_path)
            _broadcast({"type": "log", "job_id": job_id, "line": "Training cancelled"})
        except Exception as e:
            _log(job_id, f"ERROR: {e}", db_path)
            update_job(job_id=job_id, status="failed", error=str(e), db_path=db_path)
            _broadcast({"type": "error", "job_id": job_id, "message": str(e)})

    _training_thread = threading.Thread(target=_run, daemon=True)
    _training_thread.start()
    return job_id


def cancel_training() -> None:
    _training_cancel.set()


# ── Cloud GPU Manager ─────────────────────────────────────────────────────

def _ensure_ssh_key() -> tuple[str, str]:
    """Generate SSH keypair in ~/.qa-copilot/cloud/ if not present.
    Returns (private_key_path, public_key_content)."""
    import subprocess
    key_dir = Path.home() / ".qa-copilot" / "cloud"
    key_dir.mkdir(parents=True, exist_ok=True)
    key_path = key_dir / "id_rsa"
    pub_path = key_dir / "id_rsa.pub"
    if not key_path.exists():
        subprocess.run(
            ["ssh-keygen", "-t", "rsa", "-b", "4096", "-N", "", "-f", str(key_path)],
            check=True, capture_output=True,
        )
    pub_key = pub_path.read_text().strip()
    return str(key_path), pub_key


def _runpod_provision(api_key: str, gpu_tier: str, ssh_pub_key: str) -> dict:
    """Provision a RunPod instance via GraphQL API. Returns instance info dict."""
    import urllib.request

    gpu_map = {
        "small": "NVIDIA GeForce RTX 3090",
        "medium": "NVIDIA A40",
        "large": "NVIDIA A100 80GB PCIe",
    }
    gpu_type = gpu_map.get(gpu_tier, gpu_tier)

    body = json.dumps({
        "query": (
            "mutation($input: PodFindAndDeployOnDemandInput!) {"
            " podFindAndDeployOnDemand(input: $input) { id imageName machineId desiredStatus } }"
        ),
        "variables": {
            "input": {
                "cloudType": "SECURE",
                "gpuCount": 1,
                "gpuTypeId": gpu_type,
                "containerDiskInGb": 50,
                "volumeInGb": 50,
                "minVcpuCount": 4,
                "minMemoryInGb": 16,
                "imageName": "runpod/pytorch:2.1.0-py3.10-cuda11.8.0-devel-ubuntu22.04",
                "dockerArgs": "",
                "ports": "22/tcp",
                "volumeMountPath": "/workspace",
                "env": [{"key": "PUBLIC_KEY", "value": ssh_pub_key}],
            }
        },
    }).encode()
    req = urllib.request.Request(
        "https://api.runpod.io/graphql",
        data=body,
        headers={
            "Content-Type": "application/json",
            "Authorization": "Bearer " + api_key,
        },
    )
    with urllib.request.urlopen(req, timeout=30) as resp:
        data = json.loads(resp.read())
    pod = data.get("data", {}).get("podFindAndDeployOnDemand", {})
    return {"instance_id": pod.get("id"), "provider": "runpod", "status": "provisioning"}


def _lambda_provision(api_key: str, gpu_tier: str, ssh_pub_key: str) -> dict:
    """Provision a Lambda Labs instance via REST API. Returns instance info dict."""
    import urllib.request

    instance_map = {
        "small": "gpu_1x_a10",
        "medium": "gpu_1x_a100",
        "large": "gpu_1x_a100_sxm4",
    }
    instance_type = instance_map.get(gpu_tier, gpu_tier)

    # First, add SSH key
    key_name = "qa-copilot-key"
    add_key_body = json.dumps({"name": key_name, "public_key": ssh_pub_key}).encode()
    add_key_req = urllib.request.Request(
        "https://cloud.lambdalabs.com/api/v1/ssh-keys",
        data=add_key_body,
        headers={
            "Content-Type": "application/json",
            "Authorization": "Basic " + __import__("base64").b64encode(f"{api_key}:".encode()).decode(),
        },
    )
    try:
        with urllib.request.urlopen(add_key_req, timeout=30) as resp:
            key_data = json.loads(resp.read())
            key_name = key_data.get("data", {}).get("name", key_name)
    except Exception:
        pass  # Key may already exist

    # Launch instance
    launch_body = json.dumps({
        "region_name": "us-west-2",
        "instance_type_name": instance_type,
        "ssh_key_names": [key_name],
        "quantity": 1,
    }).encode()
    launch_req = urllib.request.Request(
        "https://cloud.lambdalabs.com/api/v1/instance-operations/launch",
        data=launch_body,
        headers={
            "Content-Type": "application/json",
            "Authorization": "Basic " + __import__("base64").b64encode(f"{api_key}:".encode()).decode(),
        },
    )
    with urllib.request.urlopen(launch_req, timeout=30) as resp:
        data = json.loads(resp.read())
    instance_ids = data.get("data", {}).get("instance_ids", [])
    instance_id = instance_ids[0] if instance_ids else None
    return {"instance_id": instance_id, "provider": "lambda", "status": "provisioning"}


def _update_cloud_session(session_id: int, *, db_path: str | None = None, **kwargs) -> None:
    for k in kwargs:
        if k not in _CLOUD_COLUMNS:
            raise ValueError(f"Invalid column: {k}")
    sets = [f"{k} = ?" for k in kwargs]
    sets.append("updated_at = ?")
    params = list(kwargs.values()) + [_now_iso(), session_id]
    conn = _connect(db_path)
    try:
        conn.execute(f"UPDATE cloud_sessions SET {', '.join(sets)} WHERE id = ?", params)
        conn.commit()
    finally:
        conn.close()


def _create_cloud_session(*, provider: str, gpu_tier: str, max_budget: float, idle_timeout: int, db_path: str | None = None) -> int:
    now = _now_iso()
    conn = _connect(db_path)
    try:
        cur = conn.execute(
            "INSERT INTO cloud_sessions (provider, gpu_tier, status, max_budget, idle_timeout, created_at, updated_at) VALUES (?, ?, 'pending', ?, ?, ?, ?)",
            (provider, gpu_tier, max_budget, idle_timeout, now, now),
        )
        conn.commit()
        return cur.lastrowid
    finally:
        conn.close()


def start_cloud_training(
    *,
    provider: str,
    api_key: str,
    gpu_tier: str,
    params: dict,
    max_budget: float = 10.0,
    idle_timeout: int = 1800,
    db_path: str | None = None,
) -> int:
    """Provision a cloud GPU, upload data, run training, download results, terminate.
    Returns job_id."""
    job_id = create_job(job_type="cloud_train", config={"provider": provider, "gpu_tier": gpu_tier, **params}, db_path=db_path)
    session_id = _create_cloud_session(provider=provider, gpu_tier=gpu_tier, max_budget=max_budget, idle_timeout=idle_timeout, db_path=db_path)

    def _run():
        update_job(job_id=job_id, status="running", db_path=db_path)
        instance_id = None
        try:
            _log(job_id, f"Ensuring SSH key...", db_path)
            key_path, pub_key = _ensure_ssh_key()

            _log(job_id, f"Provisioning {provider} instance (tier={gpu_tier})...", db_path)
            if provider == "runpod":
                info = _runpod_provision(api_key, gpu_tier, pub_key)
            elif provider == "lambda":
                info = _lambda_provision(api_key, gpu_tier, pub_key)
            else:
                raise ValueError(f"Unknown provider: {provider}")

            instance_id = info.get("instance_id")
            _update_cloud_session(session_id, instance_id=instance_id, status="provisioning", db_path=db_path)
            _log(job_id, f"Instance provisioned: {instance_id}", db_path)

            # Wait for instance to be ready (poll SSH)
            import socket
            import paramiko
            _log(job_id, "Waiting for SSH to become available...", db_path)
            ssh_host = _get_instance_ip(provider, api_key, instance_id)
            deadline = time.time() + 300
            while time.time() < deadline:
                try:
                    sock = socket.create_connection((ssh_host, 22), timeout=5)
                    sock.close()
                    break
                except (socket.timeout, ConnectionRefusedError, OSError):
                    time.sleep(10)
            else:
                raise TimeoutError("Instance SSH did not become available within 5 minutes")

            _update_cloud_session(session_id, status="running", db_path=db_path)
            _log(job_id, f"SSH available at {ssh_host}", db_path)

            # Connect via SSH
            client = paramiko.SSHClient()
            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
            client.connect(ssh_host, username="ubuntu", key_filename=key_path, timeout=30)

            # Upload dataset
            dataset_local = str(_DATASET_DIR / "final.jsonl")
            _log(job_id, "Uploading dataset...", db_path)
            sftp = client.open_sftp()
            sftp.put(dataset_local, "/workspace/final.jsonl")

            # Upload training script
            train_script = str(_FINETUNE_DIR / "train_qlora.py")
            sftp.put(train_script, "/workspace/train_qlora.py")
            sftp.close()

            # Install deps and run training
            _log(job_id, "Installing dependencies on remote...", db_path)
            _ssh_exec(client, "pip install -q transformers peft trl datasets torch bitsandbytes")

            cmd = (
                f"python3 /workspace/train_qlora.py"
                f" --dataset /workspace/final.jsonl"
                f" --output /workspace/output"
                f" --epochs {params.get('epochs', 3)}"
                f" --batch {params.get('batch_size', 2)}"
                f" --grad-acc {params.get('grad_acc', 8)}"
                f" --lr {params.get('lr', 2e-4)}"
            )
            _log(job_id, f"Running training: {cmd}", db_path)
            _ssh_exec(client, cmd, job_id=job_id, db_path=db_path)

            # Download merged model
            _log(job_id, "Downloading merged model...", db_path)
            local_out = _FINETUNE_DIR / "output" / "qa_copilot_qlora" / "merged"
            local_out.mkdir(parents=True, exist_ok=True)
            _sftp_download_dir(client, "/workspace/output/merged", str(local_out))

            client.close()
            _log(job_id, "Training complete. Model downloaded.", db_path)
            update_job(job_id=job_id, status="completed", db_path=db_path)
            _update_cloud_session(session_id, status="completed", db_path=db_path)
            _broadcast({"type": "done", "job_id": job_id})

        except Exception as e:
            update_job(job_id=job_id, status="failed", error=str(e), db_path=db_path)
            _update_cloud_session(session_id, status="failed", db_path=db_path)
            _broadcast({"type": "error", "job_id": job_id, "message": str(e)})
        finally:
            if instance_id:
                try:
                    _log(job_id, f"Terminating instance {instance_id}...", db_path)
                    terminate_cloud_instance(provider=provider, api_key=api_key, instance_id=instance_id, db_path=db_path)
                    _update_cloud_session(session_id, status="terminated", db_path=db_path)
                except Exception as te:
                    _log(job_id, f"WARNING: failed to terminate instance: {te}", db_path)

    t = threading.Thread(target=_run, daemon=True)
    t.start()
    return job_id


def _get_instance_ip(provider: str, api_key: str, instance_id: str) -> str:
    """Retrieve the public IP of a provisioned instance."""
    import urllib.request
    if provider == "runpod":
        body = json.dumps({
            "query": "query($input: PodFilter!) { pod(input: $input) { runtime { ports { ip } } } }",
            "variables": {"input": {"podId": instance_id}},
        }).encode()
        req = urllib.request.Request(
            "https://api.runpod.io/graphql",
            data=body,
            headers={
                "Content-Type": "application/json",
                "Authorization": "Bearer " + api_key,
            },
        )
        with urllib.request.urlopen(req, timeout=30) as resp:
            data = json.loads(resp.read())
        ports = data.get("data", {}).get("pod", {}).get("runtime", {}).get("ports", [])
        if ports:
            return ports[0].get("ip", "")
        raise RuntimeError("Could not retrieve RunPod instance IP")
    elif provider == "lambda":
        auth = __import__("base64").b64encode(f"{api_key}:".encode()).decode()
        req = urllib.request.Request(
            f"https://cloud.lambdalabs.com/api/v1/instances/{instance_id}",
            headers={"Authorization": "Basic " + auth},
        )
        with urllib.request.urlopen(req, timeout=30) as resp:
            data = json.loads(resp.read())
        return data.get("data", {}).get("ip", "")
    raise ValueError(f"Unknown provider: {provider}")


def _ssh_exec(client, cmd: str, *, job_id: int | None = None, db_path: str | None = None) -> None:
    """Execute command over SSH, streaming stdout/stderr to job log."""
    stdin, stdout, stderr = client.exec_command(cmd, get_pty=True)
    for line in iter(stdout.readline, ""):
        line = line.rstrip("\n")
        if line and job_id is not None:
            append_job_log(job_id=job_id, line=line, db_path=db_path)
    exit_code = stdout.channel.recv_exit_status()
    if exit_code != 0:
        err = stderr.read().decode(errors="replace")
        raise RuntimeError(f"Remote command failed (exit {exit_code}): {err}")


def _sftp_download_dir(client, remote_dir: str, local_dir: str) -> None:
    """Recursively download a remote directory via SFTP."""
    sftp = client.open_sftp()
    _sftp_download_recursive(sftp, remote_dir, local_dir)
    sftp.close()


def _sftp_download_recursive(sftp, remote_path: str, local_path: str) -> None:
    import stat
    Path(local_path).mkdir(parents=True, exist_ok=True)
    for entry in sftp.listdir_attr(remote_path):
        remote_item = remote_path.rstrip("/") + "/" + entry.filename
        local_item = str(Path(local_path) / entry.filename)
        if stat.S_ISDIR(entry.st_mode):
            _sftp_download_recursive(sftp, remote_item, local_item)
        else:
            sftp.get(remote_item, local_item)


def terminate_cloud_instance(*, provider: str, api_key: str, instance_id: str, db_path: str | None = None) -> None:
    """Terminate a cloud GPU instance."""
    import urllib.request
    if provider == "runpod":
        body = json.dumps({
            "query": "mutation($input: PodTerminateInput!) { podTerminate(input: $input) }",
            "variables": {"input": {"podId": instance_id}},
        }).encode()
        req = urllib.request.Request(
            "https://api.runpod.io/graphql",
            data=body,
            headers={
                "Content-Type": "application/json",
                "Authorization": "Bearer " + api_key,
            },
        )
        with urllib.request.urlopen(req, timeout=30):
            pass
    elif provider == "lambda":
        auth = __import__("base64").b64encode(f"{api_key}:".encode()).decode()
        body = json.dumps({"instance_ids": [instance_id]}).encode()
        req = urllib.request.Request(
            "https://cloud.lambdalabs.com/api/v1/instance-operations/terminate",
            data=body,
            headers={
                "Content-Type": "application/json",
                "Authorization": "Basic " + auth,
            },
        )
        with urllib.request.urlopen(req, timeout=30):
            pass
    else:
        raise ValueError(f"Unknown provider: {provider}")


def get_cloud_status(*, db_path: str | None = None) -> dict | None:
    """Return the latest cloud session record."""
    conn = _connect(db_path)
    try:
        row = conn.execute("SELECT * FROM cloud_sessions ORDER BY id DESC LIMIT 1").fetchone()
        return dict(row) if row else None
    finally:
        conn.close()


def check_orphaned_cloud(*, db_path: str | None = None) -> list[dict]:
    """Find cloud sessions stuck in provisioning or running state."""
    conn = _connect(db_path)
    try:
        rows = conn.execute(
            "SELECT * FROM cloud_sessions WHERE status IN ('provisioning', 'running') ORDER BY id DESC"
        ).fetchall()
        return [dict(r) for r in rows]
    finally:
        conn.close()


# ── Deploy Manager ────────────────────────────────────────────────────────

def get_active_job_log(*, db_path: str | None = None) -> dict | None:
    """Get the most recent running or just-completed job with its log, for SSE replay."""
    conn = _connect(db_path)
    try:
        row = conn.execute(
            "SELECT * FROM jobs WHERE status IN ('running', 'completed', 'failed') ORDER BY id DESC LIMIT 1"
        ).fetchone()
        return dict(row) if row else None
    finally:
        conn.close()


def _log(job_id: int, line: str, db_path: str | None = None) -> None:
    append_job_log(job_id=job_id, line=line, db_path=db_path)
    _broadcast({"type": "log", "job_id": job_id, "line": line})


def deploy_to_ollama(*, model_name: str = "qa-copilot:v1", model_path: str | None = None, db_path: str | None = None) -> int:
    """Register a trained model with Ollama in a background thread. Returns job id immediately."""
    import subprocess
    job_id = create_job(job_type="deploy", config={"model_name": model_name, "model_path": model_path}, db_path=db_path)

    def _run():
        update_job(job_id=job_id, status="running", db_path=db_path)
        _log(job_id, f"Starting deployment of {model_name}...", db_path)
        try:
            resolved_path = model_path
            if not resolved_path:
                mlx_gguf = _FINETUNE_DIR / "models" / "qa_copilot.gguf"
                qlora_merged = _FINETUNE_DIR / "output" / "qa_copilot_qlora" / "merged"
                if mlx_gguf.exists():
                    resolved_path = str(mlx_gguf)
                elif qlora_merged.exists():
                    resolved_path = str(qlora_merged)
                else:
                    raise FileNotFoundError("No trained model found. Run training first.")
            _log(job_id, f"Model path: {resolved_path}", db_path)
            modelfile_path = _FINETUNE_DIR / "Modelfile"
            modelfile_path.write_text(f"FROM {resolved_path}\n")
            _log(job_id, f"Updated Modelfile: FROM {resolved_path}", db_path)
            _log(job_id, f"Registering {model_name} with Ollama...", db_path)
            result = subprocess.run(["ollama", "create", model_name, "-f", str(modelfile_path)], capture_output=True, text=True, timeout=300)
            if result.returncode != 0:
                raise RuntimeError(f"ollama create failed: {result.stderr}")
            _log(job_id, f"Model {model_name} registered successfully", db_path)
            update_job(job_id=job_id, status="completed", db_path=db_path)
            _broadcast({"type": "done", "job_id": job_id, "model_name": model_name})
        except Exception as e:
            update_job(job_id=job_id, status="failed", error=str(e), db_path=db_path)
            _broadcast({"type": "error", "job_id": job_id, "message": str(e)})

    threading.Thread(target=_run, daemon=True).start()
    return job_id


def smoke_test_model(*, model_name: str = "qa-copilot:v1", prompt: str = "Generate test cases for a login feature", ollama_url: str = "http://localhost:11434") -> str:
    import urllib.request
    body = json.dumps({"model": model_name, "prompt": prompt, "stream": False}).encode()
    req = urllib.request.Request(f"{ollama_url}/api/generate", data=body, headers={"Content-Type": "application/json"})
    with urllib.request.urlopen(req, timeout=120) as resp:
        data = json.loads(resp.read())
    return data.get("response", "")


