#!/usr/bin/env python3
"""
Scientific Schematic Generator using OpenRouter AI models.

Uses Nano Banana Pro for image generation and Gemini for quality review.
Implements smart iterative refinement with document-type quality thresholds.
"""

import os
import sys
import json
import base64
import time
import requests
from pathlib import Path


# Quality thresholds by document type
QUALITY_THRESHOLDS = {
    "journal": 8.5,
    "conference": 8.0,
    "thesis": 8.0,
    "grant": 8.0,
    "preprint": 7.5,
    "report": 7.5,
    "poster": 7.0,
    "presentation": 6.5,
    "default": 7.5,
}

SCIENTIFIC_DIAGRAM_GUIDELINES = """You are an expert scientific diagram creator. Create a publication-quality scientific diagram following these strict guidelines:

VISUAL STANDARDS:
- Clean white background
- High contrast for print readability
- Sans-serif font labels (minimum 10pt equivalent)
- Professional muted color palette (colorblind-friendly Okabe-Ito)
- Clear directional arrows showing data/process flow
- Minimum 0.5pt line weight, typical 1-2pt
- Proper spacing to prevent element crowding
- Rounded rectangles for process boxes, cylinders for data stores
- Consistent styling across all elements

LAYOUT:
- Logical flow direction (left-to-right or top-to-bottom)
- Balanced composition with clear visual hierarchy
- Grouped related components
- Adequate whitespace between elements

This diagram is for an academic publication. It must be clean, precise, and professional."""

REVIEW_PROMPT = """You are an expert scientific figure reviewer for top-tier journals (Nature, Science, IEEE).

Evaluate this scientific diagram on these 5 criteria, each scored 0-2:

1. **Scientific Accuracy** (0-2): Are concepts, notation, and relationships correct?
2. **Clarity and Readability** (0-2): Is the diagram easy to understand with clear hierarchy?
3. **Label Quality** (0-2): Are labels complete, readable, consistent, and well-positioned?
4. **Layout and Composition** (0-2): Is the flow logical, balanced, with no overlaps?
5. **Professional Appearance** (0-2): Is this publication-ready for a top journal?

Respond in EXACTLY this format:
SCORE: X.X

STRENGTHS:
- strength 1
- strength 2

ISSUES:
- issue 1
- issue 2

SPECIFIC_IMPROVEMENTS:
- improvement 1
- improvement 2"""


class ScientificSchematicGenerator:
    """Generates scientific schematics using OpenRouter AI models."""

    OPENROUTER_URL = "https://openrouter.ai/api/v1/chat/completions"

    def __init__(self, api_key=None, verbose=False):
        self.api_key = api_key or os.environ.get("OPENROUTER_API_KEY")
        if not self.api_key:
            raise ValueError(
                "OPENROUTER_API_KEY environment variable not set. "
                "Get a key at https://openrouter.ai/keys"
            )
        self.verbose = verbose
        self.image_model = "google/gemini-2.5-flash-image"
        self.review_model = "google/gemini-2.5-flash-preview-05-20"
        self.headers = {
            "Authorization": f"Bearer {self.api_key}",
            "Content-Type": "application/json",
            "HTTP-Referer": "https://github.com/scientific-schematics",
            "X-Title": "Scientific Schematics Generator",
        }

    def _log(self, msg):
        if self.verbose:
            print(f"  [AI] {msg}")

    def generate_image(self, prompt, output_path):
        """Generate an image using Nano Banana Pro via OpenRouter."""
        full_prompt = f"{SCIENTIFIC_DIAGRAM_GUIDELINES}\n\nDIAGRAM REQUEST:\n{prompt}"

        self._log(f"Generating image with {self.image_model}...")

        payload = {
            "model": self.image_model,
            "messages": [
                {
                    "role": "user",
                    "content": full_prompt,
                }
            ],
        }

        resp = requests.post(
            self.OPENROUTER_URL, headers=self.headers, json=payload, timeout=120
        )
        resp.raise_for_status()
        data = resp.json()

        # Extract image from response
        content = data.get("choices", [{}])[0].get("message", {}).get("content", "")

        # Check for image URL in the response
        image_url = None

        # Handle structured content (list of content blocks)
        if isinstance(content, list):
            for block in content:
                if isinstance(block, dict):
                    if block.get("type") == "image_url":
                        image_url = block.get("image_url", {}).get("url", "")
                    elif block.get("type") == "text":
                        text = block.get("text", "")
                        image_url = self._extract_url_from_text(text) or image_url
        elif isinstance(content, str):
            image_url = self._extract_url_from_text(content)

        if not image_url:
            # Check for image in the response data directly
            image_data = self._extract_base64_from_response(data)
            if image_data:
                img_bytes = base64.b64decode(image_data)
                Path(output_path).write_bytes(img_bytes)
                self._log(f"Saved image to {output_path}")
                return output_path

            # Log the full response for debugging
            self._log(f"Response structure: {json.dumps(data, indent=2)[:1000]}")
            raise RuntimeError(
                "No image found in API response. "
                f"Model response: {str(content)[:500]}"
            )

        # Download the image from URL
        if image_url.startswith("data:image"):
            # Base64 data URI
            b64_data = image_url.split(",", 1)[1]
            img_bytes = base64.b64decode(b64_data)
        else:
            self._log(f"Downloading image from URL...")
            img_resp = requests.get(image_url, timeout=60)
            img_resp.raise_for_status()
            img_bytes = img_resp.content

        Path(output_path).write_bytes(img_bytes)
        self._log(f"Saved image to {output_path} ({len(img_bytes)} bytes)")
        return output_path

    def _extract_url_from_text(self, text):
        """Extract image URL from text content."""
        import re

        # Look for markdown image links
        md_match = re.search(r'!\[.*?\]\((https?://[^\s\)]+)\)', text)
        if md_match:
            return md_match.group(1)

        # Look for direct URLs ending in image extensions
        url_match = re.search(
            r'(https?://[^\s\'"<>]+\.(?:png|jpg|jpeg|webp|gif))', text, re.IGNORECASE
        )
        if url_match:
            return url_match.group(1)

        # Look for any URL that might be an image
        url_match = re.search(r'(https?://[^\s\'"<>]+)', text)
        if url_match:
            url = url_match.group(1)
            if any(kw in url.lower() for kw in ["image", "img", "pic", "photo", "gen"]):
                return url

        return None

    def _extract_base64_from_response(self, data):
        """Try to extract base64 image data from various response formats."""
        # Check in choices
        for choice in data.get("choices", []):
            msg = choice.get("message", {})
            content = msg.get("content", "")

            if isinstance(content, list):
                for block in content:
                    if isinstance(block, dict):
                        url = block.get("image_url", {}).get("url", "")
                        if url.startswith("data:image"):
                            return url.split(",", 1)[1]
                        # Check for b64_json
                        b64 = block.get("b64_json", block.get("data", ""))
                        if b64 and len(b64) > 100:
                            return b64

        return None

    def review_image(self, image_path):
        """Review an image using Gemini via OpenRouter."""
        self._log(f"Reviewing image with {self.review_model}...")

        # Read image and encode as base64
        img_bytes = Path(image_path).read_bytes()
        b64_image = base64.b64encode(img_bytes).decode("utf-8")

        # Determine MIME type
        suffix = Path(image_path).suffix.lower()
        mime_map = {".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".webp": "image/webp"}
        mime_type = mime_map.get(suffix, "image/png")

        payload = {
            "model": self.review_model,
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": REVIEW_PROMPT},
                        {
                            "type": "image_url",
                            "image_url": {
                                "url": f"data:{mime_type};base64,{b64_image}"
                            },
                        },
                    ],
                }
            ],
        }

        resp = requests.post(
            self.OPENROUTER_URL, headers=self.headers, json=payload, timeout=60
        )
        resp.raise_for_status()
        data = resp.json()

        critique = data.get("choices", [{}])[0].get("message", {}).get("content", "")
        if isinstance(critique, list):
            critique = " ".join(
                b.get("text", "") for b in critique if isinstance(b, dict)
            )

        # Parse score
        score = self._parse_score(critique)
        self._log(f"Review score: {score}/10")

        return score, critique

    def _parse_score(self, critique):
        """Extract numeric score from review text."""
        import re

        match = re.search(r'SCORE:\s*(\d+(?:\.\d+)?)', critique)
        if match:
            return float(match.group(1))
        # Fallback: look for any X/10 or X.X/10
        match = re.search(r'(\d+(?:\.\d+)?)\s*/\s*10', critique)
        if match:
            return float(match.group(1))
        return 5.0  # Default if parsing fails

    def improve_prompt(self, original_prompt, critique, iteration):
        """Improve the prompt based on review critique."""
        improved = (
            f"{original_prompt}\n\n"
            f"--- ITERATION {iteration} IMPROVEMENTS ---\n"
            f"Based on quality review, address these issues:\n"
            f"{critique}\n\n"
            f"CRITICAL: Fix all issues mentioned above while maintaining all other quality aspects. "
            f"Ensure maximum clarity, professional appearance, and scientific accuracy."
        )
        return improved

    def generate_iterative(
        self, user_prompt, output_path, iterations=2, doc_type="default"
    ):
        """Generate a diagram with iterative refinement."""
        threshold = QUALITY_THRESHOLDS.get(doc_type, QUALITY_THRESHOLDS["default"])
        output_path = Path(output_path)
        stem = output_path.stem
        suffix = output_path.suffix or ".png"
        parent = output_path.parent
        parent.mkdir(parents=True, exist_ok=True)

        results = {
            "user_prompt": user_prompt,
            "doc_type": doc_type,
            "quality_threshold": threshold,
            "iterations": [],
            "final_score": 0,
            "final_image": None,
            "early_stop": False,
            "early_stop_reason": None,
        }

        current_prompt = user_prompt
        best_score = 0
        best_image = None

        for i in range(1, iterations + 1):
            print(f"\n--- Iteration {i}/{iterations} ---")

            # Generate image
            version_path = str(parent / f"{stem}_v{i}{suffix}")
            try:
                self.generate_image(current_prompt, version_path)
            except Exception as e:
                print(f"  Generation failed: {e}")
                results["iterations"].append(
                    {
                        "iteration": i,
                        "image_path": version_path,
                        "score": 0,
                        "needs_improvement": True,
                        "critique": f"Generation failed: {e}",
                        "error": True,
                    }
                )
                continue

            # Review image
            try:
                score, critique = self.review_image(version_path)
            except Exception as e:
                print(f"  Review failed: {e}")
                score, critique = 7.0, f"Review unavailable: {e}"

            needs_improvement = score < threshold
            iter_result = {
                "iteration": i,
                "image_path": version_path,
                "score": score,
                "needs_improvement": needs_improvement,
                "critique": critique,
            }
            results["iterations"].append(iter_result)

            print(f"  Score: {score}/10 (threshold: {threshold})")

            if score > best_score:
                best_score = score
                best_image = version_path

            # Check if quality meets threshold
            if not needs_improvement:
                results["early_stop"] = True
                results["early_stop_reason"] = (
                    f"Quality score {score} meets threshold {threshold} for {doc_type}"
                )
                print(f"  Quality meets threshold - stopping early!")
                break

            # Improve prompt for next iteration
            if i < iterations:
                current_prompt = self.improve_prompt(user_prompt, critique, i + 1)
                print(f"  Below threshold, improving prompt for next iteration...")

        # Copy best result to final output path
        if best_image:
            import shutil
            final_path = str(parent / f"{stem}{suffix}")
            shutil.copy2(best_image, final_path)
            results["final_image"] = final_path
            results["final_score"] = best_score
            print(f"\nFinal image: {final_path} (score: {best_score}/10)")
        else:
            results["final_image"] = None
            results["final_score"] = 0
            print("\nNo image was generated successfully.")

        # Save review log
        log_path = str(parent / f"{stem}_review_log.json")
        with open(log_path, "w") as f:
            json.dump(results, f, indent=2)
        print(f"Review log: {log_path}")

        return results
