import os
import base64
import json
import requests
from google import genai
from pathlib import Path
from PIL import Image
import smtplib
from email.mime.text import MIMEText
import openai

GEMINI_API_KEY = os.environ["GEMINI_API_KEY"]
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
ALERT_EMAIL = os.environ.get("ALERT_EMAIL", "")
ALERT_PASSWORD = os.environ.get("ALERT_PASSWORD", "")

gemini_client = genai.Client(api_key=GEMINI_API_KEY)
openai_client = openai.OpenAI(api_key=OPENAI_API_KEY)


def send_alert(subject, message):
    if not ALERT_EMAIL or not ALERT_PASSWORD:
        print(f"ALERTE : {subject} — {message}")
        return
    try:
        msg = MIMEText(message)
        msg["Subject"] = subject
        msg["From"] = ALERT_EMAIL
        msg["To"] = ALERT_EMAIL
        with smtplib.SMTP_SSL("smtp.gmail.com", 465) as server:
            server.login(ALERT_EMAIL, ALERT_PASSWORD)
            server.send_message(msg)
        print(f"Alerte envoyee : {subject}")
    except Exception as e:
        print(f"Impossible d'envoyer l'alerte : {e}")


def crop_watermark(image_path, output_path=None):
    if output_path is None:
        output_path = image_path
    img = Image.open(image_path)
    width, height = img.size
    # Crop les 4% du bas (zone watermark Gemini/Nano Banana)
    crop_height = int(height * 0.96)
    cropped = img.crop((0, 0, width, crop_height))
    # Remet en carré avec fond blanc
    result = Image.new("RGB", (width, width), (255, 255, 255))
    result.paste(cropped, (0, 0))
    result.save(output_path)
    print(f"Watermark supprime : {output_path}")
    return output_path


def analyze_with_gemini(photo_path):
    print("Analyse Gemini en cours...")
    image_bytes = Path(photo_path).read_bytes()

    response = gemini_client.models.generate_content(
        model="gemini-2.5-flash",
        contents=[
            {
                "parts": [
                    {
                        "text": """Analyse this portrait image for illustration purposes only.
Describe these visual characteristics and return ONLY valid JSON, no markdown:
{
  "skin_tone": "claire ou metissee ou foncee ou tres_foncee",
  "hair_color": "ex: noir, chatain, blond, roux",
  "hair": "chauve ou ras ou court ou mi_long ou long",
  "hair_texture": "lisse ou ondule ou boucle ou crepu ou aucun si chauve",
  "eye_color": "ex: marron, bleu, vert",
  "eye_shape": "rond ou amande ou bride",
  "eye_size": "petit ou moyen ou grand",
  "gender": "garcon ou fille",
  "face_shape": "rond ou ovale ou carre ou coeur",
  "nose": {
    "size": "petit ou moyen ou grand",
    "shape": "retrousse ou droit ou large ou pointu",
    "nostrils": "fines ou moyennes ou larges"
  },
  "lips": {
    "size": "fines ou moyennes ou pulpeuses",
    "shape": "droites ou en_arc ou tombantes"
  },
  "ears": {
    "size": "petites ou moyennes ou grandes",
    "position": "collees ou normales ou decollees"
  }
}"""
                    },
                    {
                        "inline_data": {
                            "mime_type": "image/jpeg",
                            "data": base64.b64encode(image_bytes).decode("utf-8")
                        }
                    }
                ]
            }
        ]
    )

    raw = response.text.strip().replace("```json", "").replace("```", "").strip()
    print(f"Description Gemini : {raw}")
    return json.loads(raw)


def build_dalle_prompt(child):
    nose = child.get("nose", {})
    lips = child.get("lips", {})
    ears = child.get("ears", {})

    hair_desc = (
        "completely bald, no hair at all, smooth shaved head, "
        if child.get("hair") in ["chauve", "ras"]
        else f"{child.get('hair_color', '')} {child.get('hair_texture', '')} {child.get('hair', '')} hair, "
    )

    prompt = (
        "A single cartoon face icon like an emoji or app avatar, "
        "ONLY the face visible, nothing below the chin, "
        "pure white background, no neck, no body, no shoulders, "
        "face from forehead to chin only, like a round face sticker, "
        f"{child.get('gender', 'child')} child, "
        + hair_desc +
        f"{child.get('eye_color', 'brown')} {child.get('eye_size', 'big')} {child.get('eye_shape', 'round')} eyes, "
        f"{child.get('face_shape', 'round')} face shape, "
        f"{nose.get('size', 'medium')} {nose.get('shape', 'cute')} nose, "
        f"{lips.get('size', 'medium')} lips, "
        f"{ears.get('position', 'normal')} ears, "
        f"{child.get('skin_tone', 'medium')} skin tone, "
        "big expressive eyes with white highlights, joyful smile, "
        "bold clean outlines, smooth cel-shaded skin, warm vibrant colors, "
        "Disney Pixar junior style, children book illustration, high quality"
    )
    return prompt


def generate_with_dalle(child, output_path="cartoon_head.png"):
    print("Fallback DALL-E 3 en cours...")
    prompt = build_dalle_prompt(child)
    print(f"Prompt DALL-E : {prompt}")

    response = openai_client.images.generate(
        model="dall-e-3",
        prompt=prompt,
        size="1024x1024",
        quality="hd",
        n=1
    )

    url = response.data[0].url
    r = requests.get(url)
    with open(output_path, "wb") as f:
        f.write(r.content)

    print(f"Portrait DALL-E sauvegarde : {output_path}")
    return output_path


def generate_portrait_fal(child):
    import fal_client
    import os
    print("Generation fal.ai Nano Banana Pro en cours...")

    # Upload la photo de l'enfant vers fal.ai
    with open("photo_test.jpg", "rb") as f:
        photo_url = fal_client.upload(f.read(), "image/jpeg")

    prompt = (
        "Transform the person in the image into a 3D stylized Disney Pixar cartoon child avatar. "
        "Full head only, NO BODY, NO NECK, NO SHOULDERS, face only like an emoji sticker, "
        "centered on pure white background, studio lighting, soft shadows, global illumination, "
        "subsurface scattering, ultra clean render. "
        f"{child['gender']} child, "
        + ("completely bald, no hair at all, smooth head, "
           if child.get("hair") in ["chauve", "ras"]
           else f"{child.get('hair_color', '')} {child.get('hair_texture', '')} hair, clean stylized hair, ")
        + f"{child.get('skin_tone', 'medium')} skin tone, smooth skin, "
        f"{child.get('eye_color', 'brown')} big expressive eyes with white highlights, "
        f"{child.get('face_shape', 'round')} face shape, soft rounded features, "
        f"{child.get('nose', {}).get('shape', 'cute')} nose, "
        f"{child.get('lips', {}).get('size', 'medium')} lips, "
        "neutral friendly expression, same face identity, "
        "Disney Pixar 3D style, high quality 3D render, "
        "pure white background, no text, no watermark, no body parts below chin"
    )

    result = fal_client.subscribe(
        "fal-ai/nano-banana-pro/edit",
        arguments={
            "prompt": prompt,
            "image_urls": [photo_url],
            "aspect_ratio": "1:1",
            "output_format": "png",
            "resolution": "1K",
            "safety_tolerance": "4",
            "num_images": 1
        }
    )

    image_url = result["images"][0]["url"]
    r = requests.get(image_url)
    with open("cartoon_head.png", "wb") as f:
        f.write(r.content)

    print("Portrait sauvegarde : cartoon_head.png")
    return "cartoon_head.png"


def run_pipeline(photo_path):
    try:
        child = analyze_with_gemini(photo_path)
        portrait = generate_portrait_fal(child)  # ← change ici

        result = {
            "skin_tone": child["skin_tone"],
            "portrait_path": portrait
        }

        with open("head_result.json", "w") as f:
            json.dump(result, f, indent=2)

        print(f"Resultat final : {result}")
        return result

    except Exception as e:
        send_alert(
            "ERREUR CRITIQUE pipeline livre enfant",
            f"Le pipeline a completement echoue : {e}"
        )
        raise

if __name__ == "__main__":
    import sys
    if len(sys.argv) < 2:
        print("Usage: python3 generate_head.py photo_test.jpg")
        sys.exit(1)
    run_pipeline(sys.argv[1])