File size: 3,673 Bytes
f667084
3e3453c
a04adbd
4d54b56
90c9be8
fa3a39e
3e3453c
fa3a39e
 
a04adbd
3e3453c
 
 
 
 
 
 
 
 
 
 
 
 
 
264752e
3e3453c
 
 
 
dcf269f
3e3453c
 
264752e
3e3453c
 
264752e
3e3453c
 
 
 
 
01b1364
 
 
0372f7c
3e3453c
 
 
c83e28c
febdafe
3e3453c
eb4cf9a
3e3453c
 
 
 
3a94231
3e3453c
 
3aefc04
3e3453c
3aefc04
 
c6c4d1c
3e3453c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6833ac1
3e3453c
 
9023169
3e3453c
9023169
 
3e3453c
 
 
 
 
 
 
9023169
 
3e3453c
 
 
 
 
 
 
 
 
f667084
0ef1f3f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
from huggingface_hub import login, snapshot_download
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
from dotenv import load_dotenv
import gradio as gr
from diffusers import FluxPipeline
import torch
import spaces  # Hugging Face Spaces module

# -----------------------
# Pre-cache models at startup
# -----------------------
snapshot_download("Salesforce/blip-image-captioning-large", timeout=120)
snapshot_download("noamrot/FuseCap", timeout=120)
snapshot_download("black-forest-labs/FLUX.1-dev", timeout=300)

# -----------------------
# Authentication
# -----------------------
load_dotenv()
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
if HUGGINGFACE_TOKEN:
    login(token=HUGGINGFACE_TOKEN)

# -----------------------
# Load models
# -----------------------
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large", timeout=120)
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", timeout=120).to(device)

processor1 = BlipProcessor.from_pretrained("noamrot/FuseCap", timeout=120)
model2 = BlipForConditionalGeneration.from_pretrained("noamrot/FuseCap", timeout=120).to(device)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, timeout=300).to(device)

# -----------------------
# Options
# -----------------------
fabrics = ['cotton', 'silk', 'denim', 'linen', 'polyester', 'wool', 'velvet']
patterns = ['striped', 'floral', 'geometric', 'abstract', 'solid', 'polka dots']
textile_designs = ['woven texture', 'embroidery', 'printed fabric', 'hand-dyed', 'quilting']

# -----------------------
# Inference Function
# -----------------------
@spaces.GPU(duration=150)
def generate_caption_and_image(image, f, p, d):
    if image and f and p and d:
        img = image.convert("RGB")

        # Caption with FuseCap
        inputs = processor(img, "a picture of ", return_tensors="pt").to(device)
        out = model2.generate(**inputs, num_beams=3)
        caption2 = processor1.decode(out[0], skip_special_tokens=True)

        # Caption with BLIP
        inputs = processor(image, return_tensors="pt", padding=True, truncation=True, max_length=250)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        out = model.generate(**inputs)
        caption1 = processor.decode(out[0], skip_special_tokens=True)

        # Compose prompt
        prompt = (
            f"Design a high-quality, stylish clothing item that combines the essence of {caption1} and {caption2}. "
            f"Use luxurious {f} fabric with intricate {d} design elements. "
            f"Incorporate {p} patterns to elevate the garment's aesthetic. "
            "Ensure sophistication, innovation, and timeless elegance."
        )

        # Generate image
        result = pipe(
            prompt,
            height=1024,
            width=1024,
            guidance_scale=3.5,
            num_inference_steps=50,
            max_sequence_length=512,
            generator=torch.Generator('cpu').manual_seed(0)
        ).images[0]

        return result
    return None

# -----------------------
# Gradio UI
# -----------------------
iface = gr.Interface(
    fn=generate_caption_and_image,
    inputs=[
        gr.Image(type="pil", label="Upload Image"),
        gr.Radio(fabrics, label="Select Fabric"),
        gr.Radio(patterns, label="Select Pattern"),
        gr.Radio(textile_designs, label="Select Textile Design")
    ],
    outputs=gr.Image(label="Generated Design"),
    live=True
)

iface.launch()