"""qa-copilot/audit_retention_sweep.py

Standalone hard-delete sweep for `audit_log` rows past per-org retention.
Designed to run nightly via cron, launchd, or systemd-timer:

    0 4 * * *  python3 /path/to/qa-copilot/audit_retention_sweep.py

Reads ``audit_retention_days`` from each org's ``settings_json``; defaults
to 365 days. Uses the Plan 1 retention-sweep escape hatch: creates a TEMP
TABLE ``_audit_retention_sweep_active`` on its own connection, which the
``audit_log_no_delete`` trigger checks for before allowing DELETE.

Spec §9.3.
"""
from __future__ import annotations

import argparse
import datetime
import json
import logging
import os
import sqlite3
import sys


def sweep(db_path: str, *, default_retention_days: int = 365) -> dict[str, int]:
    """Hard-delete expired audit rows. Returns {org_id: deleted_count}."""
    conn = sqlite3.connect(db_path)
    deleted: dict[str, int] = {}
    try:
        # Activate the sweep marker so the trigger lets DELETE through.
        conn.execute(
            "CREATE TEMP TABLE IF NOT EXISTS _audit_retention_sweep_active(x INTEGER)"
        )
        org_rows = conn.execute(
            "SELECT id, settings_json FROM organizations WHERE deleted_at IS NULL"
        ).fetchall()
        for org_id, settings_json in org_rows:
            settings = json.loads(settings_json or "{}")
            days = int(settings.get("audit_retention_days", default_retention_days))
            cutoff = (
                datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
            ).isoformat().replace("+00:00", "Z")
            cur = conn.execute(
                "DELETE FROM audit_log WHERE org_id=? AND created_at < ?",
                (org_id, cutoff),
            )
            deleted[org_id] = cur.rowcount
        conn.commit()
    finally:
        conn.close()
    return deleted


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Audit log retention sweep")
    parser.add_argument("--db", default=os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "userdata.db",
    ))
    parser.add_argument("--default-retention-days", type=int, default=365)
    args = parser.parse_args(argv)
    logging.basicConfig(level=logging.INFO, format="%(asctime)s %(message)s")
    result = sweep(args.db, default_retention_days=args.default_retention_days)
    total = sum(result.values())
    logging.info("audit_retention_sweep: deleted %d rows across %d orgs", total, len(result))
    for oid, n in sorted(result.items()):
        logging.info("  org=%s deleted=%d", oid, n)
    return 0


if __name__ == "__main__":
    sys.exit(main())
