#!/usr/bin/env python3
"""
Generate a real two-host podcast MP3 from a script JSON.
TTS priority: OpenAI TTS API → macOS `say` + ffmpeg fallback.
Produces 96 kbps MP3 under 9 MB.
"""

import argparse, json, os, subprocess, sys, tempfile, time
from pathlib import Path

# --------------- OpenAI TTS ---------------

def tts_openai(text: str, voice: str, out_path: str) -> bool:
    """Call OpenAI-compatible TTS endpoint. Returns True on success."""
    try:
        from openai import OpenAI
        client = OpenAI(
            api_key=os.environ["OPENAI_API_KEY"],
            base_url=os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
        )
        # tts-1 is cheaper/faster; tts-1-hd if wanted
        resp = client.audio.speech.create(
            model="tts-1",
            voice=voice,
            input=text,
            response_format="mp3",
        )
        resp.stream_to_file(out_path)
        # verify file is >1 KB (not empty/error)
        if Path(out_path).stat().st_size < 1024:
            print(f"  ⚠ OpenAI TTS returned tiny file for chunk, likely error")
            return False
        return True
    except Exception as e:
        print(f"  ⚠ OpenAI TTS failed: {e}")
        return False


# --------------- macOS say fallback ---------------

def tts_say(text: str, voice: str, out_mp3: str) -> bool:
    """Use macOS `say` to create AIFF then convert to mp3 via ffmpeg."""
    try:
        with tempfile.NamedTemporaryFile(suffix=".aiff", delete=False) as tmp:
            aiff = tmp.name
        subprocess.run(["say", "-v", voice, "-o", aiff, text], check=True, timeout=120)
        subprocess.run([
            "ffmpeg", "-y", "-i", aiff,
            "-codec:a", "libmp3lame", "-b:a", "96k", "-ar", "24000", "-ac", "1",
            out_mp3
        ], check=True, capture_output=True, timeout=120)
        os.unlink(aiff)
        return Path(out_mp3).stat().st_size > 512
    except Exception as e:
        print(f"  ⚠ say fallback failed: {e}")
        return False


# --------------- main pipeline ---------------

OPENAI_VOICES = {"male": "onyx", "female": "nova"}
SAY_VOICES    = {"male": "Daniel", "female": "Samantha"}


def synthesize_line(text: str, speaker: str, out_mp3: str, method: str) -> bool:
    if method == "openai":
        return tts_openai(text, OPENAI_VOICES[speaker], out_mp3)
    elif method == "say":
        return tts_say(text, SAY_VOICES[speaker], out_mp3)
    return False


def detect_method() -> str:
    """Probe OpenAI endpoint with a tiny request."""
    if os.environ.get("OPENAI_API_KEY"):
        print("Testing OpenAI TTS endpoint…")
        with tempfile.NamedTemporaryFile(suffix=".mp3", delete=False) as f:
            probe = f.name
        ok = tts_openai("Test.", "onyx", probe)
        try:
            os.unlink(probe)
        except OSError:
            pass
        if ok:
            print("✓ OpenAI TTS available — using it.")
            return "openai"
        print("✗ OpenAI TTS probe failed.")
    # macOS say
    if subprocess.run(["which", "say"], capture_output=True).returncode == 0:
        print("✓ Falling back to macOS say + ffmpeg.")
        return "say"
    print("✗ No TTS backend available!")
    sys.exit(1)


def concat_mp3s(parts: list[str], output: str):
    """Concatenate MP3 segments via ffmpeg concat demuxer."""
    with tempfile.NamedTemporaryFile("w", suffix=".txt", delete=False) as f:
        for p in parts:
            f.write(f"file '{p}'\n")
        listfile = f.name
    subprocess.run([
        "ffmpeg", "-y", "-f", "concat", "-safe", "0", "-i", listfile,
        "-codec:a", "libmp3lame", "-b:a", "96k", "-ar", "24000", "-ac", "1",
        output
    ], check=True, capture_output=True, timeout=600)
    os.unlink(listfile)


def write_transcript(script: dict, path: Path):
    title = script.get("title", "Podcast Episode")
    lines = [f"# {title}", ""]
    for line in script["lines"]:
        name = "Alex" if line["speaker"] == "male" else "Sarah"
        lines.append(f"**{name}:** {line['paragraph']}")
        lines.append("")
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text("\n".join(lines).rstrip() + "\n", encoding="utf-8")


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--script-file", required=True)
    ap.add_argument("--output-file", required=True)
    ap.add_argument("--transcript-file", required=False)
    args = ap.parse_args()

    script = json.loads(Path(args.script_file).read_text())
    method = detect_method()

    tmpdir = tempfile.mkdtemp(prefix="podcast_")
    parts = []
    total = len(script["lines"])

    for i, line in enumerate(script["lines"]):
        out_chunk = os.path.join(tmpdir, f"chunk_{i:03d}.mp3")
        speaker = line["speaker"]
        text = line["paragraph"]
        print(f"  [{i+1}/{total}] {speaker}: {text[:60]}…")
        ok = synthesize_line(text, speaker, out_chunk, method)
        if not ok:
            print(f"  FATAL: could not synthesize line {i+1}")
            sys.exit(1)
        parts.append(out_chunk)
        # small delay to avoid rate-limits on proxy
        if method == "openai" and i < total - 1:
            time.sleep(0.3)

    print("Concatenating segments…")
    out = args.output_file
    Path(out).parent.mkdir(parents=True, exist_ok=True)
    concat_mp3s(parts, out)

    size_mb = Path(out).stat().st_size / (1024 * 1024)
    print(f"✓ Podcast MP3: {out}  ({size_mb:.2f} MB)")
    if size_mb > 9:
        print("⚠ File exceeds 9 MB target — re-encoding at lower bitrate…")
        subprocess.run([
            "ffmpeg", "-y", "-i", out,
            "-codec:a", "libmp3lame", "-b:a", "64k", "-ar", "22050", "-ac", "1",
            out + ".tmp.mp3"
        ], check=True, capture_output=True)
        os.replace(out + ".tmp.mp3", out)
        size_mb = Path(out).stat().st_size / (1024 * 1024)
        print(f"  Re-encoded: {size_mb:.2f} MB")

    if args.transcript_file:
        write_transcript(script, Path(args.transcript_file))
        print(f"✓ Transcript: {args.transcript_file}")

    # cleanup temp chunks
    for p in parts:
        try:
            os.unlink(p)
        except OSError:
            pass

    print("Done!")

if __name__ == "__main__":
    main()
