"""
generation.py — Generation pipeline for QA Copilot.

Owns all test-artifact generation logic:
  - Self-review (multi-pass LLM review of generated artifacts)
  - Coverage-aware context injection (stub)
  - Template-based scaffolding (stub)
  - Confidence scoring (stub)
  - generate() main entry point

Lazy-imports from server.py to avoid circular dependencies.
"""

import datetime
import hashlib
import json
import os
import re
import sqlite3
from pathlib import Path

import cache
import fine_tuning
from llm_providers import llm_chat, llm_json

_TEMPLATES_DIR = str(Path(__file__).resolve().parent / "templates")

# ---------------------------------------------------------------------------
# Lazy imports from server.py (avoids circular dependency)
# ---------------------------------------------------------------------------
_server_mod = None


def _ensure_server():
    global _server_mod
    if _server_mod is None:
        import server as _s
        _server_mod = _s


def _extract_json(text: str) -> dict:
    _ensure_server()
    return _server_mod._extract_json(text)


def _get_module_schema(module: str) -> dict | None:
    _ensure_server()
    return _server_mod._get_module_schema(module)


# ---------------------------------------------------------------------------
# Self-review constants and functions (moved from server.py)
# ---------------------------------------------------------------------------

_SELF_REVIEW_SYSTEM = """\
You are a senior QA reviewer. You receive a JSON test artifact generated from a user story.
Your job is to review it and return an IMPROVED version of the same JSON (same schema, same keys).

Check for and fix:
- Missing negative / unhappy-path scenarios (invalid input, unauthorized access, timeouts)
- Missing boundary conditions (empty strings, zero values, max-length inputs)
- Missing edge cases (concurrent access, special characters, locale-specific behavior)
- Duplicate or near-duplicate test cases — remove redundancy
- Missing or incomplete preconditions
- Ambiguous or vague expected results — make them specific and verifiable
- Missing security / validation scenarios (XSS, SQL injection, CSRF where relevant)

Rules:
- Return ONLY the improved JSON — no markdown fences, no commentary
- Preserve the original JSON structure and schema exactly
- Keep all valid existing items; only add, remove duplicates, or clarify
- Do NOT invent features not implied by the user story
"""

_MAX_TOKENS_DEFAULT = 8000


# ── Fast-mode detection (Roadmap #52) ────────────────────────────────────────
#
# Small local models (7B/8B parameter range) get marginal benefit from
# self-review + confidence scoring + CoT and pay a huge latency cost for them.
# Auto-enable "fast mode" for those: skip the secondary LLM passes unless
# QA_FAST_MODE is explicitly set ("0" forces full pipeline, "1" forces fast).
#
# Cost of full pipeline on qwen2.5-coder:7b for a single cas generation:
#   - CoT analyze + plan + generate + validate = 4 LLM calls
#   - Self-review                              = 1 LLM call
#   - Confidence scoring                       = 1 LLM call
#   - Total                                    = 6+ LLM calls ≈ 60-90 s
# Fast mode collapses this to 1 call ≈ 8-15 s.

_SMALL_MODEL_PATTERNS = (
    ":7b", ":8b", ":1b", ":3b", ":1.5b", ":3.5b",
    "-7b", "-8b", "-1b", "-3b",
)
_HAIKU_PATTERNS = ("haiku",)  # Claude Haiku is fast/cheap — treat as small for our purposes


def _is_small_model(model: str) -> bool:
    """True when the model is in the 7B/8B class (or Claude Haiku)."""
    if not model:
        return False
    name = model.lower()
    if any(p in name for p in _SMALL_MODEL_PATTERNS):
        return True
    if any(p in name for p in _HAIKU_PATTERNS):
        return True
    return False


def is_fast_mode(model: str) -> bool:
    """Whether to short-circuit the optional pipeline stages for this call.

    Resolution order:
      1. QA_FAST_MODE=1 → always fast.
      2. QA_FAST_MODE=0 → never fast (explicit opt-out).
      3. Unset → auto: small models → fast, larger models → full pipeline.
    """
    env = os.environ.get("QA_FAST_MODE")
    if env == "1":
        return True
    if env == "0":
        return False
    return _is_small_model(model)


def self_review(generated: dict, user_story: str, provider: str, model: str) -> dict:
    """Run a second LLM pass to review and patch generated test artifacts.

    Checks ``QA_SELF_REVIEW`` env var at call time (not module load time).
    Returns the reviewed output, or the original on any failure.
    """
    if os.environ.get("QA_SELF_REVIEW", "1") == "0":
        return generated

    try:
        max_tokens = int(os.environ.get("QA_MAX_TOKENS", _MAX_TOKENS_DEFAULT))
        review_messages = [
            {"role": "user", "content": (
                f"## Original user story\n{user_story}\n\n"
                f"## Generated test artifact to review\n```json\n"
                f"{json.dumps(generated, ensure_ascii=False, indent=2)}\n```"
            )},
        ]
        raw = llm_chat(
            review_messages, _SELF_REVIEW_SYSTEM, model, provider,
            temperature=0.1, max_tokens=max_tokens, format_json=True,
        )
        reviewed = _extract_json(raw)
        print(
            f"  🔄 Self-review pass applied "
            f"({len(json.dumps(generated))} → {len(json.dumps(reviewed))} chars)"
        )
        return reviewed
    except Exception as exc:
        print(f"  ⚠  Self-review skipped (error): {exc}")
        return generated


def parse_and_review(raw: str, user_story: str, provider: str, model: str) -> str:
    """Parse raw LLM output, run self-review, return updated JSON string.

    Falls back to the original raw string on any failure.
    """
    try:
        parsed = _extract_json(raw)
        reviewed = self_review(parsed, user_story, provider, model)
        return json.dumps(reviewed, ensure_ascii=False)
    except Exception as e:
        print(f"  ⚠  Parse-and-review skipped: {e}")
        return raw


# ---------------------------------------------------------------------------
# Chain-of-thought decomposition (Roadmap 2.1)
# ---------------------------------------------------------------------------

_COT_ANALYZE_SYSTEM = """\
You are a requirements analyst for QA test generation.
Given a user story, extract:
- criteria: acceptance criteria implied or stated
- boundaries: boundary conditions and limits
- ambiguities: unclear or missing requirements

Return structured JSON only.
"""

_COT_ANALYZE_SCHEMA = {
    "type": "object",
    "properties": {
        "criteria": {"type": "array", "items": {"type": "string"}},
        "boundaries": {"type": "array", "items": {"type": "string"}},
        "ambiguities": {"type": "array", "items": {"type": "string"}},
    },
    "required": ["criteria", "boundaries", "ambiguities"],
}

_COT_PLAN_SYSTEM = """\
You are a test planner for QA test generation.
Given an analysis of a user story (criteria, boundaries, ambiguities),
create a list of test scenarios that cover all criteria and boundaries.

Each scenario should have:
- title: short descriptive name
- type: positive, negative, boundary, or edge
- covers: list of criteria/boundaries this scenario addresses

Return structured JSON only.
"""

_COT_PLAN_SCHEMA = {
    "type": "object",
    "properties": {
        "scenarios": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title": {"type": "string"},
                    "type": {"type": "string"},
                    "covers": {"type": "array", "items": {"type": "string"}},
                },
                "required": ["title", "type", "covers"],
            },
        },
    },
    "required": ["scenarios"],
}

_COT_VALIDATE_SYSTEM = """\
You are a completeness validator for QA test generation.
You receive a generated test artifact and the test plan it was based on.
Check that every planned scenario is covered in the generated output.
If any scenarios are missing, add them. Remove duplicates.
Return the improved artifact in the SAME schema as the input — no markdown, no commentary.
"""


def _cot_analyze(user_story: str, provider: str, model: str) -> dict:
    """Step 1: Analyze user story to extract criteria, boundaries, ambiguities."""
    messages = [{"role": "user", "content": f"Analyze this user story:\n\n{user_story}"}]
    return llm_json(
        messages, _COT_ANALYZE_SYSTEM, model, provider,
        schema=_COT_ANALYZE_SCHEMA, temperature=0.1,
    )


def _cot_plan(analysis: dict, user_story: str, provider: str, model: str) -> dict:
    """Step 2: Create test plan from analysis."""
    content = (
        f"## User story\n{user_story}\n\n"
        f"## Analysis\n```json\n{json.dumps(analysis, ensure_ascii=False, indent=2)}\n```"
    )
    messages = [{"role": "user", "content": content}]
    return llm_json(
        messages, _COT_PLAN_SYSTEM, model, provider,
        schema=_COT_PLAN_SCHEMA, temperature=0.1,
    )


def _cot_validate(
    generated: dict, plan: dict, user_story: str,
    provider: str, model: str, module: str,
) -> dict:
    """Step 4: Validate generated output against plan, patch gaps."""
    mod_schema = _get_module_schema(module)
    content = (
        f"## User story\n{user_story}\n\n"
        f"## Test plan\n```json\n{json.dumps(plan, ensure_ascii=False, indent=2)}\n```\n\n"
        f"## Generated artifact\n```json\n{json.dumps(generated, ensure_ascii=False, indent=2)}\n```"
    )
    messages = [{"role": "user", "content": content}]
    kwargs: dict = {"temperature": 0.1}
    if mod_schema:
        kwargs["schema"] = mod_schema
    return llm_json(messages, _COT_VALIDATE_SYSTEM, model, provider, **kwargs)


# ---------------------------------------------------------------------------
# Coverage-aware generation (Roadmap 2.2)
# ---------------------------------------------------------------------------

_JIRA_KEY_RE = re.compile(r'\b([A-Z][A-Z0-9_]+-\d+)\b')


def _db_get_covered(story_key: str) -> list:
    """Proxy to server._db_get_covered. Lazy import to avoid circular deps."""
    global _server_mod
    if _server_mod is None:
        import server as _s
        _server_mod = _s
    return _server_mod._db_get_covered(story_key)


def _get_coverage_context(user_text: str) -> str:
    """Return existing test coverage summary for injection into prompt.
    Returns empty string if no coverage found or story key not detected."""
    cache_key = cache.make_cache_key("coverage", user_text)
    cached = cache.cache_get(cache_key)
    if cached is not None:
        return cached
    m = _JIRA_KEY_RE.search(user_text)
    if not m:
        cache.cache_set(cache_key, "", ttl_seconds=3600)  # 1h TTL
        return ""
    story_key = m.group(1)
    try:
        covered = _db_get_covered(story_key)
        if not covered:
            cache.cache_set(cache_key, "", ttl_seconds=3600)  # 1h TTL
            return ""
        lines = [f"- {t['title']} ({t.get('module', 'unknown')})" for t in covered]
        print(f"  📊 Coverage: {len(covered)} existing tests for {story_key}")
        result = "\n".join(lines)
        cache.cache_set(cache_key, result, ttl_seconds=3600)  # 1h TTL
        return result
    except Exception as e:
        print(f"  ⚠  Coverage check failed: {e}")
        return ""


# ── Differential Generation History (Roadmap 3.2) ──────────────────────

def _get_coverage_db_path() -> str:
    _ensure_server()
    return str(_server_mod.COVERAGE_DB)


def _get_generation_history(
    story_key: str,
    *,
    module: str = "",
    framework: str = "",
    db_path: str | None = None,
) -> dict | None:
    """Most-recent generation snapshot for a (story_key, module, framework) tuple.

    Roadmap #57: scope by module + framework so cas/code/perf for the same
    story don't cross-pollute each other's cached result.
    """
    db = db_path or _get_coverage_db_path()
    conn = sqlite3.connect(db)
    conn.row_factory = sqlite3.Row
    try:
        row = conn.execute(
            "SELECT * FROM generation_history "
            "WHERE story_key = ? AND module = ? AND framework = ? "
            "ORDER BY created_at DESC LIMIT 1",
            (story_key, module or "", framework or ""),
        ).fetchone()
        return dict(row) if row else None
    finally:
        conn.close()


def _save_generation_history(
    story_key: str,
    story_content: str,
    generated: dict,
    *,
    module: str = "",
    framework: str = "",
    db_path: str | None = None,
) -> None:
    """Save a generation snapshot for future diffing, scoped by module + framework."""
    db = db_path or _get_coverage_db_path()
    content_hash = hashlib.sha256(story_content.encode()).hexdigest()[:16]
    now = datetime.datetime.now(datetime.timezone.utc).isoformat()
    conn = sqlite3.connect(db)
    try:
        conn.execute(
            "INSERT INTO generation_history "
            "(story_key, story_content_hash, story_content, generated_json, "
            " created_at, module, framework) "
            "VALUES (?, ?, ?, ?, ?, ?, ?)",
            (
                story_key, content_hash, story_content,
                json.dumps(generated, ensure_ascii=False),
                now, module or "", framework or "",
            ),
        )
        conn.commit()
    finally:
        conn.close()


# ---------------------------------------------------------------------------
# Differential Generation — Diff & Merge (Roadmap 3.2)
# ---------------------------------------------------------------------------

_DIFF_SYSTEM = """\
You are a requirements analyst. You receive two versions of a user story.
Compare them and identify:
- new_criteria: acceptance criteria present in the NEW version but not the OLD
- changed_criteria: criteria that exist in both but were modified (include before/after)
- removed_criteria: criteria present in the OLD version but removed from the NEW

Return structured JSON only.
"""

_DIFF_SCHEMA = {
    "type": "object",
    "properties": {
        "new_criteria": {"type": "array", "items": {"type": "string"}},
        "changed_criteria": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "before": {"type": "string"},
                    "after": {"type": "string"},
                },
                "required": ["before", "after"],
            },
        },
        "removed_criteria": {"type": "array", "items": {"type": "string"}},
    },
    "required": ["new_criteria", "changed_criteria", "removed_criteria"],
}


def _compute_story_diff(old_content: str, new_content: str, provider: str, model: str) -> dict:
    """Call LLM to compute structured diff between two story versions."""
    content = (
        f"## OLD version\n{old_content}\n\n"
        f"## NEW version\n{new_content}"
    )
    messages = [{"role": "user", "content": content}]
    print(f"  🔀 Computing story diff ({len(old_content)} → {len(new_content)} chars)")
    return llm_json(
        messages, _DIFF_SYSTEM, model, provider,
        schema=_DIFF_SCHEMA, temperature=0.1,
    )


def _merge_results(old_generated: dict, delta_generated: dict, removed: list) -> dict:
    """Merge old generated tests with delta, filtering out removed criteria.

    Uses title substring matching to identify tests that correspond to removed criteria.
    """
    merged = {}
    for key, old_items in old_generated.items():
        if not isinstance(old_items, list):
            merged[key] = old_items
            continue
        # Filter out tests matching removed criteria
        kept = []
        for item in old_items:
            title = item.get("title", "") if isinstance(item, dict) else str(item)
            should_remove = any(
                removed_crit.lower() in title.lower()
                for removed_crit in removed
            )
            if not should_remove:
                kept.append(item)
        # Append delta items
        delta_items = delta_generated.get(key, [])
        if isinstance(delta_items, list):
            kept.extend(delta_items)
        merged[key] = kept

    # Add keys that only exist in delta
    for key, delta_items in delta_generated.items():
        if key not in merged:
            merged[key] = delta_items

    return merged


def generate_differential(
    story_key: str,
    story_content: str,
    provider: str,
    model: str,
    module: str = "cas",
    framework: str = "",
    **opts,
) -> dict | None:
    """Attempt differential generation for a known story.

    Scoped by (story_key, module, framework) so cas/code/perf snapshots for
    the same story remain isolated (Roadmap #57).

    Returns:
      - None if no history exists (caller should do full generation)
      - Cached result if content hash matches
      - Merged result if story changed (diff + delta generation + merge)
    """
    history = _get_generation_history(story_key, module=module, framework=framework)
    if history is None:
        return None

    content_hash = hashlib.sha256(story_content.encode()).hexdigest()[:16]
    if history["story_content_hash"] == content_hash:
        print(f"  ♻️  Differential: same content hash for {story_key}, returning cached")
        return json.loads(history["generated_json"])

    # Story changed — compute diff and generate delta
    old_content = history["story_content"]
    old_generated = json.loads(history["generated_json"])

    diff = _compute_story_diff(old_content, story_content, provider, model)

    new_criteria = diff.get("new_criteria", [])
    changed_criteria = diff.get("changed_criteria", [])
    removed_criteria = diff.get("removed_criteria", [])

    # Build delta prompt from new + changed criteria
    delta_items = list(new_criteria) + [c.get("after", "") for c in changed_criteria]

    if delta_items:
        delta_prompt = (
            f"Generate test cases for these NEW or CHANGED requirements only:\n"
            + "\n".join(f"- {item}" for item in delta_items)
            + f"\n\nOriginal story context:\n{story_content}"
        )
        mod_schema = _get_module_schema(module)
        kwargs: dict = {"temperature": 0.3}
        if mod_schema:
            kwargs["schema"] = mod_schema
        delta_generated = llm_json(
            [{"role": "user", "content": delta_prompt}],
            "You are a QA test generator. Generate test cases for the given requirements. Return JSON only.",
            model, provider, **kwargs,
        )
        print(f"  ➕ Differential: generated {len(delta_items)} delta criteria for {story_key}")
    else:
        delta_generated = {}

    result = _merge_results(old_generated, delta_generated, removed_criteria)
    return result


# ---------------------------------------------------------------------------
# Stubs — implemented in later tasks
# ---------------------------------------------------------------------------


# Template cache (Roadmap #52). Templates rarely change between requests, so
# parse them once and reuse. Invalidate when the dir's mtime moves so editing a
# template hot-reloads on the next generation call.
_TEMPLATE_CACHE: dict[tuple[str, str], str] = {}
_TEMPLATE_CACHE_MTIME: float = -1.0


def _templates_dir_mtime() -> float:
    if not os.path.isdir(_TEMPLATES_DIR):
        return 0.0
    try:
        latest = os.path.getmtime(_TEMPLATES_DIR)
        for fname in os.listdir(_TEMPLATES_DIR):
            if fname.endswith(".json"):
                latest = max(latest, os.path.getmtime(os.path.join(_TEMPLATES_DIR, fname)))
        return latest
    except OSError:
        return 0.0


def _get_template_context(module: str, framework: str = "") -> str:
    """Return template skeleton for injection into prompt.

    Cached per (module, framework) within a process. Invalidates if any
    template file is touched. Returns empty string if no matching template.
    """
    global _TEMPLATE_CACHE_MTIME
    if not os.path.isdir(_TEMPLATES_DIR):
        return ""

    current_mtime = _templates_dir_mtime()
    if current_mtime != _TEMPLATE_CACHE_MTIME:
        _TEMPLATE_CACHE.clear()
        _TEMPLATE_CACHE_MTIME = current_mtime

    key = (module, framework)
    if key in _TEMPLATE_CACHE:
        return _TEMPLATE_CACHE[key]

    result = ""
    for fname in os.listdir(_TEMPLATES_DIR):
        if not fname.endswith(".json"):
            continue
        try:
            with open(os.path.join(_TEMPLATES_DIR, fname)) as f:
                tmpl = json.load(f)
            matches = tmpl.get("matches", {})
            modules = matches.get("module", [])
            frameworks = matches.get("framework", [])
            if module in modules and (not framework or framework in frameworks):
                skeleton = tmpl.get("skeleton", {})
                instructions = tmpl.get("instructions", "")
                parts = [f"### {tmpl.get('name', fname)}"]
                for key_, val in skeleton.items():
                    parts.append(f"**{key_}:**\n{val}")
                if instructions:
                    parts.append(f"\n{instructions}")
                print(f"  📄 Template: loaded {tmpl.get('name', fname)}")
                result = "\n\n".join(parts)
                break
        except Exception:
            continue
    _TEMPLATE_CACHE[key] = result
    return result


_CONFIDENCE_SYSTEM = (
    "You are a test quality evaluator. Rate each test case on a 1-5 scale.\n\n"
    "Dimensions:\n"
    "- correctness: Does the test verify the right behavior?\n"
    "- completeness: Are preconditions, steps, and expected results all present?\n"
    "- clarity: Is the test easy to understand and unambiguous?\n"
    "- maintainability: Will this test be easy to maintain as requirements change?\n"
    "- overall: Your holistic quality score (1=poor, 5=excellent)\n"
    '- flag: Brief note if overall < 4 (empty string otherwise)\n\n'
    "Return JSON: {\"scores\": [{\"title\": \"...\", \"correctness\": N, \"completeness\": N, "
    "\"clarity\": N, \"maintainability\": N, \"overall\": N, \"flag\": \"...\"}]}\n"
    "Return ONLY valid JSON."
)

_CONFIDENCE_SCHEMA = {
    "type": "object",
    "properties": {
        "scores": {
            "type": "array",
            "items": {
                "type": "object",
                "properties": {
                    "title":           {"type": "string"},
                    "correctness":     {"type": "integer"},
                    "completeness":    {"type": "integer"},
                    "clarity":         {"type": "integer"},
                    "maintainability": {"type": "integer"},
                    "overall":         {"type": "integer"},
                    "flag":            {"type": "string"},
                },
                "required": ["title", "overall"],
            },
        },
    },
    "required": ["scores"],
}


def _auto_push_threshold(provider: str, model: str) -> int:
    """Minimum 'overall' score that qualifies a test case for auto-push.

    Roadmap #58 — confidence scoring used a flat ``>= 4`` threshold calibrated
    against Claude / GPT-4 outputs. Empirically, qwen2.5-coder:7b et al.
    median around 3 on the 1-5 scale even when output is usable, so the
    auto-push lane never fires for local users. Lower the bar for small
    models (and any Ollama-provided model when uncertain).

    Override via ``QA_AUTO_PUSH_THRESHOLD`` (integer 1-5).
    """
    env = os.environ.get("QA_AUTO_PUSH_THRESHOLD")
    if env:
        try:
            v = int(env)
            if 1 <= v <= 5:
                return v
        except ValueError:
            pass
    # Small local models score conservatively — drop one notch.
    if _is_small_model(model):
        return 3
    # Ollama large models without a clear "small" hint — still local, still
    # rate themselves modestly.
    if provider == "ollama":
        return 3
    # Claude / OpenAI / openai-compat — keep the historical threshold.
    return 4


def score_confidence(generated: dict, user_story: str, provider: str, model: str) -> dict | None:
    """Score generated test cases on quality dimensions.
    Returns {"scores": [...], "auto_push_count": N, "needs_review_count": N,
             "auto_push_threshold": N}
    or None if scoring is disabled, empty, or fails."""
    if os.environ.get("QA_CONFIDENCE_SCORING", "1") == "0":
        return None

    items = (generated.get("cas")
             or generated.get("gherkin", {}).get("scenarios")
             or generated.get("code", {}).get("body", "")
             or generated.get("perf", {}).get("body", "")
             or [])
    if not items:
        return None

    threshold = _auto_push_threshold(provider, model)

    try:
        content = (
            f"## User Story\n{user_story}\n\n"
            f"## Generated Tests\n```json\n{json.dumps(items, ensure_ascii=False, indent=2)}\n```"
        )
        result = llm_json(
            [{"role": "user", "content": content}],
            _CONFIDENCE_SYSTEM, model, provider,
            schema=_CONFIDENCE_SCHEMA, temperature=0.1, max_tokens=2000,
        )
        scores = result.get("scores", [])
        auto_push = sum(1 for s in scores if s.get("overall", 0) >= threshold)
        needs_review = sum(1 for s in scores if s.get("overall", 0) < threshold)
        print(
            f"  ⭐ Confidence (threshold≥{threshold}): "
            f"{auto_push} auto-push ready, {needs_review} need review"
        )
        return {
            "scores": scores,
            "auto_push_count": auto_push,
            "needs_review_count": needs_review,
            "auto_push_threshold": threshold,
        }
    except Exception as e:
        print(f"  ⚠  Confidence scoring failed: {e}")
        return None


# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------

def generate(
    messages: list,
    system: str,
    model: str,
    provider: str,
    module: str | None = None,
    schema: dict | None = None,
    user_text: str = "",
    framework: str = "",
    on_token=None,
    **opts,
) -> dict:
    """Full generation pipeline.

    Pipeline stages:
      1. Inject coverage context
      2. Inject template context
      3. Chain-of-thought (if enabled): analyze -> plan -> generate -> validate
         OR single-shot LLM call + self-review
      4. Confidence scoring

    Returns ``{"result": dict, "confidence": dict | None}``.
    """
    # Per-project model auto-selection (Roadmap 3.3)
    project = ""
    story_key_match = _JIRA_KEY_RE.search(user_text)
    if story_key_match:
        project = story_key_match.group(1).rsplit("-", 1)[0]  # "SPEED-123" → "SPEED"
        try:
            pm = fine_tuning.get_project_model(project=project)
            if pm and provider == "ollama":
                model = pm["adapter_path"].split("/")[-1] or model
                print(f"  🎯 Project model: using {model} for project {project}")
        except Exception:
            pass

    # Differential generation (Roadmap #57). Scoped by (story_key, module,
    # framework) so cas/code/perf for the same story stay isolated. Disable
    # via QA_DIFFERENTIAL=0 if needed.
    story_key = story_key_match.group(1) if story_key_match else ""
    differential_enabled = (
        story_key
        and module
        and os.environ.get("QA_DIFFERENTIAL", "1") != "0"
    )
    if differential_enabled:
        try:
            diff_result = generate_differential(
                story_key=story_key,
                story_content=user_text,
                provider=provider,
                model=model,
                module=module,
                framework=framework,
            )
            if diff_result is not None:
                confidence = score_confidence(diff_result, user_text, provider, model)
                return {"result": diff_result, "confidence": confidence, "cached": True}
        except Exception as exc:
            print(f"  ⚠  Differential lookup failed ({exc}), running full generation")

    # 1. Coverage context
    cov_ctx = _get_coverage_context(user_text)
    if cov_ctx:
        system = system + "\n\n" + cov_ctx

    # 2. Template context
    if module:
        tpl_ctx = _get_template_context(module, framework)
        if tpl_ctx:
            system = system + "\n\n" + tpl_ctx

    # Roadmap #56: trim the assembled system prompt to roughly half the model's
    # context window so the model has room to produce a real answer. The
    # tokenizer-to-char ratio for most code models is ~3.5, so we budget at
    # (num_ctx / 2) * 3.5 chars and trim the middle if we overflow.
    try:
        from llm_providers import resolve_ctx, trim_to_budget
        budget_chars = int(resolve_ctx(model) * 0.5 * 3.5)
        if len(system) > budget_chars:
            print(f"  ✂  Trimming system prompt {len(system)} → ~{budget_chars} chars for {model}")
            system = trim_to_budget(system, budget_chars)
    except Exception:
        pass

    # 3. Chain-of-thought or single-shot
    # CoT only for text-based modules (cas, gherkin, risques, data).
    # Code/perf need direct schema-constrained generation — the CoT plan
    # steers toward test-case format and confuses code output.
    fast = is_fast_mode(model)
    if fast:
        print(f"  ⚡ Fast mode active for {model} — skipping CoT + self-review + confidence")
    cot_enabled = (
        not fast
        and os.environ.get("QA_CHAIN_OF_THOUGHT", "1") != "0"
    )
    cot_modules = {"cas", "gherkin", "risques", "data"}
    used_cot = False

    if cot_enabled and module and module in cot_modules:
        try:
            # Step 1: Analyze
            analysis = _cot_analyze(user_text, provider, model)
            # Step 2: Plan
            plan = _cot_plan(analysis, user_text, provider, model)
            # Step 3: Generate with plan context injected
            plan_ctx = (
                "\n\n## Test Plan (follow this)\n```json\n"
                + json.dumps(plan, ensure_ascii=False, indent=2)
                + "\n```"
            )
            enriched_system = system + plan_ctx
            if schema:
                result = llm_json(
                    messages, enriched_system, model, provider, schema=schema, **opts,
                )
            else:
                raw = llm_chat(messages, enriched_system, model, provider, **opts)
                result = _extract_json(raw)
            # Step 4: Validate
            result = _cot_validate(result, plan, user_text, provider, model, module)
            used_cot = True
            print("  ✅ Chain-of-thought pipeline completed")
        except Exception as exc:
            print(f"  ⚠  Chain-of-thought failed ({exc}), falling back to single-shot")
            used_cot = False

    if not used_cot:
        # Single-shot path. Forward on_token for live token streaming
        # (Roadmap #55) — the callback is invoked per token chunk from the
        # provider, and the caller (typically GenerateTestsTool) routes it
        # onto the agent's event queue for the trace viewer.
        single_shot_opts = dict(opts)
        if on_token is not None:
            single_shot_opts["on_token"] = on_token
        if schema:
            result = llm_json(
                messages, system, model, provider, schema=schema, **single_shot_opts,
            )
        else:
            raw = llm_chat(messages, system, model, provider, **single_shot_opts)
            result = _extract_json(raw)
        # Self-review — skip for code/perf (review has no schema constraint
        # and the LLM converts code output back to CAS format). Also skip
        # in fast mode for small models (Roadmap #52).
        if module not in ("code", "perf") and not fast:
            result = self_review(result, user_text, provider, model)

    # 4. Confidence scoring — skipped in fast mode (it's another LLM call).
    confidence = None if fast else score_confidence(result, user_text, provider, model)

    # 5. Persist for differential generation (Roadmap #57). Only save when we
    # have a real story key + module; skip ad-hoc prompts that wouldn't match
    # on a later call.
    if differential_enabled and isinstance(result, dict):
        try:
            _save_generation_history(
                story_key, user_text, result,
                module=module, framework=framework,
            )
        except Exception as exc:
            print(f"  ⚠  History save failed ({exc})")

    return {"result": result, "confidence": confidence, "cached": False}
