"""
Caching layer for QA Copilot (Roadmap 6.1).

Tries Redis (QA_REDIS_URL), falls back to in-memory LRU dict.
Thread-safe. TTL support.
"""
from __future__ import annotations

import hashlib
import json
import os
import threading
import time

_REDIS_URL = os.environ.get("QA_REDIS_URL", "")
_redis_client = None
_use_redis = False

# In-memory fallback
_mem_store: dict[str, tuple[float, any]] = {}  # key -> (expires_at, value)
_mem_lock = threading.Lock()
_MAX_MEM_ENTRIES = 1000


def _init_redis():
    """Try to connect to Redis. Falls back to in-memory if unavailable."""
    global _redis_client, _use_redis
    if not _REDIS_URL:
        return
    try:
        import redis
        _redis_client = redis.from_url(_REDIS_URL, decode_responses=True)
        _redis_client.ping()
        _use_redis = True
        print(f"  ✓ Cache: Redis connected ({_REDIS_URL})")
    except Exception as e:
        print(f"  ⚠  Cache: Redis unavailable ({e}), using in-memory fallback")
        _use_redis = False


def cache_get(key: str) -> any:
    """Get a value from cache. Returns None if not found or expired."""
    if _use_redis and _redis_client:
        try:
            val = _redis_client.get(f"qac:{key}")
            return json.loads(val) if val else None
        except Exception:
            pass
    # In-memory fallback
    with _mem_lock:
        entry = _mem_store.get(key)
        if entry is None:
            return None
        expires_at, value = entry
        if expires_at and time.monotonic() > expires_at:
            del _mem_store[key]
            return None
        return value


def cache_set(key: str, value: any, ttl_seconds: int = 3600) -> None:
    """Set a value in cache with TTL."""
    if _use_redis and _redis_client:
        try:
            _redis_client.setex(f"qac:{key}", ttl_seconds, json.dumps(value, ensure_ascii=False))
            return
        except Exception:
            pass
    # In-memory fallback
    with _mem_lock:
        # LRU eviction if at capacity
        if len(_mem_store) >= _MAX_MEM_ENTRIES:
            oldest_key = min(_mem_store, key=lambda k: _mem_store[k][0] or 0)
            del _mem_store[oldest_key]
        expires_at = time.monotonic() + ttl_seconds if ttl_seconds else 0
        _mem_store[key] = (expires_at, value)


def cache_delete(key: str) -> None:
    """Delete a key from cache."""
    if _use_redis and _redis_client:
        try:
            _redis_client.delete(f"qac:{key}")
        except Exception:
            pass
    with _mem_lock:
        _mem_store.pop(key, None)


def cache_clear() -> None:
    """Clear all cached entries."""
    if _use_redis and _redis_client:
        try:
            for key in _redis_client.scan_iter("qac:*"):
                _redis_client.delete(key)
        except Exception:
            pass
    with _mem_lock:
        _mem_store.clear()


def make_cache_key(*parts: str) -> str:
    """Create a deterministic cache key from parts."""
    raw = ":".join(str(p) for p in parts)
    return hashlib.sha256(raw.encode()).hexdigest()[:16]


# Initialize on import
_init_redis()
