"""
Multi-provider LLM adapter for QA Copilot.

Supported providers: ollama | claude | openai | openai-compat
Selected via LLM_PROVIDER env var (default: ollama).
Per-request override: include provider= in POST body.
"""
from __future__ import annotations

import json
import os
import ssl
import urllib.error
import urllib.request
from typing import Iterator

# Fix macOS Python SSL — use certifi CA bundle
try:
    import certifi
    ssl._create_default_https_context = lambda: ssl.create_default_context(cafile=certifi.where())
except ImportError:
    pass


class LLMError(Exception):
    def __init__(self, message: str, status_code: int = 500):
        super().__init__(message)
        self.status_code = status_code


# ── Model-aware context sizing (Roadmap #56) ─────────────────────────────────
#
# OLLAMA_CTX used to default to 16384 for every model. That's overkill for 7B/8B
# models — they run faster with 4K-8K contexts, and Ollama happily allocates
# the larger window even when the model can barely fill it, wasting RAM and
# slowing first-token latency. resolve_ctx() picks a sensible default per model
# class. OLLAMA_CTX env override still wins so power users can pin a value.

_CTX_SMALL = 4096    # 7B / 8B and below — fastest, fits short stories + plan
_CTX_MID   = 8192    # 13B / 14B class — comfortable for most test artifacts
_CTX_LARGE = 16384   # 32B+ / hosted models — keep the historical default

# Patterns shared with generation._is_small_model so the two stay in sync.
_SMALL_MODEL_HINTS = (":7b", ":8b", ":1b", ":3b", ":1.5b", ":3.5b",
                       "-7b", "-8b", "-1b", "-3b", "haiku")
_MID_MODEL_HINTS = (":13b", ":14b", "-13b", "-14b")


def resolve_ctx(model: str) -> int:
    """Return the Ollama ``num_ctx`` to use for ``model``.

    Order:
      1. ``OLLAMA_CTX`` env var if set (numeric).
      2. Per-model class default (small/mid/large).
    """
    env = os.environ.get("OLLAMA_CTX")
    if env:
        try:
            return int(env)
        except ValueError:
            pass
    if not model:
        return _CTX_LARGE
    name = model.lower()
    if any(h in name for h in _MID_MODEL_HINTS):
        return _CTX_MID
    if any(h in name for h in _SMALL_MODEL_HINTS):
        return _CTX_SMALL
    return _CTX_LARGE


def trim_to_budget(text: str, max_chars: int) -> str:
    """Trim ``text`` to roughly ``max_chars`` characters, preserving the end.

    The end of the prompt is usually most relevant (the user's actual question
    or the latest plan step), so we drop from the middle and keep both bookends.
    Returns the original text when it already fits.
    """
    if not text or max_chars <= 0 or len(text) <= max_chars:
        return text
    head = max_chars // 3
    tail = max_chars - head - 32  # leave room for the marker
    if tail <= 0:
        return text[-max_chars:]
    return text[:head] + "\n…[trimmed for context budget]…\n" + text[-tail:]


# ── JSON repair helpers (Roadmap #53) ────────────────────────────────────────
#
# Small Ollama models (e.g. qwen2.5-coder:7b) sometimes ignore format="json"
# and emit prose around their JSON, or get cut off by num_predict mid-object.
# These helpers run *after* the model responds and try, in order:
#   1. Direct json.loads.
#   2. Strip code-fence / leading-prose, find the first '{', try again.
#   3. Close trailing unbalanced braces/brackets to recover truncated output.
# Used by _ollama_json + can be reused by future small-model paths.

_JSON_ONLY_SUFFIX = (
    "\n\nIMPORTANT: Respond with ONLY a single valid JSON object. "
    "No prose, no markdown fences, no commentary. Start with '{' and end with '}'."
)


def _close_truncated(text: str) -> str:
    """Append closing braces/brackets to balance an unclosed JSON object."""
    stack: list[str] = []
    in_str = False
    esc = False
    for ch in text:
        if esc:
            esc = False
            continue
        if ch == "\\" and in_str:
            esc = True
            continue
        if ch == '"':
            in_str = not in_str
            continue
        if in_str:
            continue
        if ch in "{[":
            stack.append("}" if ch == "{" else "]")
        elif ch in "}]" and stack:
            stack.pop()
    return text + ('"' if in_str else "") + "".join(reversed(stack))


def _repair_json_text(raw: str) -> dict:
    """Parse JSON from a (possibly noisy) model response. Raises ValueError if unrecoverable."""
    if not isinstance(raw, str):
        raise ValueError("non-string LLM response")
    # 1. Direct parse — fastest, succeeds when format=json works as advertised.
    try:
        return json.loads(raw)
    except json.JSONDecodeError:
        pass
    # 2. Strip code fences + prose before the first '{'.
    stripped = raw
    for fence in ("```json", "```JSON", "```"):
        stripped = stripped.replace(fence, "")
    start = stripped.find("{")
    if start == -1:
        raise ValueError("no JSON object found in response")
    candidate = stripped[start:].strip()
    try:
        return json.loads(candidate)
    except json.JSONDecodeError:
        pass
    # 3. Close truncated braces/brackets.
    try:
        return json.loads(_close_truncated(candidate))
    except json.JSONDecodeError as e:
        raise ValueError(f"could not parse JSON after repair: {e}")


OPENAI_COMPAT_PRESETS: dict[str, str | None] = {
    "groq":     "https://api.groq.com/openai/v1",
    "mistral":  "https://api.mistral.ai/v1",
    "together": "https://api.together.xyz/v1",
    "custom":   None,
}

DEFAULT_MODELS: dict[str, list[str]] = {
    "claude":        ["claude-sonnet-4-6", "claude-haiku-4-5", "claude-opus-4-6"],
    "openai":        ["gpt-4o", "gpt-4o-mini", "o1-mini"],
    "openai-compat": [],
    "ollama":        [],
}

_GROQ_MODELS     = ["llama-3.1-70b-versatile", "llama-3.1-8b-instant", "mixtral-8x7b-32768"]
_MISTRAL_MODELS  = ["mistral-large-latest", "mistral-small-latest", "open-mixtral-8x7b"]
_TOGETHER_MODELS = ["meta-llama/Llama-3-70b-chat-hf", "mistralai/Mixtral-8x7B-Instruct-v0.1"]


def get_default_provider() -> str:
    return os.environ.get("LLM_PROVIDER", "ollama")


def get_providers_info() -> dict:
    ollama_url = os.environ.get("OLLAMA_URL", "http://localhost:11434")
    ollama_models: list[str] = []
    try:
        resp = urllib.request.urlopen(f"{ollama_url}/api/tags", timeout=3)
        data = json.loads(resp.read())
        ollama_models = [m["name"] for m in data.get("models", [])]
    except Exception:
        pass

    compat_base = os.environ.get("OPENAI_COMPAT_BASE_URL", "")
    compat_key  = os.environ.get("OPENAI_COMPAT_API_KEY", "")
    compat_models: list[str] = []
    for _name, _url, _models in [
        ("groq",    "https://api.groq.com/openai/v1",  _GROQ_MODELS),
        ("mistral", "https://api.mistral.ai/v1",        _MISTRAL_MODELS),
        ("together","https://api.together.xyz/v1",      _TOGETHER_MODELS),
    ]:
        if compat_base.rstrip("/") == _url.rstrip("/"):
            compat_models = _models
            break

    return {
        "default_provider": get_default_provider(),
        "providers": {
            "ollama": {"available": True, "models": ollama_models},
            "claude": {
                "available": bool(os.environ.get("ANTHROPIC_API_KEY")),
                "models": DEFAULT_MODELS["claude"],
            },
            "openai": {
                "available": bool(os.environ.get("OPENAI_API_KEY")),
                "models": DEFAULT_MODELS["openai"],
            },
            "openai-compat": {
                "available": bool(compat_base and compat_key),
                "base_url": compat_base or None,
                "presets": list(OPENAI_COMPAT_PRESETS.keys()),
                "models": compat_models,
            },
        },
    }


def llm_json(messages: list, system: str, model: str, provider: str,
             schema: dict | None = None, **opts) -> dict:
    """
    Structured JSON call — always returns a parsed dict, never raw text.

    Uses each provider's native structured-output mechanism:
      - Ollama:        format="json" (or JSON Schema on Ollama ≥0.5)
      - OpenAI:        response_format json_schema (falls back to json_object)
      - OpenAI-compat: response_format json_object (schema ignored — compat APIs vary)
      - Claude:        tool_use with json_output tool (schema used directly)

    schema: optional JSON Schema dict that constrains the output shape.
            If omitted, any valid JSON object is accepted.
    """
    if provider == "claude":
        return _claude_json(messages, system, model, schema=schema, **opts)
    elif provider == "openai":
        return _openai_json(messages, system, model,
                            base_url="https://api.openai.com/v1",
                            api_key=os.environ.get("OPENAI_API_KEY", ""),
                            schema=schema, **opts)
    elif provider == "openai-compat":
        return _openai_json(messages, system, model,
                            base_url=os.environ.get("OPENAI_COMPAT_BASE_URL", "").rstrip("/"),
                            api_key=os.environ.get("OPENAI_COMPAT_API_KEY", ""),
                            schema=None, **opts)  # compat APIs rarely support json_schema
    else:
        return _ollama_json(messages, system, model, schema=schema, **opts)


def llm_chat(messages: list, system: str, model: str, provider: str, **opts) -> str:
    """Blocking LLM call — returns full response string."""
    if provider == "claude":
        return _claude_chat(messages, system, model, **opts)
    elif provider == "openai":
        return _openai_chat(messages, system, model,
                            base_url="https://api.openai.com/v1",
                            api_key=os.environ.get("OPENAI_API_KEY", ""), **opts)
    elif provider == "openai-compat":
        return _openai_chat(messages, system, model,
                            base_url=os.environ.get("OPENAI_COMPAT_BASE_URL", "").rstrip("/"),
                            api_key=os.environ.get("OPENAI_COMPAT_API_KEY", ""), **opts)
    else:
        return _ollama_chat(messages, system, model, **opts)


def llm_stream(messages: list, system: str, model: str, provider: str, **opts) -> Iterator[str]:
    """Streaming LLM call — yields string tokens."""
    if provider == "claude":
        yield from _claude_stream(messages, system, model, **opts)
    elif provider == "openai":
        yield from _openai_stream(messages, system, model,
                                   base_url="https://api.openai.com/v1",
                                   api_key=os.environ.get("OPENAI_API_KEY", ""), **opts)
    elif provider == "openai-compat":
        yield from _openai_stream(messages, system, model,
                                   base_url=os.environ.get("OPENAI_COMPAT_BASE_URL", "").rstrip("/"),
                                   api_key=os.environ.get("OPENAI_COMPAT_API_KEY", ""), **opts)
    else:
        yield from _ollama_stream(messages, system, model, **opts)


# ── Ollama ───────────────────────────────────────────────────────────────────

def _ollama_json(messages: list, system: str, model: str,
                 schema: dict | None = None, **opts) -> dict:
    """JSON-mode Ollama call with repair + retry (Roadmap #53).

    Pipeline:
      1. Call Ollama with format=schema (constrained) or "json".
      2. Repair: try direct parse, then strip prose, then close truncation.
      3. If repair fails, retry once with temperature=0, a stricter system
         prompt suffix, and 1.5× max_tokens (to dodge truncation).
      4. If the retry also fails, raise LLMError with the model's raw output
         attached so callers can surface a useful message.

    Roadmap #55 — pass ``on_token`` to stream tokens. When provided, the
    first attempt uses Ollama's streaming endpoint, accumulates the full
    response, and invokes ``on_token(chunk)`` for each token chunk so the
    trace viewer can show progress mid-generation. The retry still uses the
    blocking endpoint (we already have the buffer to work from).
    """
    url = os.environ.get("OLLAMA_URL", "http://localhost:11434") + "/api/chat"
    # Roadmap #56: pick num_ctx based on model size when OLLAMA_CTX is unset.
    ctx = resolve_ctx(model)
    on_token = opts.get("on_token")  # Callable[[str], None] | None

    def _call(sys_msg: str, temperature: float, max_tokens: int) -> str:
        msgs = ([{"role": "system", "content": sys_msg}] if sys_msg else []) + list(messages)
        payload: dict = {
            "model": model, "messages": msgs, "stream": False,
            "format": schema or "json",
            "options": {
                "temperature": temperature,
                "num_predict": max_tokens,
                "num_ctx": ctx,
            },
        }
        req = urllib.request.Request(
            url, data=json.dumps(payload).encode(),
            headers={"Content-Type": "application/json"}, method="POST",
        )
        with urllib.request.urlopen(req, timeout=300) as resp:
            data = json.loads(resp.read())
        return data["message"]["content"]

    def _call_streaming(sys_msg: str, temperature: float, max_tokens: int) -> str:
        """Stream the same payload via /api/chat with stream=True, calling
        ``on_token(chunk)`` per content delta. Returns the full accumulated
        text so callers can run JSON repair on it."""
        msgs = ([{"role": "system", "content": sys_msg}] if sys_msg else []) + list(messages)
        payload: dict = {
            "model": model, "messages": msgs, "stream": True,
            "format": schema or "json",
            "options": {
                "temperature": temperature,
                "num_predict": max_tokens,
                "num_ctx": ctx,
            },
        }
        req = urllib.request.Request(
            url, data=json.dumps(payload).encode(),
            headers={"Content-Type": "application/json"}, method="POST",
        )
        buf: list[str] = []
        with urllib.request.urlopen(req, timeout=300) as resp:
            for line in resp:
                if not line.strip():
                    continue
                try:
                    obj = json.loads(line)
                except json.JSONDecodeError:
                    continue
                if obj.get("error"):
                    raise LLMError(obj["error"], 500)
                chunk = (obj.get("message") or {}).get("content", "")
                if chunk:
                    buf.append(chunk)
                    if on_token is not None:
                        try:
                            on_token(chunk)
                        except Exception:
                            pass  # never let the consumer break the model call
                if obj.get("done"):
                    break
        return "".join(buf)

    base_temp = float(opts.get("temperature", 0.1))
    base_max = int(opts.get("max_tokens", 4000))
    base_system = (system or "") + _JSON_ONLY_SUFFIX

    # Attempt 1: normal call with the JSON-only suffix. Stream when the
    # caller wants token events, otherwise fall back to the blocking path.
    try:
        if on_token is not None:
            raw = _call_streaming(base_system, base_temp, base_max)
        else:
            raw = _call(base_system, base_temp, base_max)
    except urllib.error.URLError as e:
        raise LLMError(f"Ollama unreachable: {e}", 503)
    try:
        return _repair_json_text(raw)
    except ValueError:
        pass

    # Attempt 2: retry deterministic, more budget, stricter prompt.
    strict_system = (
        base_system
        + "\n\nThe previous attempt produced unparseable text. "
          "Return ONLY the JSON object now — no apologies, no prose."
    )
    try:
        raw_retry = _call(strict_system, 0.0, int(base_max * 1.5))
    except urllib.error.URLError as e:
        raise LLMError(f"Ollama unreachable on retry: {e}", 503)
    try:
        return _repair_json_text(raw_retry)
    except ValueError as e:
        snippet = (raw_retry or raw or "")[:200].replace("\n", " ")
        raise LLMError(
            f"Ollama returned unparseable JSON after repair + retry: {e}. "
            f"First 200 chars: {snippet!r}",
            502,
        )


def _ollama_chat(messages: list, system: str, model: str, **opts) -> str:
    return "".join(_ollama_stream(messages, system, model, **opts))


def _ollama_stream(messages: list, system: str, model: str, **opts) -> Iterator[str]:
    url = os.environ.get("OLLAMA_URL", "http://localhost:11434") + "/api/chat"
    ctx = resolve_ctx(model)  # Roadmap #56: model-aware context window
    msgs: list = ([{"role": "system", "content": system}] if system else []) + list(messages)
    payload: dict = {
        "model": model, "messages": msgs, "stream": True,
        "options": {
            "temperature": float(opts.get("temperature", 0.3)),
            "num_predict": int(opts.get("max_tokens", 8000)),
            "num_ctx": ctx, "top_p": 0.85, "repeat_penalty": 1.1,
        },
    }
    if opts.get("format_json", False):
        payload["format"] = "json"
    req = urllib.request.Request(url, data=json.dumps(payload).encode(),
                                  headers={"Content-Type": "application/json"}, method="POST")
    try:
        with urllib.request.urlopen(req, timeout=300) as resp:
            for raw in resp:
                if not raw.strip(): continue
                try: chunk = json.loads(raw)
                except json.JSONDecodeError: continue
                if chunk.get("error"): raise LLMError(chunk["error"], 500)
                content = (chunk.get("message") or {}).get("content", "")
                if content: yield content
                if chunk.get("done"): return
    except urllib.error.URLError as e:
        raise LLMError(f"Ollama unreachable: {e}", 503)


# ── Claude ───────────────────────────────────────────────────────────────────

def _claude_json(messages: list, system: str, model: str,
                 schema: dict | None = None, **opts) -> dict:
    api_key = os.environ.get("ANTHROPIC_API_KEY", "")
    if not api_key:
        raise LLMError("ANTHROPIC_API_KEY not set", 503)
    tool_schema = schema or {"type": "object", "additionalProperties": True}
    payload: dict = {
        "model": model,
        "max_tokens": int(opts.get("max_tokens", 4000)),
        "messages": list(messages),
        "tools": [{
            "name": "json_output",
            "description": "Return the structured result",
            "input_schema": tool_schema,
        }],
        "tool_choice": {"type": "tool", "name": "json_output"},
    }
    if system:
        payload["system"] = system
    req = urllib.request.Request(
        "https://api.anthropic.com/v1/messages",
        data=json.dumps(payload).encode(),
        headers={"Content-Type": "application/json", "x-api-key": api_key,
                 "anthropic-version": "2023-06-01"},
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=300) as resp:
            data = json.loads(resp.read())
        for block in data.get("content", []):
            if block.get("type") == "tool_use" and block.get("name") == "json_output":
                return block["input"]  # already a dict — no json.loads needed
        raise LLMError("Claude did not return a tool_use block", 500)
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        if e.code == 429:
            raise LLMError(f"Claude rate limit: {body[:200]}", 429)
        raise LLMError(f"Claude error {e.code}: {body[:200]}", e.code)
    except urllib.error.URLError as e:
        raise LLMError(f"Claude unreachable: {e}", 503)


def _claude_chat(messages: list, system: str, model: str, **opts) -> str:
    return "".join(_claude_stream(messages, system, model, **opts))


def _claude_stream(messages: list, system: str, model: str, **opts) -> Iterator[str]:
    api_key = os.environ.get("ANTHROPIC_API_KEY", "")
    if not api_key:
        raise LLMError("ANTHROPIC_API_KEY not set", 503)
    payload: dict = {"model": model, "max_tokens": int(opts.get("max_tokens", 8000)),
                     "stream": True, "messages": messages}
    if system:
        payload["system"] = system
    req = urllib.request.Request(
        "https://api.anthropic.com/v1/messages",
        data=json.dumps(payload).encode(),
        headers={"Content-Type": "application/json", "x-api-key": api_key,
                 "anthropic-version": "2023-06-01"},
        method="POST",
    )
    try:
        with urllib.request.urlopen(req, timeout=300) as resp:
            for raw in resp:
                line = raw.decode("utf-8", errors="replace").strip()
                if not line.startswith("data: "): continue
                data_str = line[6:]
                if data_str == "[DONE]": return
                try: event = json.loads(data_str)
                except json.JSONDecodeError: continue
                if event.get("type") == "content_block_delta":
                    delta = event.get("delta", {})
                    if delta.get("type") == "text_delta":
                        text = delta.get("text", "")
                        if text: yield text
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        if e.code == 429: raise LLMError(f"Claude rate limit: {body[:200]}", 429)
        raise LLMError(f"Claude error {e.code}: {body[:200]}", e.code)
    except urllib.error.URLError as e:
        raise LLMError(f"Claude unreachable: {e}", 503)


# ── OpenAI / OpenAI-compatible ────────────────────────────────────────────────

def _openai_json(messages: list, system: str, model: str,
                 base_url: str, api_key: str,
                 schema: dict | None = None, **opts) -> dict:
    if not api_key:
        raise LLMError(f"API key not set for {base_url or 'openai'}", 503)
    msgs: list = ([{"role": "system", "content": system}] if system else []) + list(messages)
    url = (base_url or "https://api.openai.com/v1").rstrip("/") + "/chat/completions"
    payload: dict = {
        "model": model, "messages": msgs, "stream": False,
        "temperature": float(opts.get("temperature", 0.1)),
        "max_tokens": int(opts.get("max_tokens", 4000)),
    }
    if schema:
        payload["response_format"] = {
            "type": "json_schema",
            "json_schema": {"name": "response", "schema": schema, "strict": True},
        }
    else:
        payload["response_format"] = {"type": "json_object"}
    req = urllib.request.Request(url, data=json.dumps(payload).encode(),
                                  headers={"Content-Type": "application/json",
                                           "Authorization": f"Bearer {api_key}"},
                                  method="POST")
    try:
        with urllib.request.urlopen(req, timeout=300) as resp:
            data = json.loads(resp.read())
        return json.loads(data["choices"][0]["message"]["content"])
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        if e.code == 429:
            raise LLMError(f"Rate limit: {body[:200]}", 429)
        raise LLMError(f"Provider error {e.code}: {body[:200]}", e.code)
    except urllib.error.URLError as e:
        raise LLMError(f"Provider unreachable: {e}", 503)


def _openai_chat(messages: list, system: str, model: str,
                 base_url: str, api_key: str, **opts) -> str:
    return "".join(_openai_stream(messages, system, model, base_url=base_url, api_key=api_key, **opts))


def _openai_stream(messages: list, system: str, model: str,
                   base_url: str, api_key: str, **opts) -> Iterator[str]:
    if not api_key:
        raise LLMError(f"API key not set for {base_url or 'openai'}", 503)
    msgs: list = ([{"role": "system", "content": system}] if system else []) + list(messages)
    url = (base_url or "https://api.openai.com/v1").rstrip("/") + "/chat/completions"
    payload = {"model": model, "messages": msgs, "stream": True,
               "temperature": float(opts.get("temperature", 0.3)),
               "max_tokens": int(opts.get("max_tokens", 8000))}
    req = urllib.request.Request(url, data=json.dumps(payload).encode(),
                                  headers={"Content-Type": "application/json",
                                           "Authorization": f"Bearer {api_key}"},
                                  method="POST")
    try:
        with urllib.request.urlopen(req, timeout=300) as resp:
            for raw in resp:
                line = raw.decode("utf-8", errors="replace").strip()
                if not line.startswith("data: "): continue
                data_str = line[6:]
                if data_str == "[DONE]": return
                try: event = json.loads(data_str)
                except json.JSONDecodeError: continue
                content = (event.get("choices") or [{}])[0].get("delta", {}).get("content", "")
                if content: yield content
    except urllib.error.HTTPError as e:
        body = e.read().decode("utf-8", errors="replace")
        if e.code == 429: raise LLMError(f"Rate limit: {body[:200]}", 429)
        raise LLMError(f"Provider error {e.code}: {body[:200]}", e.code)
    except urllib.error.URLError as e:
        raise LLMError(f"Provider unreachable: {e}", 503)
