Naahbi's picture
commenting
c35f287 verified
from transformers import SwinForImageClassification, AutoFeatureExtractor
from PIL import Image, ImageOps
import torch
from ultralytics import YOLO
# Index classes
nature_to_idx = {0: 'Precancerous', 1: 'Malign', 2: 'Benign', 3: 'Benign',
4: 'Malign', 5: 'Benign', 6: 'Malign', 7: 'Benign'}
class_to_idx = {0: 'Actinic keratosis', 1: 'Basal cell carcinoma', 2: 'Benign keratosis', 3: 'Dermatofibroma',
4: 'Melanoma', 5: 'Melanocytic nevus', 6: 'Squamous cell carcinoma', 7: 'Vascular lesion'}
# Loading Model
loaded_model_dir = "models/best_swin"
original_model_dir = "microsoft/swin-base-patch4-window7-224"
yolo_model_dir = "models/yolo_BB.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SwinForImageClassification.from_pretrained(loaded_model_dir, torch_dtype="auto")
model.eval()
feature_extractor = AutoFeatureExtractor.from_pretrained(original_model_dir)
yolo_model = YOLO(yolo_model_dir)
# Function for diagnoses based on image passed as paramether
def predicted_diagnosis(img):
target_size = (640, 640)
fill_color = (114, 114, 114)
img.thumbnail(target_size, Image.LANCZOS)
padded_image = ImageOps.pad(img, target_size, color=fill_color)
crop_result = yolo_model(padded_image)[0]
boxes = crop_result.boxes.xyxy
if len(boxes) == 0:
crop = padded_image # fallback: usa l'intera immagine
else:
x1, y1, x2, y2 = map(int, boxes[0])
crop = padded_image.crop((x1, y1, x2, y2))
extracted = feature_extractor(images=crop, return_tensors='pt').to(device)
with torch.no_grad():
outputs = model(**extracted)
logits = outputs.logits # shape [1, num_classes]
probs = torch.softmax(logits, dim=-1) # probabilità per classe
pred_class_idx = torch.argmax(probs, dim=-1).item()
pred_class_prob = probs[0, pred_class_idx].item()
return crop, nature_to_idx[pred_class_idx], class_to_idx[pred_class_idx], pred_class_prob
"""## Platform code
"""
import os
import gradio as gr
import json
BASE_DIR = 'my_project'
os.makedirs(BASE_DIR, exist_ok=True)
IMG_DIR = os.path.join(BASE_DIR, 'images')
os.makedirs(IMG_DIR, exist_ok=True)
JSON_DATA_DIR = os.path.join(BASE_DIR, 'results')
os.makedirs(JSON_DATA_DIR, exist_ok=True)
JSON_USERS_PATH = os.path.join(BASE_DIR, 'users.json')
"""### main codes"""
# Function for showing the required page
def show_page(page_name, username):
if page_name == 'classifier':
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False)
elif page_name == 'archive':
archive_msg, archive_df = show_archive(username)
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), archive_msg, archive_df
elif page_name == 'login':
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
#Function for saving the classification results
def save_result(img: Image.Image, predicted_nature, predicted_class, predicted_probability, username):
# Security Checks
if img is None:
return "Error: no image uploaded!"
if predicted_nature is None or predicted_class == "":
return "Error: classify first!"
if predicted_class is None or predicted_class == "":
return "Error: classify first!"
if predicted_probability is None or predicted_probability == "":
return "Error: classify first!"
if username == 'user' or username == '':
return "Error: sign-in first"
now = datetime.now()
timestamp_file = now.strftime('%Y-%m-%d-%H-%M-%S')
timestamp_json = now.isoformat()
# Save image
img_path = os.path.join(IMG_DIR, username, f'{timestamp_file}.jpg')
os.makedirs(os.path.dirname(img_path), exist_ok=True)
img.save(img_path)
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
# Read for existing JSON
try:
with open(json_path, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = []
results.append({
'timestamp': timestamp_json,
'image_path': os.path.relpath(img_path, BASE_DIR),
'predicted_nature': predicted_nature,
'predicted_class': predicted_class,
'predicted_probability': float(predicted_probability)
})
with open(json_path, "w") as f:
json.dump(results, f, indent=2)
return gr.update(value="Results saved successfully!", visible=True)
# fucntion for loading the archive related to the user
def show_archive(username):
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
if not os.path.exists(json_path):
return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False)
with open(json_path, "r") as f:
archive_data = json.load(f)
if not archive_data:
return gr.update(value="Nessun record salvato.", visible=True), gr.update(visible=False)
rows = []
for row in archive_data:
rows.append([row["timestamp"],
row["predicted_nature"],
row["predicted_class"],
row["predicted_probability"],
row["image_path"]])
return gr.update(value="", visible=False), gr.update(value=rows, visible=True)
# Function for showing the selected image
def show_image(evt: gr.SelectData):
img_path = evt.row_value[3]
return gr.update(value=os.path.join(BASE_DIR, img_path), visible=True), gr.update(value=img_path), gr.update(
visible=True)
# Function for deleting an existing record from the archive files
def delete_record(table_data, img_path, username):
if img_path is None:
return gr.update(value=table_data)
if os.path.exists(os.path.join(BASE_DIR, img_path)):
os.remove(os.path.join(BASE_DIR, img_path))
json_path = os.path.join(JSON_DATA_DIR, f'{username}.json')
if os.path.exists(json_path):
with open(json_path, "r") as f:
results = json.load(f)
results = [r for r in results if r["image_path"] != img_path]
with open(json_path, "w") as f:
json.dump(results, f, indent=2)
rows = []
for row in results:
rows.append([row["timestamp"],
row["predicted_nature"],
row["predicted_class"],
row["predicted_probability"],
row["image_path"]])
return gr.update(value=rows, visible=True), gr.update(value=None, visible=False), gr.update(visible=False)
# Function for sign in
def signin(user, psw):
if not os.path.exists(JSON_USERS_PATH):
return (
gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
else:
path = JSON_USERS_PATH
with open(path, 'r') as f:
users = json.load(f)
for u in users:
if u['username'].lower() == user.lower() and u['password'] == psw:
return (
gr.update(value='Login effettuato', visible=True), # login_msg
gr.update(value=f"**{user.upper()}**", visible=True), # username_sb
gr.update(visible=False), # login_pg
gr.update(visible=True), # archive_sb_btn
gr.update(visible=False), # sign_sb_btn
gr.update(visible=True), # delete_sb_btn
gr.update(visible=True), # save_btn
gr.update(visible=True) # classifier_pg
)
return (
gr.update(value='Password e/o Nome utente errato', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
# Function for sign up
def signup(user, psw):
try:
with open(JSON_USERS_PATH, "r") as f:
results = json.load(f)
except FileNotFoundError:
results = []
for u in results:
if u['username'].lower() == user.lower():
return (
gr.update(value='Utente già esistente', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
if user == '' or psw == '':
return (
gr.update(value='Nome utente e/o password non validi', visible=True), # login_msg
gr.update(value='user', visible=False), # username_sb
gr.update(visible=True), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=False) # classifier_pg
)
results.append({
'username': user,
'password': psw
})
with open(JSON_USERS_PATH, "w") as f:
json.dump(results, f, indent=2)
return (
gr.update(value='Login effettuato', visible=True), # login_msg
gr.update(value=f"**{user.upper()}**", visible=True), # username_sb
gr.update(visible=False), # login_pg
gr.update(visible=True), # archive_sb_btn
gr.update(visible=False), # sign_sb_btn
gr.update(visible=True), # delete_sb_btn
gr.update(visible=True), # save_btn
gr.update(visible=True) # classifier_pg
)
# Function for deleting an existing account
def delete_account(user):
user = user.strip('*').lower()
with open(JSON_USERS_PATH, 'r') as file:
users = json.load(file)
users = [record for record in users if record['username'].lower() != user.lower()]
print(users)
with open(JSON_USERS_PATH, 'w') as file:
json.dump(users, file, indent=2)
return (
gr.update(visible=False), # username_sb
gr.update(visible=False), # login_pg
gr.update(visible=False), # archive_sb_btn
gr.update(visible=True), # sign_sb_btn
gr.update(visible=False), # delete_sb_btn
gr.update(visible=False), # save_btn
gr.update(visible=True) # classifier_pg
)
from PIL import Image
import gradio as gr
from datetime import datetime
import json
css = """
#container {min-height: 100vh;}
.sidebar {
background-color: #f0f0f0;
padding: 10px;
height: 100vh;
overflow-y: auto;
}
.content {padding: 20px; padding-bottom: 60px; overflow-y:auto}
.fixed_image img{
max-width: 224px;
height: 224px; /* altezza massima */
object-fit: contain; /* mantiene proporzioni */
}
"""
# UI
with gr.Blocks(css=css) as demo:
with gr.Row(elem_id='container'):
# sidebar
with gr.Column(scale=1, min_width=200, elem_classes='sidebar'):
username_sb = gr.Markdown('user', visible=False)
cls_sb_bt = gr.Button('Classifier', visible=True)
archive_sb_btn = gr.Button('Archive', visible=False)
sign_sb_btn = gr.Button('Sign-in/Sign-up', visible=True)
delete_sb_btn = gr.Button('Delete Account', visible=False)
# content
with gr.Column(scale=4):
with gr.Column(scale=4, min_width=400, elem_classes="content"):
# Classifier page
with gr.Group(visible=True) as classifier:
gr.Markdown('### Classifier')
with gr.Row():
uploader = gr.Image(label='Upload image', sources=['upload', 'webcam'], type='pil',
elem_classes='fixed_image')
with gr.Row():
output_nature = gr.Textbox(value=0, label="Nature", interactive=False)
output_diagnosis = gr.Textbox(value=0, label="Diagnosis", interactive=False)
output_probability = gr.Number(value="", label="Probability", interactive=False)
with gr.Row():
# pulsanti
cls_btn = gr.Button("Classify", visible=True)
save_btn = gr.Button('Save Results', visible=False)
save_status = gr.Textbox(label='Status', interactive=False, visible=False)
# events
cls_btn.click(
fn=predicted_diagnosis,
inputs=uploader,
outputs=[uploader, output_nature, output_diagnosis, output_probability]
)
save_btn.click(
fn=save_result,
inputs=[uploader, output_nature, output_diagnosis, output_probability, username_sb],
outputs=save_status
)
# Archive page
with gr.Group(visible=False) as archive:
gr.Markdown('### Archive')
with gr.Row():
archive_msg = gr.Markdown()
archive_df = gr.Dataframe(
headers=['Timestamp', 'Nature', 'Diagnosis', 'Probability', 'Image Path'],
datatype=['str', 'str', 'str', 'number', 'str'],
interactive=False,
visible=False
)
selected_image = gr.Image(label='Immagine selezionata.', visible=False, type='filepath',
interactive=False, elem_classes='fixed_image')
image_path = gr.Textbox(visible=False)
with gr.Row():
# Buttons
load_btn = gr.Button('Refresh Archive')
delete_btn = gr.Button('Delete Record', visible=False)
# Events
load_btn.click(fn=show_archive, inputs=[username_sb], outputs=[archive_msg, archive_df])
archive_df.select(
fn=show_image,
inputs=[],
outputs=[selected_image, image_path, delete_btn]
)
delete_btn.click(
fn=delete_record,
inputs=[archive_df, image_path, username_sb],
outputs=[archive_df, selected_image, delete_btn]
)
# Login page
with gr.Group(visible=False) as login:
gr.Markdown('### Login')
with gr.Row():
username_tbox = gr.Textbox(label='Username')
password_tbox = gr.Textbox(label='Password', type='password')
with gr.Row():
login_msg = gr.Textbox(label='Status', interactive=False, visible=False)
with gr.Row():
# Buttons
signin_btn = gr.Button('Sign-in')
signup_btn = gr.Button('Sign-up')
# Events
signin_btn.click(
fn=signin,
inputs=[username_tbox, password_tbox],
outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn,
classifier]
)
signup_btn.click(
fn=signup,
inputs=[username_tbox, password_tbox],
outputs=[login_msg, username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn,
classifier]
)
#Sidebar buttons
cls_sb_bt.click(fn=lambda username: show_page('classifier', username), inputs=[username_sb],
outputs=[classifier, archive, login])
archive_sb_btn.click(fn=lambda username: show_page('archive', username), inputs=[username_sb],
outputs=[classifier, archive, login, archive_msg, archive_df])
sign_sb_btn.click(fn=lambda username: show_page('login', username), inputs=[username_sb],
outputs=[classifier, archive, login])
delete_sb_btn.click(fn=delete_account, inputs=[username_sb],
outputs=[username_sb, login, archive_sb_btn, sign_sb_btn, delete_sb_btn, save_btn,
classifier])
if __name__ == '__main__':
demo.launch(debug=True)