--- a/scripts/generate_schematic.py +++ b/scripts/generate_schematic.py @@ -1,324 +1,75 @@ #!/usr/bin/env python3 -"""Generate publication-quality scientific diagrams using Nano Banana Pro AI. +""" +CLI wrapper for scientific schematic generation. -Uses OpenRouter to access Nano Banana Pro for image generation and -Gemini 3 Pro for automated quality review with smart iteration. +Usage: + python scripts/generate_schematic.py "diagram description" -o output.png + python scripts/generate_schematic.py "diagram" -o out.png --doc-type journal + python scripts/generate_schematic.py "diagram" -o out.png --iterations 2 -v """ import argparse -import base64 -import json +import sys import os -import sys -import time -from pathlib import Path -import requests +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -OPENROUTER_API_URL = "https://openrouter.ai/api/v1" - -# Model identifiers on OpenRouter -IMAGE_MODEL = "nanobanana/nano-banana-pro" -REVIEW_MODEL = "google/gemini-3-pro" - -# Quality thresholds by document type -THRESHOLDS = { - "journal": 8.5, - "conference": 8.0, - "thesis": 8.0, - "grant": 8.0, - "preprint": 7.5, - "report": 7.5, - "poster": 7.0, - "presentation": 6.5, - "default": 7.5, -} - -SCIENTIFIC_STYLE_PROMPT = ( - "Create a clean, publication-quality scientific diagram. " - "Use a white background with high contrast elements. " - "Apply sans-serif fonts for all labels (minimum 10pt equivalent). " - "Use colorblind-friendly colors from the Okabe-Ito palette. " - "Ensure proper spacing between elements with no overlapping text. " - "Include clear directional arrows for flow/connections. " - "Make it suitable for high-resolution print at 300 DPI." -) - - -def get_api_key() -> str: - """Get OpenRouter API key from environment.""" - key = os.environ.get("OPENROUTER_API_KEY") - if not key: - print("Error: OPENROUTER_API_KEY environment variable not set.") - print("Get your key at: https://openrouter.ai/keys") - sys.exit(1) - return key - - -def generate_image( - api_key: str, - prompt: str, - width: int = 1024, - height: int = 1024, -) -> bytes: - """Generate an image using Nano Banana Pro via OpenRouter. - - Returns raw image bytes (PNG). - """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "HTTP-Referer": "https://doany.ai", - "X-Title": "Scientific Schematics Generator", - } - - full_prompt = f"{SCIENTIFIC_STYLE_PROMPT}\n\nDiagram request: {prompt}" - - payload = { - "model": IMAGE_MODEL, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": full_prompt, - } - ], - } - ], - "max_tokens": 4096, - "response_format": {"type": "image"}, - "image_generation": { - "width": width, - "height": height, - "format": "png", - }, - } - - print(f" Generating image with {IMAGE_MODEL}...") - resp = requests.post( - f"{OPENROUTER_API_URL}/chat/completions", - headers=headers, - json=payload, - timeout=120, - ) - resp.raise_for_status() - - data = resp.json() - - # Extract image from response — OpenRouter returns base64 in content - content = data["choices"][0]["message"]["content"] - if isinstance(content, list): - for part in content: - if part.get("type") == "image_url": - b64 = part["image_url"]["url"].split(",", 1)[-1] - return base64.b64decode(b64) - elif isinstance(content, str) and content.startswith("data:image"): - b64 = content.split(",", 1)[-1] - return base64.b64decode(b64) - - raise ValueError("No image data found in API response") - - -def review_quality( - api_key: str, - image_bytes: bytes, - description: str, - doc_type: str, -) -> dict: - """Review diagram quality using Gemini 3 Pro via OpenRouter. - - Returns dict with keys: score (float), strengths (list), weaknesses (list), - suggestions (list), critique (str). - """ - headers = { - "Authorization": f"Bearer {api_key}", - "Content-Type": "application/json", - "HTTP-Referer": "https://doany.ai", - "X-Title": "Scientific Schematics Quality Review", - } - - b64_image = base64.b64encode(image_bytes).decode("utf-8") - - review_prompt = f"""You are a scientific diagram quality reviewer for a {doc_type} publication. - -Evaluate this diagram against the following request: "{description}" - -Score the diagram on these criteria (0-2 points each, 10 total): -1. **Scientific Accuracy** - Correct concepts, notation, relationships -2. **Clarity and Readability** - Easy to understand, clear hierarchy -3. **Label Quality** - Complete, readable, consistent labels -4. **Layout and Composition** - Logical flow, balanced, no overlaps -5. **Professional Appearance** - Publication-ready quality - -Respond in this exact JSON format: -{{ - "score": , - "criteria": {{ - "scientific_accuracy": , - "clarity": , - "labels": , - "layout": , - "appearance": - }}, - "strengths": ["strength 1", "strength 2"], - "weaknesses": ["weakness 1", "weakness 2"], - "suggestions": ["specific improvement 1", "specific improvement 2"], - "critique": "Brief overall assessment in one paragraph." -}}""" - - payload = { - "model": REVIEW_MODEL, - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": review_prompt}, - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{b64_image}", - }, - }, - ], - } - ], - "max_tokens": 2048, - "temperature": 0.2, - } - - print(f" Reviewing quality with {REVIEW_MODEL}...") - resp = requests.post( - f"{OPENROUTER_API_URL}/chat/completions", - headers=headers, - json=payload, - timeout=90, - ) - resp.raise_for_status() - - data = resp.json() - text = data["choices"][0]["message"]["content"] - - # Parse JSON from response (handle markdown code fences) - text = text.strip() - if text.startswith("```"): - text = text.split("\n", 1)[1].rsplit("```", 1)[0].strip() - - return json.loads(text) - - -def build_improved_prompt(original: str, review: dict) -> str: - """Build an improved generation prompt incorporating review feedback.""" - suggestions = "; ".join(review.get("suggestions", [])) - weaknesses = "; ".join(review.get("weaknesses", [])) - - return ( - f"{original}. " - f"IMPORTANT improvements needed: {suggestions}. " - f"Fix these issues: {weaknesses}. " - "Ensure all labels are clearly readable, layout is balanced, " - "and the diagram meets journal publication standards." - ) - - -def save_review_log(log_path: Path, reviews: list[dict], final_version: int) -> None: - """Save the complete review log as JSON.""" - log = { - "final_version": final_version, - "total_iterations": len(reviews), - "reviews": reviews, - } - log_path.write_text(json.dumps(log, indent=2)) - print(f" Review log saved: {log_path}") +from generate_schematic_ai import ScientificSchematicGenerator def main(): parser = argparse.ArgumentParser( - description="Generate publication-quality scientific diagrams with AI" + description="Generate publication-quality scientific diagrams using AI" ) - parser.add_argument("description", help="Natural language description of the diagram") + parser.add_argument("prompt", help="Description of the diagram to generate") parser.add_argument( - "-o", "--output", default="figures/schematic.png", help="Output image path" + "-o", "--output", required=True, help="Output image path" ) parser.add_argument( - "--doc-type", - default="default", - choices=list(THRESHOLDS.keys()), - help="Document type for quality threshold", + "--doc-type", default="default", + choices=["journal", "conference", "thesis", "grant", + "preprint", "report", "poster", "presentation", "default"], + help="Document type for quality threshold (default: default)", ) parser.add_argument( - "--iterations", - type=int, - default=2, - choices=[1, 2], - help="Maximum refinement iterations (1-2)", + "--iterations", type=int, default=2, + help="Maximum refinement iterations (default: 2, max: 2)", ) - parser.add_argument("--width", type=int, default=1024, help="Image width in pixels") - parser.add_argument("--height", type=int, default=1024, help="Image height in pixels") + parser.add_argument("--api-key", default=None, help="OpenRouter API key") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output") args = parser.parse_args() - - threshold = THRESHOLDS[args.doc_type] - output_path = Path(args.output) - output_path.parent.mkdir(parents=True, exist_ok=True) - - stem = output_path.stem - suffix = output_path.suffix - log_path = output_path.parent / f"{stem}_review.json" - - api_key = get_api_key() - prompt = args.description - reviews = [] + args.iterations = min(args.iterations, 2) print(f"Scientific Schematics Generator") - print(f" Document type: {args.doc_type} (threshold: {threshold}/10)") - print(f" Max iterations: {args.iterations}") - print(f" Output: {output_path}") - print() + print(f"Document type: {args.doc_type}") + print(f"Max iterations: {args.iterations}") + print(f"Output: {args.output}") - for iteration in range(1, args.iterations + 1): - version_path = output_path.parent / f"{stem}_v{iteration}{suffix}" + try: + generator = ScientificSchematicGenerator( + api_key=args.api_key, verbose=args.verbose + ) + results = generator.generate_iterative( + user_prompt=args.prompt, + output_path=args.output, + iterations=args.iterations, + doc_type=args.doc_type, + ) - print(f"--- Iteration {iteration}/{args.iterations} ---") + if results.get("final_image"): + print(f"\nDone! Final score: {results['final_score']}/10") + if results.get("early_stop"): + print(f"Early stop: {results['early_stop_reason']}") + return 0 + else: + print("\nFailed to generate diagram.", file=sys.stderr) + return 1 - # Generate image - image_bytes = generate_image(api_key, prompt, args.width, args.height) - version_path.write_bytes(image_bytes) - print(f" Saved: {version_path} ({len(image_bytes):,} bytes)") - - # Review quality - review = review_quality(api_key, image_bytes, args.description, args.doc_type) - score = review.get("score", 0) - review["iteration"] = iteration - review["version_file"] = str(version_path) - reviews.append(review) - - print(f" Quality score: {score}/10 (threshold: {threshold}/10)") - - if score >= threshold: - print(f" PASS — Quality meets {args.doc_type} threshold!") - # Copy final version to output path - output_path.write_bytes(image_bytes) - print(f" Final output: {output_path}") - save_review_log(log_path, reviews, iteration) - print(f"\nDone! Diagram ready for {args.doc_type} publication.") - return - - print(f" BELOW THRESHOLD — needs improvement") - if iteration < args.iterations: - prompt = build_improved_prompt(args.description, review) - print(f" Refining prompt for next iteration...") - print() - else: - # Max iterations reached — save best version - print(f" Max iterations reached. Saving best available version.") - output_path.write_bytes(image_bytes) - print(f" Final output: {output_path}") - save_review_log(log_path, reviews, iteration) - print(f"\nDone! Best version saved (score: {score}/10, threshold: {threshold}/10).") - print(" Consider running again or adjusting the description for higher quality.") + except Exception as e: + print(f"\nError: {e}", file=sys.stderr) + return 1 if __name__ == "__main__": - main() + sys.exit(main())