"""Replication driver: boot the kernel under a pty, run one alice model
session on ccpty0, dump the scrollback, classify. Repeat N times.

Usage: python _repl_driver.py <label> <n_runs> [model_secs]
Saves raw scrollbacks to /tmp/repl/<label>_run<i>.txt and prints a summary.
"""
import os, pty, select, signal, sys, time, re, pathlib

VENV_PY = os.environ.get("VENV_PY", sys.executable)
ENV = dict(os.environ,
           SEK_MODEL_URL="http://localhost:4000/v1",
           SEK_MODEL="llama3.1:8b",
           SEK_API_KEY="sk-litellm-dev")

def _spawn():
    pid, fd = pty.fork()
    if pid == 0:
        os.execve(VENV_PY, [VENV_PY, "-m", "byteb4rb1e.sek.ddist"], ENV)
    return pid, fd

def _read_until(fd, pattern, timeout, sink):
    """Read until `pattern` (bytes) seen in the tail, or timeout. Append all
    read bytes to sink (bytearray). Returns True if pattern seen."""
    deadline = time.monotonic() + timeout
    tail = b""
    while time.monotonic() < deadline:
        r, _, _ = select.select([fd], [], [], 0.3)
        if r:
            try:
                chunk = os.read(fd, 4096)
            except OSError:
                return False
            if not chunk:
                return False
            sink.extend(chunk)
            tail = (tail + chunk)[-256:]
            if pattern in tail:
                return True
    return False

def _drain(fd, secs, sink):
    deadline = time.monotonic() + secs
    while time.monotonic() < deadline:
        r, _, _ = select.select([fd], [], [], 0.3)
        if r:
            try:
                chunk = os.read(fd, 4096)
            except OSError:
                return
            if not chunk:
                return
            sink.extend(chunk)

def run_once(model_secs):
    pid, fd = _spawn()
    out = bytearray()
    try:
        if not _read_until(fd, b"login:", 35, out):
            return None, out
        os.write(fd, b"root\n")
        _read_until(fd, b"password:", 10, out)
        os.write(fd, b"\n")
        _read_until(fd, b"# ", 12, out)
        # background the model session
        os.write(fd, b"login -f alice -t ccpty0 -- /bin/sh &\n")
        _drain(fd, model_secs, out)
        # dump the scrollback into a fresh buffer; read until the root prompt
        # returns. Nothing but the scrollback carries [assistant]/[user] tags.
        cap = bytearray()
        os.write(fd, b"stty -f dev://ccpty0 scrollback\n")
        _read_until(fd, b"~# ", 12, cap)
        out.extend(cap)
        return cap.decode("utf-8", "replace"), out
    finally:
        try: os.kill(pid, signal.SIGKILL)
        except OSError: pass
        try: os.close(fd)
        except OSError: pass
        try: os.waitpid(pid, 0)
        except OSError: pass

PROSE = re.compile(r"\b(let me|let's|it seems|i think|i'll|i will|i'm|maybe|perhaps|"
                   r"seems like|looks like|sorry|frustrating|trying to|going to|"
                   r"that's|cheat|let us|we can|we could|i need|i should|got it|"
                   r"actually|hmm|oops|whoops|apolog)\b", re.I)

def classify(scroll):
    if scroll is None:
        return "boot-fail", {}
    # split into role-tagged turns
    turns = re.split(r"(?=\[(?:assistant|user|system)\])", scroll)
    asst = []
    for t in turns:
        m = re.match(r"\[assistant\]\s?(.*)", t, re.S)
        if m:
            asst.append(m.group(1).strip())
    if not asst:
        return "no-turns", {"asst": 0}
    prose = [a for a in asst if PROSE.search(a)]
    emits_prompt = any("alice@sek" in a for a in asst)
    has_exit = any(a.strip() == "exit" or a.strip().endswith("\nexit") for a in asst)
    stats = {"asst": len(asst), "prose": len(prose), "exit": has_exit, "prompt": emits_prompt}
    if prose:
        return "drift", stats
    if emits_prompt:
        return "over-imitation", stats
    if has_exit:
        return "clean", stats
    return "commands-no-exit", stats

def main():
    label = sys.argv[1]
    n = int(sys.argv[2])
    secs = int(sys.argv[3]) if len(sys.argv) > 3 else 20
    outdir = pathlib.Path("/tmp/repl"); outdir.mkdir(exist_ok=True)
    results = []
    for i in range(1, n + 1):
        scroll, raw = run_once(secs)
        (outdir / f"{label}_run{i}.txt").write_text(scroll or "", encoding="utf-8")
        (outdir / f"{label}_run{i}.raw").write_bytes(bytes(raw))
        verdict, stats = classify(scroll)
        results.append(verdict)
        print(f"  run {i}: {verdict:18s} {stats}", flush=True)
    from collections import Counter
    print(f"=== {label}: {dict(Counter(results))}", flush=True)

if __name__ == "__main__":
    main()
