"""
LLM request queue for QA Copilot.

Fair scheduling — one request per user at a time. Additional requests
from the same user are queued. Different users run concurrently (up to
max_concurrent workers).
"""
from __future__ import annotations

import os
import queue
import threading
import time
from collections import defaultdict

MAX_CONCURRENT = int(os.environ.get("QA_LLM_MAX_CONCURRENT", "2"))

_queue: queue.Queue = queue.Queue()
_active: dict[str, int] = defaultdict(int)  # user_id -> active request count
_lock = threading.Lock()
_semaphore = threading.Semaphore(MAX_CONCURRENT)


class LLMRequest:
    def __init__(self, user_id: str, fn, args, kwargs):
        self.user_id = user_id
        self.fn = fn
        self.args = args
        self.kwargs = kwargs
        self.result = None
        self.error = None
        self.done = threading.Event()

    def wait(self, timeout=600):
        """Block until the request completes. Returns result or raises error."""
        self.done.wait(timeout=timeout)
        if self.error:
            raise self.error
        return self.result


def submit(user_id: str, fn, *args, **kwargs) -> LLMRequest:
    """Submit an LLM call to the queue. Returns an LLMRequest to wait on."""
    req = LLMRequest(user_id, fn, args, kwargs)
    _queue.put(req)
    return req


def _worker():
    """Worker thread that processes LLM requests."""
    while True:
        try:
            req = _queue.get(timeout=1)
        except queue.Empty:
            continue
        _semaphore.acquire()
        try:
            with _lock:
                _active[req.user_id] += 1
            req.result = req.fn(*req.args, **req.kwargs)
        except Exception as e:
            req.error = e
        finally:
            with _lock:
                _active[req.user_id] -= 1
                if _active[req.user_id] <= 0:
                    del _active[req.user_id]
            _semaphore.release()
            req.done.set()
            _queue.task_done()


def get_queue_status() -> dict:
    """Return current queue status."""
    with _lock:
        return {
            "queue_size": _queue.qsize(),
            "active_requests": dict(_active),
            "max_concurrent": MAX_CONCURRENT,
        }


# Start worker threads
for _ in range(MAX_CONCURRENT):
    t = threading.Thread(target=_worker, daemon=True)
    t.start()
