Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| import os | |
| import time | |
| from typing import List, Union | |
| from pillow_heif import register_heif_opener | |
| register_heif_opener() | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel, HttpUrl | |
| from transformers import pipeline | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG") | |
| MAX_URLS = int(os.getenv("MAX_URLS", 5)) | |
| MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 200)) | |
| # https://huggingface.co/models?pipeline_tag=image-to-text&sort=likes | |
| MODEL = os.getenv("MODEL", "../models/Salesforce/blip-image-captioning-large") | |
| logging.basicConfig(level=LOG_LEVEL) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| captioner = None # Placeholder for the captioner pipeline | |
| is_initialized = asyncio.Event() # Event to track initialization status | |
| lock = asyncio.Lock() | |
| def load_model(): | |
| global captioner | |
| logger.info("Loading model...") | |
| # simpler model: "ydshieh/vit-gpt2-coco-en" | |
| captioner = pipeline( | |
| "image-to-text", | |
| model=MODEL, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| ) | |
| logger.info("Done loading model.") | |
| is_initialized.set() | |
| class Image(BaseModel): | |
| url: Union[HttpUrl, List[HttpUrl]] # url can be a string or a list of strings | |
| async def startup_event(): | |
| global app | |
| asyncio.create_task(asyncio.to_thread(load_model)) | |
| # add gradio interface | |
| iface = gr.Interface(fn=captioner_gradapter, inputs="text", outputs=["text"], allow_flagging="never") | |
| app = gr.mount_gradio_app(app, iface, path="/gradio") | |
| async def captioner_gradapter(image_url): | |
| await is_initialized.wait() | |
| async with lock: | |
| result = await asyncio.to_thread(captioner, image_url) | |
| caption = result[0]["generated_text"] | |
| return caption | |
| async def root(): | |
| return {"message": "Hello World"} | |
| # the image url is passed in as a "url" tag in the json body | |
| async def create_caption(image: Image): | |
| if isinstance(image.url, list) and len(image.url) > MAX_URLS: | |
| logger.debug( | |
| f"Request with more than {MAX_URLS} URLs received. Refusing the request." | |
| ) | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Maximum of {MAX_URLS} URLs can be processed at once", | |
| ) | |
| async with lock: | |
| await is_initialized.wait() # Wait until initialization is completed | |
| start_time = time.time() | |
| # get the image url from the json body | |
| image_url = image.url | |
| try: | |
| caption = await asyncio.to_thread(captioner, image_url) | |
| except Exception as e: | |
| logger.error("Error during caption generation: %s", str(e)) | |
| raise HTTPException( | |
| status_code=500, | |
| detail="An error occurred during caption generation. Please try again later.", | |
| ) | |
| end_time = time.time() | |
| duration = end_time - start_time | |
| logger.debug("Captioning completed. Time taken: %s seconds.", duration) | |
| return {"caption": caption, "duration": duration} | |
| # add liveness probe | |
| async def healthz(): | |
| return {"status": "ok"} | |
| # add readiness probe | |
| async def readyz(): | |
| if not is_initialized.is_set(): | |
| raise HTTPException(status_code=503, detail="Initialization in progress") | |
| return {"status": "ok"} | |