#!/usr/bin/env python3
"""
Generate publication-quality RAG system architecture diagram.
Suitable for Nature/IEEE journal submission.
"""

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.patches import FancyBboxPatch, FancyArrowPatch
from matplotlib.path import Path
import numpy as np

# ── Publication settings ──────────────────────────────────────────────
plt.rcParams.update({
    "font.family": "sans-serif",
    "font.sans-serif": ["Helvetica", "Arial", "DejaVu Sans"],
    "font.size": 9,
    "axes.linewidth": 0.8,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.15,
})

# ── Color palette (muted, colorblind-friendly) ───────────────────────
C_INGEST     = "#4878A8"   # steel blue  – ingestion
C_INGEST_LT  = "#D6E4F0"   # light blue fill
C_RETRIEVE   = "#2A7F62"   # teal        – retrieval
C_RETRIEVE_LT= "#D0EDE3"   # light teal fill
C_GENERATE   = "#C46E3B"   # warm orange – generation
C_GENERATE_LT= "#F5DFD0"   # light orange fill
C_FEEDBACK   = "#7B6888"   # muted purple – feedback
C_FEEDBACK_LT= "#E0D8E8"
C_TEXT       = "#1A1A1A"
C_LABEL      = "#555555"
C_BG_SECTION = "#F8F8F8"

fig, ax = plt.subplots(figsize=(13, 6.2))
ax.set_xlim(-0.8, 13.5)
ax.set_ylim(-2.5, 5.8)
ax.set_aspect("equal")
ax.axis("off")
fig.patch.set_facecolor("white")

# ── Helper: rounded box ──────────────────────────────────────────────
def draw_box(x, y, w, h, label, fc, ec, fontsize=8.5, bold=False,
             sublabel=None, sublabel_fs=6.5):
    box = FancyBboxPatch(
        (x - w/2, y - h/2), w, h,
        boxstyle="round,pad=0.12",
        facecolor=fc, edgecolor=ec, linewidth=1.3, zorder=3,
    )
    ax.add_patch(box)
    weight = "bold" if bold else "normal"
    ax.text(x, y, label, ha="center", va="center",
            fontsize=fontsize, color=C_TEXT, fontweight=weight, zorder=4,
            linespacing=1.2)
    if sublabel:
        ax.text(x, y - h/2 - 0.22, sublabel, ha="center", va="top",
                fontsize=sublabel_fs, color=C_LABEL, zorder=4, linespacing=1.15)
    return box

# ── Helper: cylinder (vector store) ─────────────────────────────────
def draw_cylinder(cx, cy, w, h, label, fc, ec, fontsize=8.5):
    ell_h = 0.18
    # body rectangle
    body = plt.Rectangle((cx - w/2, cy - h/2), w, h,
                          facecolor=fc, edgecolor=ec, linewidth=1.3, zorder=3)
    ax.add_patch(body)
    # top ellipse (full)
    ell_top = mpatches.Ellipse((cx, cy + h/2), w, ell_h*2,
        facecolor=fc, edgecolor=ec, linewidth=1.3, zorder=4)
    ax.add_patch(ell_top)
    # bottom ellipse (front arc only)
    ell_bot = mpatches.Arc((cx, cy - h/2), w, ell_h*2, theta1=180, theta2=360,
        edgecolor=ec, linewidth=1.3, zorder=4)
    ax.add_patch(ell_bot)
    # hide body top edge behind ellipse
    cover = plt.Rectangle((cx - w/2 + 0.01, cy + h/2 - 0.02), w - 0.02, 0.04,
                           facecolor=fc, edgecolor="none", zorder=3.5)
    ax.add_patch(cover)
    ax.text(cx, cy, label, ha="center", va="center",
            fontsize=fontsize, color=C_TEXT, fontweight="bold", zorder=5,
            linespacing=1.2)

# ── Helper: arrow with label ────────────────────────────────────────
def draw_arrow(x1, y1, x2, y2, label="", color="#555555", style="-|>",
               connectionstyle="arc3,rad=0", fontsize=7, ls="-",
               label_pos=None, label_ha="center", label_va="center",
               shrinkA=10, shrinkB=10, lw=1.1, zorder=2,
               label_bg=True):
    arrow = FancyArrowPatch(
        (x1, y1), (x2, y2),
        arrowstyle=style, color=color,
        connectionstyle=connectionstyle,
        linewidth=lw, linestyle=ls,
        shrinkA=shrinkA, shrinkB=shrinkB,
        mutation_scale=13, zorder=zorder,
    )
    ax.add_patch(arrow)
    if label:
        if label_pos is None:
            lx, ly = (x1+x2)/2, (y1+y2)/2 + 0.22
        else:
            lx, ly = label_pos
        bbox_props = dict(boxstyle="round,pad=0.1", facecolor="white",
                          edgecolor="none", alpha=0.85) if label_bg else None
        ax.text(lx, ly, label, ha=label_ha, va=label_va,
                fontsize=fontsize, color=color, fontstyle="italic", zorder=5,
                bbox=bbox_props)

# ── Section background ──────────────────────────────────────────────
def section_bg(x, y, w, h, title, color_edge):
    rect = mpatches.FancyBboxPatch(
        (x, y), w, h,
        boxstyle="round,pad=0.18",
        facecolor=C_BG_SECTION, edgecolor=color_edge,
        linewidth=0.9, linestyle=(0, (5, 4)), alpha=0.5, zorder=1,
    )
    ax.add_patch(rect)
    ax.text(x + w/2, y + h + 0.18, title, ha="center", va="bottom",
            fontsize=8.5, color=color_edge, fontweight="bold", zorder=2,
            fontstyle="italic")

# ═══════════════════════════════════════════════════════════════════════
#  LAYOUT
# ═══════════════════════════════════════════════════════════════════════

Y_TOP = 3.6
Y_MID = 2.2
Y_BOT = 0.75

# -- Section backgrounds --
section_bg(-0.5, Y_BOT - 0.7, 4.0, Y_TOP - Y_BOT + 1.6, "Ingestion", C_INGEST)
section_bg(4.0,  Y_BOT - 0.7, 4.8, Y_TOP - Y_BOT + 1.6, "Retrieval & Reranking", C_RETRIEVE)
section_bg(9.2,  Y_BOT - 0.7, 4.0, Y_TOP - Y_BOT + 1.6, "Generation", C_GENERATE)

# ===================== INGESTION PIPELINE ============================

draw_box(0.3, Y_MID, 1.3, 0.75, "Documents", C_INGEST_LT, C_INGEST, bold=True,
         sublabel="PDFs, text, web")

draw_box(2.6, Y_MID, 1.45, 0.75, "Chunk &\nEmbed", C_INGEST_LT, C_INGEST, bold=True)

# Arrow: Documents → Chunk & Embed
draw_arrow(0.97, Y_MID, 1.86, Y_MID, "", C_INGEST)

# Vector Store (cylinder)
draw_cylinder(2.6, Y_BOT, 1.45, 0.7, "Vector\nStore", C_INGEST_LT, C_INGEST)

# Arrow: Chunk & Embed → Vector Store (vertical)
draw_arrow(2.6, Y_MID - 0.40, 2.6, Y_BOT + 0.50, "embeddings",
           C_INGEST, label_pos=(1.55, (Y_MID + Y_BOT)/2 + 0.05))

# ===================== RETRIEVAL & RERANKING =========================

draw_box(5.5, Y_TOP, 1.5, 0.62, "User Query", C_RETRIEVE_LT, C_RETRIEVE, bold=True)

draw_box(5.5, Y_MID, 1.45, 0.68, "Query\nEncoder", C_RETRIEVE_LT, C_RETRIEVE, bold=True)

# Arrow: User Query → Query Encoder
draw_arrow(5.5, Y_TOP - 0.34, 5.5, Y_MID + 0.37, "encode",
           C_RETRIEVE, label_pos=(6.15, (Y_TOP + Y_MID)/2))

draw_box(5.5, Y_BOT, 1.45, 0.68, "Similarity\nSearch", C_RETRIEVE_LT, C_RETRIEVE, bold=True)

# Arrow: Query Encoder → Similarity Search
draw_arrow(5.5, Y_MID - 0.37, 5.5, Y_BOT + 0.37, "query vector",
           C_RETRIEVE, label_pos=(6.35, (Y_MID + Y_BOT)/2))

# Arrow: Vector Store → Similarity Search
draw_arrow(3.35, Y_BOT, 4.75, Y_BOT, "index lookup",
           C_RETRIEVE, label_pos=(4.05, Y_BOT + 0.28))

draw_box(7.9, Y_BOT, 1.35, 0.78, "Reranker", C_RETRIEVE_LT, C_RETRIEVE, bold=True)
ax.text(7.9, Y_BOT - 0.20, "cross-encoder", ha="center", va="center",
        fontsize=6, color=C_LABEL, zorder=4, fontstyle="italic")

# Arrow: Similarity Search → Reranker
draw_arrow(6.25, Y_BOT, 7.20, Y_BOT, "top-k",
           C_RETRIEVE, label_pos=(6.72, Y_BOT + 0.25))

# ===================== GENERATION ====================================

draw_box(10.6, Y_MID, 1.7, 0.9, "LLM", C_GENERATE_LT, C_GENERATE, fontsize=12, bold=True)

# Arrow: Reranker → LLM (reranked context)
draw_arrow(8.60, Y_BOT + 0.22, 9.73, Y_MID - 0.25, "reranked context",
           C_GENERATE, connectionstyle="arc3,rad=-0.15",
           label_pos=(9.0, Y_BOT + 0.65), fontsize=7)

# Arrow: User Query → LLM (original query, curved over the top)
draw_arrow(6.27, Y_TOP + 0.05, 9.73, Y_MID + 0.30, "original query",
           C_GENERATE, connectionstyle="arc3,rad=-0.15",
           label_pos=(8.0, Y_TOP + 0.30))

draw_box(12.4, Y_MID, 1.4, 0.75, "Generated\nResponse", C_GENERATE_LT, C_GENERATE, bold=True)

# Arrow: LLM → Response
draw_arrow(11.47, Y_MID, 11.68, Y_MID, "", C_GENERATE)

# ===================== FEEDBACK LOOP (dashed) ========================

fb_y_start = Y_BOT - 0.70       # just below the boxes
fb_y_bottom = -1.55              # lowest point of the feedback curve

# Drop from Response box down
draw_arrow(12.4, Y_MID - 0.40, 12.4, fb_y_start, "",
           C_FEEDBACK, ls="--", lw=1.4, zorder=1, shrinkA=4, shrinkB=2)

# Main horizontal feedback path: curved from Response → Reranker → Vector Store
# Segment 1: right side down and across to Reranker
pts1 = [(12.4, fb_y_start),
        (12.4, fb_y_bottom),
        (7.9,  fb_y_bottom),
        (7.9,  fb_y_start)]
path1 = Path(pts1, [Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4])
p1 = FancyArrowPatch(path=path1, arrowstyle="-", color=C_FEEDBACK,
                      linewidth=1.4, linestyle="--", mutation_scale=13, zorder=1)
ax.add_patch(p1)

# Arrow up into Reranker
draw_arrow(7.9, fb_y_start, 7.9, Y_BOT - 0.37, "",
           C_FEEDBACK, ls="--", lw=1.4, zorder=1, shrinkA=2, shrinkB=4)

# Segment 2: from Reranker level across to Vector Store
pts2 = [(7.9,  fb_y_start),
        (7.9,  fb_y_bottom - 0.35),
        (2.6,  fb_y_bottom - 0.35),
        (2.6,  fb_y_start)]
path2 = Path(pts2, [Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4])
p2 = FancyArrowPatch(path=path2, arrowstyle="-", color=C_FEEDBACK,
                      linewidth=1.4, linestyle="--", mutation_scale=13, zorder=1)
ax.add_patch(p2)

# Arrow up into Vector Store
draw_arrow(2.6, fb_y_start, 2.6, Y_BOT - 0.50, "",
           C_FEEDBACK, ls="--", lw=1.4, zorder=1, shrinkA=2, shrinkB=4)

# Feedback label badge
ax.text(6.5, fb_y_bottom - 0.08, "User Feedback / Relevance Signal",
        ha="center", va="center", fontsize=7.5, color=C_FEEDBACK,
        fontweight="bold", fontstyle="italic",
        bbox=dict(boxstyle="round,pad=0.3", facecolor=C_FEEDBACK_LT,
                  edgecolor=C_FEEDBACK, linewidth=0.9, alpha=0.92),
        zorder=5)

# Direction indicator arrow on the bottom path
ax.annotate("", xy=(5.2, fb_y_bottom - 0.35), xytext=(6.2, fb_y_bottom - 0.35),
            arrowprops=dict(arrowstyle="-|>", color=C_FEEDBACK, lw=1.2,
                            linestyle="--"))

# ── Title ─────────────────────────────────────────────────────────────
ax.text(6.5, 5.45, "RAG System Architecture",
        ha="center", va="bottom", fontsize=13, fontweight="bold", color=C_TEXT)

# ── Save ──────────────────────────────────────────────────────────────
fig.savefig("figures/rag_architecture.png", dpi=300, facecolor="white")
fig.savefig("figures/rag_architecture.pdf", facecolor="white")
print("Saved: figures/rag_architecture.png (300 DPI)")
print("Saved: figures/rag_architecture.pdf (vector)")
plt.close()
