StyleSync / app.py
Fiqa's picture
Update app.py
3e3453c verified
import os
from huggingface_hub import login, snapshot_download
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from dotenv import load_dotenv
import gradio as gr
from diffusers import FluxPipeline
import torch
import spaces # Hugging Face Spaces module
# -----------------------
# Pre-cache models at startup
# -----------------------
snapshot_download("Salesforce/blip-image-captioning-large", timeout=120)
snapshot_download("noamrot/FuseCap", timeout=120)
snapshot_download("black-forest-labs/FLUX.1-dev", timeout=300)
# -----------------------
# Authentication
# -----------------------
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if HUGGINGFACE_TOKEN:
login(token=HUGGINGFACE_TOKEN)
# -----------------------
# Load models
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", timeout=120)
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", timeout=120).to(device)
processor1 = BlipProcessor.from_pretrained("noamrot/FuseCap", timeout=120)
model2 = BlipForConditionalGeneration.from_pretrained("noamrot/FuseCap", timeout=120).to(device)
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, timeout=300).to(device)
# -----------------------
# Options
# -----------------------
fabrics = ['cotton', 'silk', 'denim', 'linen', 'polyester', 'wool', 'velvet']
patterns = ['striped', 'floral', 'geometric', 'abstract', 'solid', 'polka dots']
textile_designs = ['woven texture', 'embroidery', 'printed fabric', 'hand-dyed', 'quilting']
# -----------------------
# Inference Function
# -----------------------
@spaces.GPU(duration=150)
def generate_caption_and_image(image, f, p, d):
if image and f and p and d:
img = image.convert("RGB")
# Caption with FuseCap
inputs = processor(img, "a picture of ", return_tensors="pt").to(device)
out = model2.generate(**inputs, num_beams=3)
caption2 = processor1.decode(out[0], skip_special_tokens=True)
# Caption with BLIP
inputs = processor(image, return_tensors="pt", padding=True, truncation=True, max_length=250)
inputs = {k: v.to(device) for k, v in inputs.items()}
out = model.generate(**inputs)
caption1 = processor.decode(out[0], skip_special_tokens=True)
# Compose prompt
prompt = (
f"Design a high-quality, stylish clothing item that combines the essence of {caption1} and {caption2}. "
f"Use luxurious {f} fabric with intricate {d} design elements. "
f"Incorporate {p} patterns to elevate the garment's aesthetic. "
"Ensure sophistication, innovation, and timeless elegance."
)
# Generate image
result = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator('cpu').manual_seed(0)
).images[0]
return result
return None
# -----------------------
# Gradio UI
# -----------------------
iface = gr.Interface(
fn=generate_caption_and_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Radio(fabrics, label="Select Fabric"),
gr.Radio(patterns, label="Select Pattern"),
gr.Radio(textile_designs, label="Select Textile Design")
],
outputs=gr.Image(label="Generated Design"),
live=True
)
iface.launch()