"""
PosterMagic — Local Backend Server v2.0
========================================
Two-step pipeline:
  1. GPT-4o generates themed costume & scene
  2. Replicate face swap pastes real face onto costume
  3. Pillow assembles final poster

Run on your computer:
  pip install fastapi uvicorn openai pillow python-multipart replicate httpx
  python postermagic_server.py
"""

from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import base64, os, io, httpx, json
from openai import OpenAI
from PIL import Image

app = FastAPI(title="PosterMagic API", version="2.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# ── Simple per-device usage limiting (no login needed) ────────────────────────
# Tracks how many paid (AI-costing) generations each device has used.
# Persists to a JSON file so counts survive server restarts.
USAGE_FILE  = "usage_counts.json"
FREE_LIMIT  = 5   # free generations per device before blocking

def _load_usage():
    if os.path.exists(USAGE_FILE):
        try:
            with open(USAGE_FILE, "r") as f:
                return json.load(f)
        except Exception:
            return {}
    return {}

def _save_usage(data):
    with open(USAGE_FILE, "w") as f:
        json.dump(data, f)

# ── Server-side API keys (for hosted/family-shared mode) ────────────────────
# If the server has its own keys configured via environment variables, those
# are used automatically and family members never need to enter or see a key.
# If a key IS passed in the request (e.g. during local testing), that takes
# priority — lets you override without touching server config.
SERVER_OPENAI_KEY    = os.environ.get("POSTERMAGIC_OPENAI_KEY", "")
SERVER_REPLICATE_KEY = os.environ.get("POSTERMAGIC_REPLICATE_KEY", "")

def resolve_openai_key(passed_key: str) -> str:
    key = (passed_key or "").strip() or SERVER_OPENAI_KEY
    if not key:
        raise HTTPException(400, "No OpenAI API key configured. Set POSTERMAGIC_OPENAI_KEY on the server or pass api_key.")
    return key

def resolve_replicate_key(passed_key: str) -> str:
    key = (passed_key or "").strip() or SERVER_REPLICATE_KEY
    if not key:
        raise HTTPException(400, "No Replicate API key configured. Set POSTERMAGIC_REPLICATE_KEY on the server or pass replicate_key.")
    return key

# ── Admin bypass — owner's device skips the free-generation limit entirely ──
ADMIN_SECRET = os.environ.get("POSTERMAGIC_ADMIN_SECRET", "")

def check_and_increment_usage(device_id: str, admin_key: str = ""):
    """
    Raises HTTPException if this device has hit the free limit.
    Otherwise increments their count and allows the request to proceed.
    Call this at the START of every endpoint that costs real API money.

    If admin_key matches the server's configured ADMIN_SECRET, the limit
    is bypassed entirely (still counted, for visibility, but never blocks).
    """
    if not device_id:
        device_id = "unknown"
    is_admin = bool(ADMIN_SECRET) and admin_key == ADMIN_SECRET

    usage = _load_usage()
    count = usage.get(device_id, 0)
    if count >= FREE_LIMIT and not is_admin:
        raise HTTPException(
            429,
            f"You've used all {FREE_LIMIT} free generations on this device. "
            f"Please contact the app owner for more."
        )
    usage[device_id] = count + 1
    _save_usage(usage)
    usage[device_id] = count + 1
    _save_usage(usage)

# ── Theme prompts — focused on costume quality, not face ──────────────────────
# Solo themes have _male / _female variants. The /generate endpoint picks the
# right one based on the `gender` form field (defaults to male for back-compat).
THEME_PROMPTS = {
    "cricket_male": """Ultra-realistic cinematic cricket sports poster.
Indian male, age 45-55, medium-dark South Indian complexion.
Royal blue India cricket jersey with orange and white trim, gold '50+' laurel wreath emblem on chest.
Holding a cricket bat upright on right shoulder, wearing blue and white batting gloves on left hand.
Night cricket stadium background with bright stadium floodlights, golden confetti falling around him.
Half-body portrait shot, face looking directly forward at camera, neutral serious expression.
Professional sports photography lighting, photorealistic, sharp detail, 4K quality.""",

    "cricket_female": """Ultra-realistic cinematic cricket sports poster.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features.
Royal blue India cricket jersey with orange and white trim, gold '50+' laurel wreath emblem on chest, hair tied back neatly.
Holding a cricket bat upright on right shoulder, wearing blue and white batting gloves on left hand.
Night cricket stadium background with bright stadium floodlights, golden confetti falling around her.
Half-body portrait shot, face looking directly forward at camera, confident determined expression.
Professional sports photography lighting, photorealistic, sharp detail, 4K quality.""",

    "royal_male": """Ultra-realistic portrait of a royal king.
Indian male, age 45-55, medium-dark South Indian complexion.
Deep {COLOR} velvet robe with gold embroidery and trim, ornate gold crown with gemstones, holding royal sceptre.
Grand palace throne room background, marble columns, golden candlelight, dramatic regal lighting.
Half-body portrait, face looking directly forward at camera, serious dignified expression.
Photorealistic, sharp detail, 4K quality.""",

    "toast_male": """Ultra-realistic celebratory portrait of a distinguished man at an elegant evening party.
Indian male, age 40-55, medium-dark South Indian complexion, warm confident smile.
Smart {COLOR} formal blazer or sherwani, well-groomed, refined elegant styling.
Holding a champagne flute raised slightly in a toast gesture, angled naturally toward the camera.
Warm string lights and soft bokeh in the background, upscale lounge or rooftop party atmosphere, golden evening lighting.
Half-body portrait, body angled three-quarters, face looking directly forward at camera, joyful celebratory expression.
Photorealistic, sharp detail, warm and festive, 4K quality.""",

    "royal_female": """Ultra-realistic portrait of a royal queen.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features.
Opulent {COLOR} and gold royal gown or silk saree with intricate embroidered details and fine zari work.
Ornate gold crown with sparkling gemstones, layered statement gold jewellery — choker necklace, long haram necklace,
jhumka earrings, maang tikka, and bangles, all richly detailed.
Grand palace throne room background, marble columns, golden candlelight, dramatic regal lighting.
Half-body portrait, face looking directly forward at camera, serene dignified expression, radiant opulent atmosphere.
Photorealistic, sharp detail, luxurious and richly detailed, 4K quality.""",

    "royal_couple": """Ultra-realistic portrait of a royal king and queen standing together side by side.
LEFT PERSON (man): Indian male, age 45-55, medium-dark South Indian complexion.
Deep burgundy velvet robe with gold embroidery, ornate gold crown with gemstones, one arm around the queen's waist.
RIGHT PERSON (woman): Indian female, age 45-55, warm South Indian complexion, elegant features.
Rich emerald green and gold royal gown with embroidered details, matching ornate gold crown with gemstones, holding a small bouquet of red roses.
Both standing close together, facing camera, warm loving expressions, slight smiles.
Grand palace throne room background, marble columns, golden candlelight, soft romantic regal lighting.
Half-body double portrait, both faces clearly visible and forward facing.
Photorealistic, sharp detail, 4K quality.""",

    "bollywood_male": """Ultra-realistic Bollywood film star poster.
Indian male, age 45-55, medium-dark South Indian complexion.
Designer {COLOR} sherwani with gold accents and embroidery, stylish and glamorous.
Dramatic cinematic studio background, bold film poster lighting, dark background with golden tones.
Half-body portrait, face looking directly forward at camera, confident slight smile.
Photorealistic, sharp detail, 4K quality.""",

    "bollywood_female": """Ultra-realistic Bollywood film star poster.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features.
Elegant flowing {COLOR} and gold designer gown with delicate embroidery, statement gold jewellery, glamorous styled hair.
Dramatic cinematic studio background, bold film poster lighting, dark background with golden tones.
Half-body portrait, face looking directly forward at camera, confident radiant smile.
Photorealistic, sharp detail, 4K quality.""",

    "traditional_male": """Ultra-realistic portrait in traditional Indian formal wear.
Indian male, age 45-55, medium-dark South Indian complexion.
Elegant {COLOR} sherwani with gold embroidery, dupatta draped on shoulder.
Warm festive background with marigold flowers, diyas, soft golden lighting.
Half-body portrait, face looking directly forward at camera, dignified warm expression.
Photorealistic, sharp detail, 4K quality.""",

    "traditional_female": """Ultra-realistic portrait in traditional Indian formal wear.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features.
Rich {COLOR} silk saree with gold zari embroidery, traditional gold jewellery, dupatta draped elegantly.
Warm festive background with marigold flowers, diyas, soft golden lighting.
Half-body portrait, face looking directly forward at camera, dignified warm expression.
Photorealistic, sharp detail, 4K quality.""",

    "devi_female": """Ultra-realistic divine-inspired portrait, celebrating feminine grace and strength.
Indian woman in her mid-30s to 40s, warm South Indian complexion, serene radiant features, gentle glowing skin.
Rich {COLOR} silk saree with intricate gold zari embroidery, ornate gold temple jewellery, elegant gold headpiece with a delicate maang tikka.
Soft halo-like golden light surrounding the figure, intricately carved temple stone backdrop, scattered marigold petals and flickering diya lamps.
Half-body portrait, face looking directly forward at camera, calm benevolent expression, eyes warm and luminous.
Photorealistic, sharp detail, devotional yet tasteful and elegant, 4K quality.""",

    "army_male": """Ultra-realistic portrait of an Indian Army officer.
Indian male, age 45-55, medium-dark South Indian complexion.
Full Indian Army officer uniform, olive green with medals and rank insignia, officer beret.
Patriotic background with Indian flag, dramatic hero lighting.
Half-body portrait, face looking directly forward at camera, proud dignified expression.
Photorealistic, sharp detail, 4K quality.""",

    "army_female": """Ultra-realistic portrait of an Indian Army officer.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features, hair neatly tied back.
Full Indian Army officer uniform, olive green with medals and rank insignia, officer cap.
Patriotic background with Indian flag, dramatic hero lighting.
Half-body portrait, face looking directly forward at camera, proud dignified expression.
Photorealistic, sharp detail, 4K quality.""",

    "royal_king": """Ultra-realistic portrait of a royal king, solo portrait.
Indian male, age 45-55, medium-dark South Indian complexion.
Deep burgundy velvet robe with gold embroidery and trim, ornate gold crown with emerald and ruby gemstones, holding a small bouquet of red roses in one hand at waist height.
Grand palace throne room background, marble columns, warm golden candlelight, soft romantic regal lighting.
Half-body portrait, face looking directly forward at camera, warm loving smile.
Photorealistic, sharp detail, 4K quality.""",

    "royal_queen": """Ultra-realistic portrait of a royal queen, solo portrait.
Indian female, age 45-55, warm South Indian complexion, elegant graceful features.
Rich emerald green and gold royal gown with intricate embroidered details, ornate gold crown with emerald and ruby gemstones, gold necklace and earrings, holding a small bouquet of red roses at waist height.
Grand palace throne room background, marble columns, warm golden candlelight, soft romantic regal lighting.
Half-body portrait, face looking directly forward at camera, warm loving smile.
Photorealistic, sharp detail, 4K quality.""",

    "anniv_traditional_groom": """Ultra-realistic portrait of an Indian man in traditional wedding attire, solo portrait.
Indian male, age 45-55, medium-dark South Indian complexion.
Elegant cream and gold silk sherwani with intricate gold embroidery, matching turban with a small brooch, holding a single red rose at waist height.
Warm festive background with soft marigold flowers and diyas, golden bokeh lighting.
Half-body portrait, face looking directly forward at camera, warm loving smile.
Photorealistic, sharp detail, 4K quality.""",

    "anniv_traditional_bride": """Ultra-realistic portrait of an Indian woman in traditional wedding attire, solo portrait.
Indian female, age 45-55, warm South Indian complexion, elegant graceful features.
Rich red and gold silk saree with intricate zari embroidery, traditional gold jewellery including maang tikka, necklace and earrings, holding a single red rose at waist height.
Warm festive background with soft marigold flowers and diyas, golden bokeh lighting.
Half-body portrait, face looking directly forward at camera, warm loving smile.
Photorealistic, sharp detail, 4K quality.""",

    "anniv_bollywood_groom": """Ultra-realistic Bollywood film-poster style romantic portrait of a man, solo portrait.
Indian male, age 45-55, medium-dark South Indian complexion.
Sharp black designer bandhgala suit with subtle gold buttons, open collar, confident relaxed pose.
Dramatic cinematic background with warm amber and rose-gold lighting, soft bokeh lights.
Half-body portrait, face looking directly forward at camera, charming confident smile.
Photorealistic, sharp detail, film-poster quality, 4K.""",

    "anniv_bollywood_bride": """Ultra-realistic Bollywood film-poster style romantic portrait of a woman, solo portrait.
Indian female, age 45-55, warm South Indian complexion, elegant graceful features.
Elegant flowing wine-red and gold designer gown with delicate embroidery, soft wavy hair, statement gold earrings.
Dramatic cinematic background with warm amber and rose-gold lighting, soft bokeh lights.
Half-body portrait, face looking directly forward at camera, charming radiant smile.
Photorealistic, sharp detail, film-poster quality, 4K.""",

    "anniv_vintage_groom": """Ultra-realistic vintage-style classic wedding portrait of a man, solo portrait, 1970s studio photography aesthetic.
Indian male, age 45-55, medium-dark South Indian complexion.
Classic cream three-piece suit with wide lapels, neatly combed hair, warm sepia-toned studio background with soft vignette.
Half-body portrait, face looking directly forward at camera, gentle nostalgic smile.
Photorealistic, soft film grain, warm vintage color grading, sharp detail.""",

    "anniv_vintage_bride": """Ultra-realistic vintage-style classic wedding portrait of a woman, solo portrait, 1970s studio photography aesthetic.
Indian female, age 45-55, warm South Indian complexion, elegant graceful features, hair styled in a classic vintage updo.
Elegant traditional silk saree with simple gold border, classic gold jewellery, warm sepia-toned studio background with soft vignette.
Half-body portrait, face looking directly forward at camera, gentle nostalgic smile.
Photorealistic, soft film grain, warm vintage color grading, sharp detail.""",

    "embrace_traditional": """Ultra-realistic romantic portrait of an Indian couple standing close together, arms around each other, celebrating their wedding anniversary.
MAN (left side): Indian male, age 45-55, medium-dark South Indian complexion, wearing elegant cream and gold silk sherwani with gold embroidery, right arm wrapped warmly around the woman's waist/shoulder.
WOMAN (right side): Indian female, age 45-55, warm South Indian complexion, wearing rich red and gold silk saree with zari embroidery, traditional gold jewellery, leaning into the man, one hand resting on his chest or arm.
Both standing very close together, bodies touching, warm affectionate body language, both facing camera with joyful smiles.
Warm festive background with soft marigold flowers and diyas, golden bokeh lighting.
Half-body double portrait, both faces clearly visible and forward-facing, photorealistic, sharp detail, 4K quality.""",

    "embrace_bollywood": """Ultra-realistic Bollywood film-poster style romantic portrait of a couple standing close together, arms around each other.
MAN (left side): Indian male, age 45-55, medium-dark South Indian complexion, wearing a sharp black designer bandhgala suit, right arm around the woman's waist, other hand gently holding her hand.
WOMAN (right side): Indian female, age 45-55, warm South Indian complexion, wearing an elegant flowing wine-red and gold designer gown, leaning into the man's shoulder affectionately.
Both standing very close together, romantic embrace, both facing camera with warm charming smiles.
Dramatic cinematic background with warm amber and rose-gold lighting, soft bokeh lights.
Half-body double portrait, both faces clearly visible and forward-facing, photorealistic, film-poster quality, 4K.""",

    "embrace_vintage": """Ultra-realistic vintage-style classic wedding anniversary portrait of a couple standing close together, arms around each other, 1970s studio photography aesthetic.
MAN (left side): Indian male, age 45-55, medium-dark South Indian complexion, wearing a classic cream three-piece suit with wide lapels, arm around the woman's shoulder.
WOMAN (right side): Indian female, age 45-55, warm South Indian complexion, wearing an elegant traditional silk saree with simple gold border, leaning into the man.
Both standing very close together, gentle nostalgic affection, both facing camera with warm smiles.
Warm sepia-toned studio background with soft vignette.
Half-body double portrait, both faces clearly visible and forward-facing, photorealistic, soft film grain, warm vintage color grading.""",

    "retirement_male": """Ultra-realistic portrait of a distinguished professional at retirement celebration.
Indian male, age 45-55, medium-dark South Indian complexion.
Smart dark navy formal suit with tie, holding a gold trophy award.
Warm celebratory background with golden bokeh and confetti falling.
Half-body portrait, face looking directly forward at camera, proud warm smile.
Photorealistic, sharp detail, 4K quality.""",

    "retirement_female": """Ultra-realistic portrait of a distinguished professional at retirement celebration.
Indian woman in her mid-30s to 40s, warm South Indian complexion, mature graceful features.
Smart formal blazer or elegant silk saree, holding a gold trophy award, tasteful jewellery.
Warm celebratory background with golden bokeh and confetti falling.
Half-body portrait, face looking directly forward at camera, proud warm smile.
Photorealistic, sharp detail, 4K quality."""
}


def pil_to_b64(img, fmt="JPEG", quality=92):
    buf = io.BytesIO()
    img.save(buf, format=fmt, quality=quality)
    return base64.b64encode(buf.getvalue()).decode()


def load_font(size, bold=False):
    from PIL import ImageFont
    candidates = [
        "C:/Windows/Fonts/arialbd.ttf" if bold else "C:/Windows/Fonts/arial.ttf",
        "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf" if bold else
        "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
    ]
    for p in candidates:
        if os.path.exists(p):
            return ImageFont.truetype(p, size)
    return ImageFont.load_default()


@app.get("/")
def root():
    return {"status": "PosterMagic v2.0 running — with face swap!"}


@app.get("/usage-status")
def usage_status(device_id: str = ""):
    """Returns how many free generations this device has left, for the UI to display."""
    usage = _load_usage()
    used = usage.get(device_id or "unknown", 0)
    return {"used": used, "limit": FREE_LIMIT, "remaining": max(0, FREE_LIMIT - used)}


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT 1: Generate themed costume via GPT-image-1
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/generate")
async def generate_costume(
    photo:      UploadFile = File(...),
    theme:      str        = Form(default="cricket"),
    gender:     str        = Form(default="male"),   # "male" or "female"
    color:      str        = Form(default=""),       # optional outfit colour override
    age_group:  str        = Form(default=""),       # optional: "20s-30s" | "30s-40s" | "40s-50s" | "50s-60s"
    name:       str        = Form(default=""),
    api_key:    str        = Form(default=""),
    variants:   int        = Form(default=2),
    device_id:  str        = Form(default=""),
    admin_key:  str        = Form(default=""),
):
    check_and_increment_usage(device_id, admin_key)

    # Solo themes are stored as "<theme>_male" / "<theme>_female".
    # Couple/embrace themes (royal_couple, embrace_*) have no gender suffix.
    if theme in THEME_PROMPTS:
        resolved_key = theme  # already a couple/embrace theme, or fully-qualified key
    else:
        suffix = "female" if gender == "female" else "male"
        resolved_key = f"{theme}_{suffix}"

    if resolved_key not in THEME_PROMPTS:
        raise HTTPException(400, f"Unknown theme '{theme}' for gender '{gender}'. Resolved key '{resolved_key}' not found.")

    # Default outfit colour per theme if user didn't pick one, keeps old look unchanged
    DEFAULT_COLORS = {
        "royal_male": "burgundy", "royal_female": "emerald green",
        "toast_male": "navy blue",
        "bollywood_male": "black", "bollywood_female": "wine-red",
        "traditional_male": "ivory and cream", "traditional_female": "red",
        "devi_female": "red and gold",
    }
    chosen_color = color.strip() if color.strip() else DEFAULT_COLORS.get(resolved_key, "")

    prompt = THEME_PROMPTS[resolved_key]
    if "{COLOR}" in prompt:
        prompt = prompt.replace("{COLOR}", chosen_color)

    AGE_LABELS = {
        "20s-30s": "in their 20s to early 30s, youthful features",
        "30s-40s": "in their mid-30s to 40s, mature graceful features",
        "40s-50s": "in their 40s to 50s, dignified mature features",
        "50s-60s": "in their 50s to 60s, distinguished mature features",
    }
    if age_group.strip() in AGE_LABELS:
        prompt += (f"\n\nIMPORTANT: The person should look {AGE_LABELS[age_group.strip()]}. "
                   f"This overrides any other age mentioned above.")

    if name:
        prompt += f'\n\nIMPORTANT: The name "{name.upper()}" must appear clearly on the jersey/costume.'

    client  = OpenAI(api_key=resolve_openai_key(api_key))
    results = []

    try:
        for _ in range(max(1, min(variants, 3))):
            response = client.images.generate(
                model="gpt-image-1",
                prompt=prompt,
                n=1,
                size="1024x1536",
                quality="high",
                moderation="low",
            )
            # Note: gpt-image-1 does not accept reference images directly
            # The face swap step handles the actual face replacement
            results.append(response.data[0].b64_json)
    except Exception as e:
        raise HTTPException(500, f"OpenAI error: {str(e)}")

    return JSONResponse({"success": True, "variants": results, "count": len(results)})


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT 2: Face swap — paste real face onto generated costume
# Uses: codeplugtech/face-swap on Replicate (~$0.0024 per swap)
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/faceswap")
async def face_swap(
    face_photo:    UploadFile = File(...),
    costume_image: UploadFile = File(...),
    replicate_key: str        = Form(default=""),
    device_id:     str        = Form(default=""),  # only metered if provided (saree-swap flow)
    admin_key:     str        = Form(default=""),
):
    # Only meter usage here if a device_id was explicitly passed — this lets the
    # saree-swap flow (which calls /faceswap directly, no prior /generate call)
    # count as one try, while the normal poster flow's faceswap calls stay
    # unmetered since /generate already counted that action.
    if device_id:
        check_and_increment_usage(device_id, admin_key)

    # Read and prepare both images
    face_raw    = await face_photo.read()
    costume_raw = await costume_image.read()

    face_img    = Image.open(io.BytesIO(face_raw)).convert("RGB")
    costume_img = Image.open(io.BytesIO(costume_raw)).convert("RGB")
    face_img.thumbnail((1024, 1024), Image.LANCZOS)

    face_b64    = pil_to_b64(face_img)
    costume_b64 = pil_to_b64(costume_img)

    face_data_url    = f"data:image/jpeg;base64,{face_b64}"
    costume_data_url = f"data:image/jpeg;base64,{costume_b64}"

    try:
        headers = {
            "Authorization": f"Token {resolve_replicate_key(replicate_key)}",
            "Content-Type": "application/json",
        }

        # Save images to temp files and encode as base64 data URLs
        def img_to_dataurl(img):
            buf = io.BytesIO()
            img.save(buf, format="PNG")
            b64 = base64.b64encode(buf.getvalue()).decode()
            return f"data:image/png;base64,{b64}"

        face_dataurl    = img_to_dataurl(face_img)
        costume_dataurl = img_to_dataurl(costume_img)

        print(f"Face URL length: {len(face_dataurl)}")
        print(f"Costume URL length: {len(costume_dataurl)}")

        # Call Replicate predictions API directly
        async with httpx.AsyncClient(timeout=120) as http:
            # Create prediction
            create_resp = await http.post(
                "https://api.replicate.com/v1/predictions",
                headers=headers,
                json={
                    "version": "278a81e7ebb22db98bcba54de985d22cc1abeead2754eb1f2af717247be69b34",
                    "input": {
                        "swap_image":  face_dataurl,
                        "input_image": costume_dataurl,
                    }
                }
            )
            print(f"Create response: {create_resp.status_code} {create_resp.text[:300]}")

            if create_resp.status_code not in (200, 201):
                raise Exception(f"Prediction create failed: {create_resp.text}")

            prediction = create_resp.json()
            pred_id    = prediction["id"]
            poll_url   = prediction["urls"]["get"]

            # Poll until done
            import asyncio
            for attempt in range(60):
                await asyncio.sleep(3)
                poll_resp = await http.get(poll_url, headers=headers)
                poll_data = poll_resp.json()
                status    = poll_data.get("status")
                print(f"Poll {attempt+1}: {status}")

                if status == "succeeded":
                    output = poll_data.get("output")
                    print(f"Output: {str(output)[:200]}")
                    if not output:
                        raise Exception(f"Prediction succeeded but returned no output. Full response: {poll_data}")
                    # Download result
                    result_url = output if isinstance(output, str) else output[0]
                    img_resp   = await http.get(result_url)
                    result_b64 = base64.b64encode(img_resp.content).decode()
                    return JSONResponse({"success": True, "swapped": result_b64})

                elif status == "failed":
                    raise Exception(f"Prediction failed: {poll_data.get('error')}")

            raise Exception("Prediction timed out after 3 minutes")

    except Exception as e:
        print(f"Face swap error detail: {e}")
        raise HTTPException(500, f"Face swap error: {str(e)}")


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT: Virtual Try-On — wear any uploaded garment photo
# Uses: cuuupid/idm-vton on Replicate
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/tryon")
async def virtual_tryon(
    person_photo:  UploadFile = File(...),
    garment_photo: UploadFile = File(...),
    category:      str        = Form(default="upper_body"),  # upper_body | lower_body | dresses
    garment_des:   str        = Form(default=""),
    replicate_key: str        = Form(default=""),
    device_id:     str        = Form(default=""),
    admin_key:     str        = Form(default=""),
):
    check_and_increment_usage(device_id, admin_key)

    if category not in ("upper_body", "lower_body", "dresses"):
        raise HTTPException(400, "category must be one of: upper_body, lower_body, dresses")

    person_raw  = await person_photo.read()
    garment_raw = await garment_photo.read()

    person_img  = Image.open(io.BytesIO(person_raw)).convert("RGB")
    garment_img = Image.open(io.BytesIO(garment_raw)).convert("RGB")
    person_img.thumbnail((1024, 1024), Image.LANCZOS)
    garment_img.thumbnail((1024, 1024), Image.LANCZOS)

    def img_to_dataurl(img):
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        b64 = base64.b64encode(buf.getvalue()).decode()
        return f"data:image/png;base64,{b64}"

    person_dataurl  = img_to_dataurl(person_img)
    garment_dataurl = img_to_dataurl(garment_img)

    try:
        headers = {
            "Authorization": f"Token {resolve_replicate_key(replicate_key)}",
            "Content-Type": "application/json",
        }

        async with httpx.AsyncClient(timeout=180) as http:
            create_resp = await http.post(
                "https://api.replicate.com/v1/predictions",
                headers=headers,
                json={
                    "version": "3b032a70c29aef7b9c3222f2e40b71660201d8c288336475ba326f3ca278a3e1",
                    "input": {
                        "human_img":   person_dataurl,
                        "garm_img":    garment_dataurl,
                        "category":    category,
                        "garment_des": garment_des or "garment",
                        "crop":        False,
                    }
                }
            )
            print(f"Try-on create response: {create_resp.status_code} {create_resp.text[:300]}")

            if create_resp.status_code not in (200, 201):
                raise Exception(f"Prediction create failed: {create_resp.text}")

            prediction = create_resp.json()
            poll_url   = prediction["urls"]["get"]

            import asyncio
            for attempt in range(60):
                await asyncio.sleep(3)
                poll_resp = await http.get(poll_url, headers=headers)
                poll_data = poll_resp.json()
                status    = poll_data.get("status")
                print(f"Try-on poll {attempt+1}: {status}")

                if status == "succeeded":
                    output = poll_data.get("output")
                    if not output:
                        raise Exception(f"Prediction succeeded but returned no output. Full response: {poll_data}")
                    result_url = output if isinstance(output, str) else output[0]
                    img_resp   = await http.get(result_url)
                    result_b64 = base64.b64encode(img_resp.content).decode()
                    return JSONResponse({"success": True, "result": result_b64})

                elif status == "failed":
                    raise Exception(f"Prediction failed: {poll_data.get('error')}")

            raise Exception("Prediction timed out after 3 minutes")

    except Exception as e:
        print(f"Try-on error detail: {e}")
        raise HTTPException(500, f"Try-on error: {str(e)}")


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT 3: Assemble final poster with Pillow
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/faceswap-couple")
async def face_swap_couple(
    face_photo_1:  UploadFile = File(...),   # first person's real face (e.g. husband)
    face_photo_2:  UploadFile = File(...),   # second person's real face (e.g. wife)
    costume_image: UploadFile = File(...),   # GPT-4o generated couple costume image
    replicate_key: str        = Form(default=""),
):
    """
    Swaps two real faces onto a generated couple image.
    Runs face swap twice — once per person — on crops of left/right half,
    then merges back. Uses same codeplugtech/face-swap model sequentially.
    """
    face1_raw    = await face_photo_1.read()
    face2_raw    = await face_photo_2.read()
    costume_raw  = await costume_image.read()

    face1_img   = Image.open(io.BytesIO(face1_raw)).convert("RGB")
    face2_img   = Image.open(io.BytesIO(face2_raw)).convert("RGB")
    costume_img = Image.open(io.BytesIO(costume_raw)).convert("RGB")

    face1_img.thumbnail((1024, 1024), Image.LANCZOS)
    face2_img.thumbnail((1024, 1024), Image.LANCZOS)

    def img_to_dataurl(img):
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        b64 = base64.b64encode(buf.getvalue()).decode()
        return f"data:image/png;base64,{b64}"

    headers = {
        "Authorization": f"Token {resolve_replicate_key(replicate_key)}",
        "Content-Type": "application/json",
    }

    async def run_faceswap(swap_face_dataurl, target_dataurl, http):
        create_resp = await http.post(
            "https://api.replicate.com/v1/predictions",
            headers=headers,
            json={
                "version": "278a81e7ebb22db98bcba54de985d22cc1abeead2754eb1f2af717247be69b34",
                "input": {
                    "swap_image":  swap_face_dataurl,
                    "input_image": target_dataurl,
                }
            }
        )
        if create_resp.status_code not in (200, 201):
            raise Exception(f"Prediction create failed: {create_resp.text}")
        prediction = create_resp.json()
        poll_url   = prediction["urls"]["get"]

        import asyncio
        for _ in range(60):
            await asyncio.sleep(3)
            poll_resp = await http.get(poll_url, headers=headers)
            poll_data = poll_resp.json()
            status = poll_data.get("status")
            if status == "succeeded":
                output = poll_data.get("output")
                if not output:
                    raise Exception(f"Prediction succeeded but returned no output. Full response: {poll_data}")
                result_url = output if isinstance(output, str) else output[0]
                img_resp = await http.get(result_url)
                return Image.open(io.BytesIO(img_resp.content)).convert("RGB")
            elif status == "failed":
                raise Exception(f"Prediction failed: {poll_data.get('error')}")
        raise Exception("Prediction timed out")

    try:
        costume_dataurl = img_to_dataurl(costume_img)
        face1_dataurl   = img_to_dataurl(face1_img)
        face2_dataurl   = img_to_dataurl(face2_img)

        async with httpx.AsyncClient(timeout=120) as http:
            print("Swapping face 1 (left person)...")
            step1_img = await run_faceswap(face1_dataurl, costume_dataurl, http)

            step1_dataurl = img_to_dataurl(step1_img)
            print("Swapping face 2 (right person) onto result of step 1...")
            step2_img = await run_faceswap(face2_dataurl, step1_dataurl, http)

        buf = io.BytesIO()
        step2_img.save(buf, format="PNG")
        result_b64 = base64.b64encode(buf.getvalue()).decode()

        return JSONResponse({"success": True, "swapped": result_b64})

    except Exception as e:
        print(f"Couple face swap error: {e}")
        raise HTTPException(500, f"Couple face swap error: {str(e)}")


# Maps couple theme key -> (groom prompt key, bride prompt key)
COUPLE_THEME_PAIRS = {
    "royal_couple":            ("royal_king",             "royal_queen"),
    "anniv_traditional":       ("anniv_traditional_groom", "anniv_traditional_bride"),
    "anniv_bollywood":         ("anniv_bollywood_groom",   "anniv_bollywood_bride"),
    "anniv_vintage":           ("anniv_vintage_groom",     "anniv_vintage_bride"),
}


@app.post("/generate-couple-solo")
async def generate_couple_solo(
    photo_1:       UploadFile = File(...),   # husband's face
    photo_2:       UploadFile = File(...),   # wife's face
    name_1:        str        = Form(default=""),
    name_2:        str        = Form(default=""),
    couple_theme:  str        = Form(default="royal_couple"),
    api_key:       str        = Form(default=""),
    replicate_key: str        = Form(default=""),
    device_id:     str        = Form(default=""),
    admin_key:     str        = Form(default=""),
):
    """
    Generates TWO separate solo portraits (groom + bride style for the chosen theme),
    face-swaps each independently (no cross-contamination),
    then returns both swapped images for the user to review.
    """
    check_and_increment_usage(device_id, admin_key)

    if couple_theme not in COUPLE_THEME_PAIRS:
        raise HTTPException(400, f"Unknown couple theme. Choose from: {list(COUPLE_THEME_PAIRS.keys())}")

    groom_key, bride_key = COUPLE_THEME_PAIRS[couple_theme]
    client = OpenAI(api_key=resolve_openai_key(api_key))

    try:
        # ── Step A: Generate groom/king costume ──────────────────────────────
        groom_prompt = THEME_PROMPTS[groom_key]
        groom_resp = client.images.generate(
            model="gpt-image-1", prompt=groom_prompt, n=1,
            size="1024x1536", quality="high", moderation="low",
        )
        king_costume_b64 = groom_resp.data[0].b64_json

        # ── Step B: Generate bride/queen costume ─────────────────────────────
        bride_prompt = THEME_PROMPTS[bride_key]
        bride_resp = client.images.generate(
            model="gpt-image-1", prompt=bride_prompt, n=1,
            size="1024x1536", quality="high", moderation="low",
        )
        queen_costume_b64 = bride_resp.data[0].b64_json

    except Exception as e:
        raise HTTPException(500, f"OpenAI error: {str(e)}")

    # ── Step C: Face swap each independently (no cross-talk) ──────────────────
    face1_raw = await photo_1.read()
    face2_raw = await photo_2.read()
    face1_img = Image.open(io.BytesIO(face1_raw)).convert("RGB")
    face2_img = Image.open(io.BytesIO(face2_raw)).convert("RGB")
    face1_img.thumbnail((1024, 1024), Image.LANCZOS)
    face2_img.thumbnail((1024, 1024), Image.LANCZOS)

    king_costume_img  = Image.open(io.BytesIO(base64.b64decode(king_costume_b64))).convert("RGB")
    queen_costume_img = Image.open(io.BytesIO(base64.b64decode(queen_costume_b64))).convert("RGB")

    def img_to_dataurl(img):
        buf = io.BytesIO()
        img.save(buf, format="PNG")
        return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}"

    headers = {"Authorization": f"Token {resolve_replicate_key(replicate_key)}", "Content-Type": "application/json"}

    async def run_faceswap(swap_face_dataurl, target_dataurl, http):
        create_resp = await http.post(
            "https://api.replicate.com/v1/predictions",
            headers=headers,
            json={
                "version": "278a81e7ebb22db98bcba54de985d22cc1abeead2754eb1f2af717247be69b34",
                "input": {"swap_image": swap_face_dataurl, "input_image": target_dataurl}
            }
        )
        if create_resp.status_code not in (200, 201):
            raise Exception(f"Prediction create failed: {create_resp.text}")
        poll_url = create_resp.json()["urls"]["get"]

        import asyncio
        for _ in range(60):
            await asyncio.sleep(3)
            poll_resp = await http.get(poll_url, headers=headers)
            poll_data = poll_resp.json()
            status = poll_data.get("status")
            if status == "succeeded":
                output = poll_data.get("output")
                if not output:
                    raise Exception(f"Prediction succeeded but returned no output. Full response: {poll_data}")
                result_url = output if isinstance(output, str) else output[0]
                img_resp = await http.get(result_url)
                return Image.open(io.BytesIO(img_resp.content)).convert("RGB")
            elif status == "failed":
                raise Exception(f"Prediction failed: {poll_data.get('error')}")
        raise Exception("Prediction timed out")

    try:
        async with httpx.AsyncClient(timeout=120) as http:
            print("Swapping face onto king costume...")
            king_swapped = await run_faceswap(
                img_to_dataurl(face1_img), img_to_dataurl(king_costume_img), http
            )
            print("Swapping face onto queen costume...")
            queen_swapped = await run_faceswap(
                img_to_dataurl(face2_img), img_to_dataurl(queen_costume_img), http
            )
    except Exception as e:
        print(f"Face swap error: {e}")
        raise HTTPException(500, f"Face swap error: {str(e)}")

    king_buf, queen_buf = io.BytesIO(), io.BytesIO()
    king_swapped.save(king_buf, format="PNG")
    queen_swapped.save(queen_buf, format="PNG")

    return JSONResponse({
        "success": True,
        "king":  base64.b64encode(king_buf.getvalue()).decode(),
        "queen": base64.b64encode(queen_buf.getvalue()).decode(),
    })


@app.post("/combine-couple")
async def combine_couple(
    king_image:  UploadFile = File(...),
    queen_image: UploadFile = File(...),
):
    """
    Combines two solo swapped portraits side by side into one image,
    ready for the standard assemble-poster step.
    """
    king_raw  = await king_image.read()
    queen_raw = await queen_image.read()
    king_img  = Image.open(io.BytesIO(king_raw)).convert("RGB")
    queen_img = Image.open(io.BytesIO(queen_raw)).convert("RGB")

    # Match heights
    target_h = max(king_img.height, queen_img.height)
    def resize_to_h(img, h):
        w = int(img.width * h / img.height)
        return img.resize((w, h), Image.LANCZOS)

    king_img  = resize_to_h(king_img, target_h)
    queen_img = resize_to_h(queen_img, target_h)

    gap = 20
    combined = Image.new("RGB", (king_img.width + queen_img.width + gap, target_h), (6, 9, 26))
    combined.paste(king_img, (0, 0))
    combined.paste(queen_img, (king_img.width + gap, 0))

    buf = io.BytesIO()
    combined.save(buf, format="PNG")
    return JSONResponse({"success": True, "combined": base64.b64encode(buf.getvalue()).decode()})


EMBRACE_THEMES = {"embrace_traditional", "embrace_bollywood", "embrace_vintage"}


def detect_two_faces(img: Image.Image):
    """Detect two largest faces in image, return sorted left-to-right as (x,y,w,h) boxes."""
    import cv2
    import numpy as np
    cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
    cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
    faces = cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(40, 40))
    if len(faces) < 2:
        return None
    # Take two largest by area, sort left to right by x
    faces_sorted_by_area = sorted(faces, key=lambda f: f[2]*f[3], reverse=True)[:2]
    faces_left_to_right = sorted(faces_sorted_by_area, key=lambda f: f[0])
    return faces_left_to_right  # [(x,y,w,h) for left person, (x,y,w,h) for right person]


def crop_face_box(img: Image.Image, face_box, pad_factor=1.6):
    """Crop a face region from an image given a (x,y,w,h) box, same padding as merge_face_region."""
    x, y, w, h = (int(v) for v in face_box)
    cx, cy = x + w/2, y + h/2
    half_w = w * pad_factor / 2
    half_h = h * pad_factor * 1.3 / 2
    left   = max(0, int(cx - half_w))
    right  = min(img.width, int(cx + half_w))
    top    = max(0, int(cy - half_h * 1.1))
    bottom = min(img.height, int(cy + half_h * 1.4))
    return img.crop((left, top, right, bottom))


def face_similarity(img_a: Image.Image, img_b: Image.Image) -> float:
    """
    Lightweight face similarity score using colour-histogram correlation.
    Not as accurate as a dedicated face-embedding model, but dependency-free
    and good enough to disambiguate "which swapped face is the husband's".
    Returns a score where higher = more similar (range roughly -1 to 1).
    """
    import cv2
    import numpy as np

    def prep(img):
        arr = np.array(img.convert("RGB").resize((100, 100)))
        bgr = cv2.cvtColor(arr, cv2.COLOR_RGB2BGR)
        hist = cv2.calcHist([bgr], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
        cv2.normalize(hist, hist)
        return hist

    hist_a = prep(img_a)
    hist_b = prep(img_b)
    return cv2.compareHist(hist_a, hist_b, cv2.HISTCMP_CORREL)


def merge_face_region(original_img, swapped_img, dest_box, pad_factor=1.6, source_box=None):
    """
    Paste the face region from swapped_img onto a copy of original_img,
    using an elliptical soft-edged mask for natural blending.

    dest_box:   (x,y,w,h) position in original_img where the face should END UP.
    source_box: (x,y,w,h) position in swapped_img to crop FROM. Defaults to
                dest_box if not given (assumes same position in both images).
    """
    from PIL import ImageDraw, ImageFilter
    if source_box is None:
        source_box = dest_box

    dx, dy, dw, dh = (int(v) for v in dest_box)
    dcx, dcy = dx + dw/2, dy + dh/2
    half_w = dw * pad_factor / 2
    half_h = dh * pad_factor * 1.3 / 2

    left   = max(0, int(dcx - half_w))
    right  = min(original_img.width, int(dcx + half_w))
    top    = max(0, int(dcy - half_h * 1.1))
    bottom = min(original_img.height, int(dcy + half_h * 1.4))

    # Crop the SAME relative region from the swapped image, but centred on source_box
    sx, sy, sw, sh = (int(v) for v in source_box)
    scx, scy = sx + sw/2, sy + sh/2
    s_half_w = sw * pad_factor / 2
    s_half_h = sh * pad_factor * 1.3 / 2
    s_left   = max(0, int(scx - s_half_w))
    s_right  = min(swapped_img.width, int(scx + s_half_w))
    s_top    = max(0, int(scy - s_half_h * 1.1))
    s_bottom = min(swapped_img.height, int(scy + s_half_h * 1.4))

    result = original_img.copy()

    # Build soft elliptical mask sized to the DESTINATION region
    mask = Image.new("L", (right-left, bottom-top), 0)
    mdraw = ImageDraw.Draw(mask)
    mdraw.ellipse([0, 0, right-left, bottom-top], fill=255)
    mask = mask.filter(ImageFilter.GaussianBlur(radius=max(8, dw//8)))

    face_crop = swapped_img.crop((s_left, s_top, s_right, s_bottom))
    # Resize source crop to match destination size exactly (positions may differ slightly)
    face_crop = face_crop.resize((right-left, bottom-top), Image.LANCZOS)

    result.paste(face_crop, (left, top), mask)
    return result


@app.post("/generate-embrace-couple")
async def generate_embrace_couple(
    photo_1:       UploadFile = File(...),   # first person's face (e.g. husband)
    photo_2:       UploadFile = File(...),   # second person's face (e.g. wife)
    embrace_theme: str        = Form(default="embrace_traditional"),
    api_key:       str        = Form(default=""),
    variants:      int        = Form(default=2),
    device_id:     str        = Form(default=""),
    admin_key:     str        = Form(default=""),
):
    """
    Sends BOTH face photos directly to GPT-image-1's edit endpoint in a
    single call, exactly like pasting two photos into ChatGPT and asking
    for a themed couple portrait. No separate face-swap step — GPT-image-1
    handles face placement itself, the same way it does in the ChatGPT app.
    Generates multiple variants so the user can pick the best likeness match,
    same as retrying in the ChatGPT app.
    """
    check_and_increment_usage(device_id, admin_key)

    if embrace_theme not in EMBRACE_THEMES:
        raise HTTPException(400, f"Unknown embrace theme. Choose from: {list(EMBRACE_THEMES)}")

    prompt = THEME_PROMPTS[embrace_theme]
    prompt += ("\n\nUse the two reference photos provided as the exact likenesses "
               "of the two people in this scene — preserve each person's real facial "
               "features, skin tone, and identity as shown in their reference photo. "
               "The first reference photo is the man, the second reference photo is "
               "the woman.")

    # Read uploaded photos into memory, normalise to PNG (gpt-image-1 edit accepts png/webp/jpg)
    face1_raw = await photo_1.read()
    face2_raw = await photo_2.read()

    face1_img = Image.open(io.BytesIO(face1_raw)).convert("RGB")
    face2_img = Image.open(io.BytesIO(face2_raw)).convert("RGB")
    face1_img.thumbnail((1024, 1024), Image.LANCZOS)
    face2_img.thumbnail((1024, 1024), Image.LANCZOS)

    buf1, buf2 = io.BytesIO(), io.BytesIO()
    face1_img.save(buf1, format="PNG")
    face2_img.save(buf2, format="PNG")
    buf1.seek(0)
    buf2.seek(0)
    buf1.name = "person1.png"
    buf2.name = "person2.png"

    client = OpenAI(api_key=resolve_openai_key(api_key))

    try:
        resp = client.images.edit(
            model="gpt-image-1",
            image=[buf1, buf2],
            prompt=prompt,
            size="1024x1536",
            quality="high",
            n=max(1, min(variants, 4)),
        )
        results = [item.b64_json for item in resp.data]
    except Exception as e:
        print(f"Embrace generation error: {e}")
        raise HTTPException(500, f"OpenAI error: {str(e)}")

    return JSONResponse({"success": True, "variants": results, "count": len(results)})


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT: Generate embrace couple portrait from a SINGLE couple photo
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/generate-embrace-couple-photo")
async def generate_embrace_couple_photo(
    couple_photo:  UploadFile = File(...),
    embrace_theme: str        = Form(default="embrace_traditional"),
    api_key:       str        = Form(default=""),
    variants:      int        = Form(default=2),
    device_id:     str        = Form(default=""),
    admin_key:     str        = Form(default=""),
):
    check_and_increment_usage(device_id, admin_key)

    if embrace_theme not in EMBRACE_THEMES:
        raise HTTPException(400, f"Unknown embrace theme. Choose from: {list(EMBRACE_THEMES)}")

    prompt = THEME_PROMPTS[embrace_theme]
    prompt += ("\n\nThe reference photo shows BOTH people together. "
               "Use the exact likenesses of both people visible in this photo — "
               "preserve each person's real facial features, skin tone, age, and "
               "identity as shown. The man should be on the left, the woman on the right.")

    raw = await couple_photo.read()
    img = Image.open(io.BytesIO(raw)).convert("RGB")
    img.thumbnail((1024, 1024), Image.LANCZOS)
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    buf.seek(0)
    buf.name = "couple.png"

    client = OpenAI(api_key=resolve_openai_key(api_key))

    try:
        resp = client.images.edit(
            model="gpt-image-1",
            image=[buf],
            prompt=prompt,
            size="1024x1536",
            quality="high",
            n=max(1, min(variants, 4)),
        )
        results = [item.b64_json for item in resp.data]
    except Exception as e:
        print(f"Embrace couple-photo error: {e}")
        raise HTTPException(500, f"OpenAI error: {str(e)}")

    return JSONResponse({"success": True, "variants": results, "count": len(results)})


@app.post("/assemble-poster")
async def assemble_poster(
    generated_image: UploadFile = File(...),
    name:      str = Form(default=""),
    message:   str = Form(default=""),
    signed_by: str = Form(default="Your Children & Family"),
    occasion:  str = Form(default="50th Birthday"),
    title:     str = Form(default=""),       # optional custom banner title, blank = no banner title
    subtitle:  str = Form(default=""),       # optional tagline under the title
    footer:    str = Form(default="★   Wishing you a lifetime of love and happiness   ★"),
):
    from PIL import ImageDraw

    raw     = await generated_image.read()
    gen_img = Image.open(io.BytesIO(raw)).convert("RGB")

    GOLD  = (184, 148,  42)
    GOLD2 = (201, 168,  76)
    BG    = (  6,   9,  26)
    WHITE = (240, 230, 200)
    CREAM = (212, 201, 168)
    DIM   = (138, 112,  64)

    PW = 1800
    MARGIN = 70
    avail_w = PW - 2*MARGIN

    # Read the generated image dimensions first to calculate a height that
    # actually fits everything (image was overflowing the old fixed PH=2900).
    img_h_preview = int(avail_w * gen_img.height / gen_img.width)

    # Estimate message height (rough: ~9 words per line at this width/font)
    msg_word_count = len(message.split()) if message else 0
    msg_lines_est = max(1, -(-msg_word_count // 8))  # ceil division, ~8 words/line
    msg_height_est = msg_lines_est * 58

    # Build the canvas tall enough for: top text block + image + bottom text block
    TOP_BLOCK    = 250    # occasion heading + rules
    BOTTOM_BLOCK = 380 + msg_height_est   # quote + heading + message + signature + footer
    PH = TOP_BLOCK + img_h_preview + BOTTOM_BLOCK + 150  # generous padding

    poster = Image.new("RGB", (PW, PH), BG)
    draw   = ImageDraw.Draw(poster)

    # Gold border
    for i in range(6):
        draw.rectangle([i, i, PW-1-i, PH-1-i], outline=GOLD)
    for i in range(2):
        draw.rectangle([20+i, 20+i, PW-21-i, PH-21-i], outline=GOLD)

    def gold_rule(y):
        draw.line([(MARGIN, y), (PW-MARGIN, y)], fill=GOLD, width=2)

    def centered(text, y, font, fill):
        bb = draw.textbbox((0, 0), text, font=font)
        draw.text(((PW-(bb[2]-bb[0]))//2, y), text, font=font, fill=fill)
        return bb[3]-bb[1]

    f_title  = load_font(96,  bold=True)
    f_sub    = load_font(44)
    f_occ    = load_font(58,  bold=True)
    f_msg_h  = load_font(50,  bold=True)
    f_msg    = load_font(42)
    f_sign   = load_font(38)
    f_footer = load_font(34)
    f_quote  = load_font(140, bold=True)

    y = 55
    if title:
        y += centered(title.upper(), y, f_title, WHITE) + 18
    if subtitle:
        y += centered(subtitle, y, f_sub, GOLD2) + 28
    if title or subtitle:
        gold_rule(y); y += 28
    y += centered(occasion.upper(), y, f_occ, GOLD2) + 28
    gold_rule(y); y += 30

    # Main image (face-swapped)
    img_h   = int(avail_w * gen_img.height / gen_img.width)
    resized = gen_img.resize((avail_w, img_h), Image.LANCZOS)
    poster.paste(resized, (MARGIN, y))
    for i in range(4):
        draw.rectangle([MARGIN+i, y+i, MARGIN+avail_w-1-i, y+img_h-1-i], outline=GOLD)
    y += img_h + 45

    gold_rule(y); y += 38

    # Opening quote
    qb = draw.textbbox((0, 0), "\u201c", font=f_quote)
    draw.text(((PW-(qb[2]-qb[0]))//2, y), "\u201c", font=f_quote, fill=GOLD)
    y += (qb[3]-qb[1]) + 8

    y += centered("A MESSAGE FROM THE HEART", y, f_msg_h, GOLD2) + 28

    # Message body — word wrap
    if message:
        words, line, line_h = message.split(), "", 58
        for word in words:
            test = line + word + " "
            if draw.textbbox((0,0), test, font=f_msg)[2] > avail_w - 40 and line:
                lb = draw.textbbox((0,0), line.strip(), font=f_msg)
                draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
                y += line_h; line = word + " "
            else:
                line = test
        if line.strip():
            lb = draw.textbbox((0,0), line.strip(), font=f_msg)
            draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
            y += line_h

    y += 22
    y += centered(f"\u2014 {signed_by}", y, f_sign, DIM) + 55
    gold_rule(y); y += 32
    centered(footer, y, f_footer, DIM)

    out = io.BytesIO()
    poster.save(out, format="PNG", optimize=True)
    return JSONResponse({"success": True,
                         "poster": base64.b64encode(out.getvalue()).decode()})


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT: Assemble a GROUP poster (multiple people, e.g. 8 family women)
# Arranges N already-generated/face-swapped portraits into a grid, with each
# person's name beneath their photo, plus a shared message/signature/footer.
# Reuses the same gold/navy visual language as the solo poster.
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/assemble-group-poster")
async def assemble_group_poster(
    images: list[UploadFile] = File(...),     # 2-12 generated portraits, in display order
    names:  str = Form(default=""),           # comma-separated names, same order as images
    title:     str = Form(default="Happy Mother's Day"),
    subtitle:  str = Form(default=""),
    message:   str = Form(default=""),
    signed_by: str = Form(default="With all our love"),
    footer:    str = Form(default="★   To the queens of our family   ★"),
    columns:   int = Form(default=2),         # 2 columns x N rows by default
):
    from PIL import ImageDraw

    if not (2 <= len(images) <= 12):
        raise HTTPException(400, "Group poster supports between 2 and 12 people.")

    name_list = [n.strip() for n in names.split(",")] if names else []
    while len(name_list) < len(images):
        name_list.append("")

    person_imgs = []
    for f in images:
        raw = await f.read()
        person_imgs.append(Image.open(io.BytesIO(raw)).convert("RGB"))

    GOLD  = (184, 148,  42)
    GOLD2 = (201, 168,  76)
    BG    = (  6,   9,  26)
    WHITE = (240, 230, 200)
    CREAM = (212, 201, 168)
    DIM   = (138, 112,  64)

    PW = 1800
    MARGIN = 70
    GUTTER = 30                      # space between grid cells
    n = len(person_imgs)
    cols = max(1, min(columns, n))
    rows = -(-n // cols)             # ceil division

    avail_w   = PW - 2*MARGIN
    cell_w    = (avail_w - GUTTER*(cols-1)) // cols
    # Each cell preserves its own image's aspect ratio individually below,
    # but we need a common cell height for grid alignment — use a typical
    # portrait ratio (2:3) as the cell height baseline.
    cell_h    = int(cell_w * 1.5)
    name_h    = 60                   # space for name label under each photo

    TOP_BLOCK    = 230
    grid_h       = rows * (cell_h + name_h) + (rows-1) * GUTTER
    msg_word_count = len(message.split()) if message else 0
    msg_lines_est  = max(1, -(-msg_word_count // 8))
    msg_height_est = msg_lines_est * 58
    BOTTOM_BLOCK = 280 + msg_height_est

    PH = TOP_BLOCK + grid_h + BOTTOM_BLOCK + 150

    poster = Image.new("RGB", (PW, PH), BG)
    draw   = ImageDraw.Draw(poster)

    f_title   = load_font(64, bold=True)
    f_sub     = load_font(30)
    f_name    = load_font(34, bold=True)
    f_msg_h   = load_font(32, bold=True)
    f_msg     = load_font(30)
    f_sign    = load_font(30)
    f_footer  = load_font(24)

    def centered(text, y, font, fill):
        bbox = draw.textbbox((0,0), text, font=font)
        w = bbox[2]-bbox[0]
        draw.text(((PW-w)//2, y), text, font=font, fill=fill)
        return bbox[3]-bbox[1]

    def gold_rule(y):
        draw.line([(MARGIN, y), (PW-MARGIN, y)], fill=GOLD, width=2)

    # Outer border
    draw.rectangle([18, 18, PW-18, PH-18], outline=GOLD, width=3)

    y = 50
    y += centered(title.upper(), y, f_title, WHITE) + 18
    if subtitle:
        y += centered(subtitle, y, f_sub, GOLD2) + 24
    gold_rule(y); y += 30

    # Grid of portraits
    grid_top = y
    for i, img in enumerate(person_imgs):
        r, c = divmod(i, cols)
        cell_x = MARGIN + c * (cell_w + GUTTER)
        cell_y = grid_top + r * (cell_h + name_h + GUTTER)

        # Resize/crop to fill cell while preserving aspect ratio (center-crop)
        img_ratio  = img.width / img.height
        cell_ratio = cell_w / cell_h
        if img_ratio > cell_ratio:
            new_h = cell_h
            new_w = int(cell_h * img_ratio)
        else:
            new_w = cell_w
            new_h = int(cell_w / img_ratio)
        resized = img.resize((new_w, new_h), Image.LANCZOS)
        left = (new_w - cell_w) // 2
        top  = (new_h - cell_h) // 2
        cropped = resized.crop((left, top, left+cell_w, top+cell_h))

        poster.paste(cropped, (cell_x, cell_y))
        draw.rectangle([cell_x, cell_y, cell_x+cell_w, cell_y+cell_h], outline=GOLD, width=2)

        # Name label beneath this photo
        label = name_list[i] if i < len(name_list) and name_list[i] else f"Person {i+1}"
        lb = draw.textbbox((0,0), label, font=f_name)
        lw = lb[2]-lb[0]
        draw.text((cell_x + (cell_w-lw)//2, cell_y + cell_h + 12), label, font=f_name, fill=GOLD2)

    y = grid_top + grid_h + 35
    gold_rule(y); y += 30

    # Message block
    if message:
        y += centered('"', y, load_font(70, bold=True), GOLD) + 8
        y += centered("A MESSAGE FROM THE HEART", y, f_msg_h, GOLD2) + 28

        words, line, line_h = message.split(), "", 56
        avail_msg_w = PW - 2*MARGIN - 40
        for word in words:
            test = line + word + " "
            if draw.textbbox((0,0), test, font=f_msg)[2] > avail_msg_w and line:
                lb = draw.textbbox((0,0), line.strip(), font=f_msg)
                draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
                y += line_h; line = word + " "
            else:
                line = test
        if line.strip():
            lb = draw.textbbox((0,0), line.strip(), font=f_msg)
            draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
            y += line_h

    y += 22
    y += centered(f"— {signed_by}", y, f_sign, DIM) + 45
    gold_rule(y); y += 28
    centered(footer, y, f_footer, DIM)

    out = io.BytesIO()
    poster.save(out, format="PNG", optimize=True)
    return JSONResponse({"success": True,
                         "poster": base64.b64encode(out.getvalue()).decode()})


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT: Assemble a TOAST LINE poster (e.g. Father's Day "Cheers" group)
# Arranges N already-generated/face-swapped portraits side-by-side in a single
# horizontal row, slightly overlapping for a cohesive "standing together"
# celebration feel, rather than a boxed grid. Same gold/navy visual language.
# ─────────────────────────────────────────────────────────────────────────────
@app.post("/assemble-toast-poster")
async def assemble_toast_poster(
    images: list[UploadFile] = File(...),     # 2-8 generated portraits, in display order
    names:  str = Form(default=""),           # comma-separated names, same order as images
    title:     str = Form(default="Happy Father's Day"),
    subtitle:  str = Form(default=""),
    message:   str = Form(default=""),
    signed_by: str = Form(default="With all our love"),
    footer:    str = Form(default="★   Cheers to the men of our family   ★"),
):
    from PIL import ImageDraw

    if not (2 <= len(images) <= 8):
        raise HTTPException(400, "Toast poster supports between 2 and 8 people.")

    name_list = [n.strip() for n in names.split(",")] if names else []
    while len(name_list) < len(images):
        name_list.append("")

    person_imgs = []
    for f in images:
        raw = await f.read()
        person_imgs.append(Image.open(io.BytesIO(raw)).convert("RGB"))

    GOLD  = (184, 148,  42)
    GOLD2 = (201, 168,  76)
    BG    = (  6,   9,  26)
    WHITE = (240, 230, 200)
    CREAM = (212, 201, 168)
    DIM   = (138, 112,  64)

    n = len(person_imgs)
    PW = 1800
    MARGIN = 70
    OVERLAP = 60   # how much each portrait overlaps its neighbour, for a "standing together" feel

    avail_w  = PW - 2*MARGIN
    person_w = (avail_w + (n-1)*OVERLAP) // n
    person_h = int(person_w * 1.4)
    name_h   = 60

    TOP_BLOCK = 230
    row_h     = person_h + name_h
    msg_word_count = len(message.split()) if message else 0
    msg_lines_est  = max(1, -(-msg_word_count // 8))
    msg_height_est = msg_lines_est * 58
    BOTTOM_BLOCK = 280 + msg_height_est

    PH = TOP_BLOCK + row_h + BOTTOM_BLOCK + 150

    poster = Image.new("RGB", (PW, PH), BG)
    draw   = ImageDraw.Draw(poster)

    f_title   = load_font(64, bold=True)
    f_sub     = load_font(30)
    f_name    = load_font(32, bold=True)
    f_msg_h   = load_font(32, bold=True)
    f_msg     = load_font(30)
    f_sign    = load_font(30)
    f_footer  = load_font(24)

    def centered(text, y, font, fill):
        bbox = draw.textbbox((0,0), text, font=font)
        w = bbox[2]-bbox[0]
        draw.text(((PW-w)//2, y), text, font=font, fill=fill)
        return bbox[3]-bbox[1]

    def gold_rule(y):
        draw.line([(MARGIN, y), (PW-MARGIN, y)], fill=GOLD, width=2)

    draw.rectangle([18, 18, PW-18, PH-18], outline=GOLD, width=3)

    y = 50
    y += centered(title.upper(), y, f_title, WHITE) + 18
    if subtitle:
        y += centered(subtitle, y, f_sub, GOLD2) + 24
    gold_rule(y); y += 30

    # Horizontal toast line — back-to-front so the leftmost person appears
    # on top, like a natural line-up rather than a flat strip.
    row_top = y
    row_total_w = person_w + (n-1) * (person_w - OVERLAP)
    start_x = MARGIN + (avail_w - row_total_w) // 2

    for i in reversed(range(n)):
        img = person_imgs[i]
        cell_x = start_x + i * (person_w - OVERLAP)
        cell_y = row_top

        img_ratio  = img.width / img.height
        cell_ratio = person_w / person_h
        if img_ratio > cell_ratio:
            new_h = person_h
            new_w = int(person_h * img_ratio)
        else:
            new_w = person_w
            new_h = int(person_w / img_ratio)
        resized = img.resize((new_w, new_h), Image.LANCZOS)
        left = (new_w - person_w) // 2
        top  = (new_h - person_h) // 2
        cropped = resized.crop((left, top, left+person_w, top+person_h))

        poster.paste(cropped, (cell_x, cell_y))
        draw.rectangle([cell_x, cell_y, cell_x+person_w, cell_y+person_h], outline=GOLD, width=2)

        label = name_list[i] if i < len(name_list) and name_list[i] else f"Person {i+1}"
        lb = draw.textbbox((0,0), label, font=f_name)
        lw = lb[2]-lb[0]
        label_cx = cell_x + person_w//2
        draw.text((label_cx - lw//2, cell_y + person_h + 12), label, font=f_name, fill=GOLD2)

    y = row_top + row_h + 35
    gold_rule(y); y += 30

    if message:
        y += centered('"', y, load_font(70, bold=True), GOLD) + 8
        y += centered("RAISE A GLASS TO YOU", y, f_msg_h, GOLD2) + 28

        words, line, line_h = message.split(), "", 56
        avail_msg_w = PW - 2*MARGIN - 40
        for word in words:
            test = line + word + " "
            if draw.textbbox((0,0), test, font=f_msg)[2] > avail_msg_w and line:
                lb = draw.textbbox((0,0), line.strip(), font=f_msg)
                draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
                y += line_h; line = word + " "
            else:
                line = test
        if line.strip():
            lb = draw.textbbox((0,0), line.strip(), font=f_msg)
            draw.text(((PW-(lb[2]-lb[0]))//2, y), line.strip(), font=f_msg, fill=CREAM)
            y += line_h

    y += 22
    y += centered(f"— {signed_by}", y, f_sign, DIM) + 45
    gold_rule(y); y += 28
    centered(footer, y, f_footer, DIM)

    out = io.BytesIO()
    poster.save(out, format="PNG", optimize=True)
    return JSONResponse({"success": True,
                         "poster": base64.b64encode(out.getvalue()).decode()})


# ─────────────────────────────────────────────────────────────────────────────
# ENDPOINT: Save image for mobile download
# Mobile browsers can't download data: URLs — this endpoint accepts a base64
# image, saves it as a temp file, and returns a real URL the phone can download.
# ─────────────────────────────────────────────────────────────────────────────
import uuid, time
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles

TEMP_DIR = "temp_images"
os.makedirs(TEMP_DIR, exist_ok=True)

@app.post("/save-image")
async def save_image(
    image_b64: str = Form(...),
    filename:  str = Form(default="poster.png"),
):
    """Saves a base64 image to a temp file and returns a URL for mobile download."""
    # Clean up files older than 1 hour to avoid disk buildup
    now = time.time()
    for f in os.listdir(TEMP_DIR):
        fpath = os.path.join(TEMP_DIR, f)
        try:
            if now - os.path.getmtime(fpath) > 3600:
                os.remove(fpath)
        except Exception:
            pass

    uid = str(uuid.uuid4())[:8]
    safe_name = filename.replace("/", "_").replace("..", "_")
    file_path = os.path.join(TEMP_DIR, f"{uid}_{safe_name}")
    with open(file_path, "wb") as f:
        f.write(base64.b64decode(image_b64))

    return JSONResponse({"url": f"/downloads/{uid}_{safe_name}"})

@app.get("/downloads/{filename}")
async def download_image(filename: str):
    """Serves a saved temp image file for download."""
    file_path = os.path.join(TEMP_DIR, filename)
    if not os.path.exists(file_path):
        raise HTTPException(404, "File not found or expired")
    return FileResponse(
        file_path,
        media_type="image/png",
        headers={"Content-Disposition": f"attachment; filename=\"{filename}\""}
    )


if __name__ == "__main__":
    import uvicorn
    print("\n" + "="*54)
    print("  PosterMagic Server v2.0")
    print("  http://localhost:8000")
    print("  Step 1: GPT-image-1  → themed costume")
    print("  Step 2: Replicate    → real face swap")
    print("  Step 3: Pillow       → final poster")
    print("="*54 + "\n")
    uvicorn.run(app, host="0.0.0.0", port=8000)
