Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| # Initialize model and tokenizer | |
| MODEL_OPTIONS = { | |
| "waleko/roberta-arxiv-tags": "RoBERTa Arxiv Tags" | |
| } | |
| def load_model(model_name): | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_name) | |
| return model, tokenizer | |
| current_model = None | |
| current_tokenizer = None | |
| def get_model_and_tokenizer(model_name): | |
| global current_model, current_tokenizer | |
| if current_model is None or current_tokenizer is None: | |
| current_model, current_tokenizer = load_model(model_name) | |
| return current_model, current_tokenizer | |
| def create_visualization(probs, labels): | |
| return go.Figure(data=[go.Pie( | |
| labels=labels + ['Others'] if sum(probs) < 1 else labels, | |
| values=list(probs) + [1 - sum(probs)] if sum(probs) < 1 else list(probs), | |
| textinfo='percent', | |
| textposition='inside', | |
| hole=.3, | |
| showlegend=True | |
| )]) | |
| def classify_text(title, abstract, model_name): | |
| if not title and not abstract: | |
| return "Error: At least one of title or abstract must be provided.", None | |
| model, tokenizer = get_model_and_tokenizer(model_name) | |
| text = 'Title: ' + (title or '') + '\n\nAbstract: ' + (abstract or '') | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=1024) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits[0], dim=0) | |
| probs = probs.numpy() | |
| sorted_idx = np.argsort(probs)[::-1] | |
| sorted_probs = probs[sorted_idx] | |
| cumsum = np.cumsum(sorted_probs) | |
| k = 1 | |
| if sorted_probs[0] < 0.95: | |
| k = np.argmax(cumsum >= 0.95) + 1 | |
| id2label = model.config.id2label | |
| tags = [id2label[idx] for idx in sorted_idx[:k]] | |
| compact_pred = f'<span style="font-weight: 800;">{tags[0]}</span>' + (f" {' '.join(tags[1:])}" if len(tags) > 1 else "") | |
| viz_data = create_visualization( | |
| sorted_probs[:k], | |
| [id2label[idx] for idx in sorted_idx[:k]] | |
| ) | |
| html_output = f""" | |
| <div> | |
| <h3>Predicted Tags</h3> | |
| <p>{compact_pred}</p> | |
| </div> | |
| """ | |
| return html_output, viz_data | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # Arxiv Tags Classification | |
| Classify academic papers into arXiv categories using state-of-the-art language models. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value=list(MODEL_OPTIONS.keys())[0], | |
| label="Select Model", | |
| info="Choose the model for classification" | |
| ) | |
| title_input = gr.Textbox( | |
| lines=1, | |
| label="Title", | |
| placeholder="Enter paper title (optional if abstract is provided)" | |
| ) | |
| abstract_input = gr.Textbox( | |
| lines=5, | |
| label="Abstract", | |
| placeholder="Enter paper abstract (optional if title is provided)" | |
| ) | |
| with gr.Column(scale=1): | |
| output_html = gr.HTML( | |
| label="Predicted Tags" | |
| ) | |
| output_plot = gr.Plot( | |
| label="Probability Distribution", | |
| show_label=True | |
| ) | |
| inputs = [title_input, abstract_input, model_dropdown] | |
| btn = gr.Button("Classify", variant="primary") | |
| btn.click(fn=classify_text, inputs=inputs, outputs=[output_html, output_plot]) | |
| if __name__ == "__main__": | |
| demo.launch() |