#!/usr/bin/env python3
"""Compare two UBER joeylog.txt files by per-op ops/sec.

Sibling of diff-uber-hashes (which compares pixel correctness). This
tool drives Phase 10 of project_planar_68k_plan.md: pick the
biggest perf gaps vs the IIgs reference and target asm/algorithmic
optimization at those.

Usage:
    tools/diff-uber-perf <reference-log> <test-log> [--threshold 1.0]

Output is sorted by speed ratio (test/ref) ascending, so the worst
gaps print first. Ops missing from either log are flagged. The
threshold flag (default 1.0) marks ops below that ratio as FAIL --
project_perf_directive.md says "IIgs is the perf floor; every
other target must match or beat it", so parity = 1.0x. Use
--threshold 0.8 for the project_planar_68k_plan looser acceptance.

Exit code:
    0 = all common ops at >= threshold
    1 = at least one op below threshold (or missing)
    2 = usage error or missing file
"""

import re
import sys

# Match e.g.:
#   UBER: drawCircle r=80: 56 iters / 4 frames = 840 ops/sec | hash=A1B2C3D4
LINE_RE = re.compile(
    r"UBER:\s+(?P<op>[^:]+):\s+\d+\s+iters\s+/\s+\d+\s+frames\s+=\s+(?P<ops>\d+)\s+ops/sec"
)


def parse_log(path):
    """Return ordered dict {op_name: ops_per_sec} from a UBER log file.

    Multiple runs may be concatenated (joeyLog appends); last value
    for each op wins, matching the most recent run.
    """
    perf = {}
    with open(path) as f:
        for line in f:
            m = LINE_RE.search(line)
            if m:
                perf[m.group("op").strip()] = int(m.group("ops"))
    return perf


def main(argv):
    threshold = 1.0
    args = []
    i = 1
    while i < len(argv):
        if argv[i] == "--threshold" and i + 1 < len(argv):
            try:
                threshold = float(argv[i + 1])
            except ValueError:
                sys.stderr.write(f"error: bad threshold {argv[i+1]}\n")
                return 2
            i += 2
        else:
            args.append(argv[i])
            i += 1

    if len(args) != 2:
        sys.stderr.write(
            "usage: diff-uber-perf <reference-log> <test-log> [--threshold 1.0]\n"
        )
        return 2

    try:
        ref = parse_log(args[0])
        test = parse_log(args[1])
    except OSError as e:
        sys.stderr.write(f"error: {e}\n")
        return 2

    if not ref:
        sys.stderr.write(f"error: no UBER lines found in {args[0]}\n")
        return 2
    if not test:
        sys.stderr.write(f"error: no UBER lines found in {args[1]}\n")
        return 2

    rows = []
    for op, ref_ops in ref.items():
        test_ops = test.get(op)
        if test_ops is None:
            rows.append((op, ref_ops, None, None, "MISSING"))
            continue
        if ref_ops == 0:
            ratio = float("inf") if test_ops > 0 else 1.0
        else:
            ratio = test_ops / ref_ops
        status = "ok" if ratio >= threshold else "FAIL"
        rows.append((op, ref_ops, test_ops, ratio, status))

    extras = [(op, None, test[op], None, "EXTRA") for op in test if op not in ref]

    # Sort: missing/fail first by worst ratio, then ok ascending by ratio.
    def sort_key(row):
        op, refv, testv, ratio, status = row
        if status == "MISSING":
            return (0, 0.0, op)
        if status == "EXTRA":
            return (3, 0.0, op)
        return (1 if status == "FAIL" else 2, ratio, op)

    rows.sort(key=sort_key)

    op_w = max(len(op) for op in ref) if ref else 8
    op_w = max(op_w, max((len(op) for op in test), default=8), len("op"))

    print(f"{'op':<{op_w}}  {'ref':>10}  {'test':>10}  {'ratio':>7}  status")
    print(f"{'-'*op_w}  {'-'*10}  {'-'*10}  {'-'*7}  ------")
    fails = 0
    for op, refv, testv, ratio, status in rows + extras:
        refs = "" if refv is None else str(refv)
        tests = "" if testv is None else str(testv)
        rats = "" if ratio is None else f"{ratio:.2f}x"
        print(f"{op:<{op_w}}  {refs:>10}  {tests:>10}  {rats:>7}  {status}")
        if status in ("FAIL", "MISSING"):
            fails += 1

    print()
    print(f"threshold: {threshold:.2f}x  ({len(rows)} ops compared, {fails} below threshold)")
    return 1 if fails > 0 else 0


if __name__ == "__main__":
    sys.exit(main(sys.argv))
