Rsnarsna's picture
Update app.py
a763c13 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification, pipeline
# Load the model and tokenizer
model_name = "citizenlab/twitter-xlm-roberta-base-sentiment-finetunned"
tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
model = XLMRobertaForSequenceClassification.from_pretrained(model_name)
# Define the sentiment analysis pipeline
sentiment_pipeline = pipeline("text-classification", model=model, tokenizer=tokenizer)
app = FastAPI()
# Define a Pydantic model for the input text
class TextInput(BaseModel):
text: str
# --- For /predict ---
# Function to split text into chunks
def split_text_into_chunks(text, max_tokens=500):
tokens = tokenizer(text, return_tensors="pt", truncation=False, padding=False)
input_ids = tokens['input_ids'][0].tolist()
chunks = [input_ids[i:i+max_tokens] for i in range(0, len(input_ids), max_tokens)]
chunk_texts = [tokenizer.decode(chunk, skip_special_tokens=True) for chunk in chunks]
return chunks, chunk_texts, [len(chunk) for chunk in chunks]
# Function to analyze sentiment for a list of chunks
def analyze_sentiment_chunks(chunks, chunk_texts, chunk_token_counts):
results = []
total_token_count = 0
for i, chunk in enumerate(chunk_texts):
total_token_count += chunk_token_counts[i]
analysis = sentiment_pipeline(chunk, top_k=None)
results.append({
"chunk": i + 1,
"text": chunk,
"token_count": chunk_token_counts[i],
"analysis": analysis,
})
return results, total_token_count
@app.post("/predict")
def predict_sentiment(input_data: TextInput):
chunks, chunk_texts, chunk_token_counts = split_text_into_chunks(input_data.text)
results, total_token_count = analyze_sentiment_chunks(chunks, chunk_texts, chunk_token_counts)
total_neutral_score = total_positive_score = total_negative_score = 0
for result in results:
for sentiment in result['analysis']:
if sentiment['label'] == "Neutral":
total_neutral_score += sentiment['score']
elif sentiment['label'] == "Positive":
total_positive_score += sentiment['score']
elif sentiment['label'] == "Negative":
total_negative_score += sentiment['score']
num_chunks = len(results)
overall_neutral_score = total_neutral_score / num_chunks if num_chunks > 0 else 0
overall_positive_score = total_positive_score / num_chunks if num_chunks > 0 else 0
overall_negative_score = total_negative_score / num_chunks if num_chunks > 0 else 0
if len(results)==1:
return {"results": results,}
return {
"total_chunks": num_chunks,
"total_token_count": total_token_count,
"total_neutral_score": total_neutral_score,
"total_positive_score": total_positive_score,
"total_negative_score": total_negative_score,
"overall_neutral_score": overall_neutral_score,
"overall_positive_score": overall_positive_score,
"overall_negative_score": overall_negative_score,
"results": results,
}
# --- For /analyse_text ---
# Function to split text into structured format
def split_conversation(conversation, default_user="You"):
conversation_lines = conversation.strip().split("\n")
split_lines = []
for line in conversation_lines:
if ":" in line:
user, text = line.split(":", 1)
text = text.strip().strip('"')
split_lines.append({"user": user.strip(), "text": text})
return split_lines
# Function to analyze sentiment for each text entry
def analyze_sentiment(conversation_list):
overall_scores = {"Negative": 0, "Neutral": 0, "Positive": 0}
total_entries = len(conversation_list)
for entry in conversation_list:
analysis = sentiment_pipeline(entry["text"], top_k=None)
entry["analysis"] = analysis
for sentiment in analysis:
overall_scores[sentiment["label"]] += sentiment["score"]
overall_analysis = [
{"label": label, "score": overall_scores[label] / total_entries}
for label in overall_scores
]
return overall_analysis
@app.post("/analyse_text")
def analyse_text(input_data: TextInput):
conversation_list = split_conversation(input_data.text)
overall_analysis = analyze_sentiment(conversation_list)
return {
"analyses": conversation_list,
"overall_analysis": overall_analysis,
}
@app.get("/")
def read_root():
return {
"info": "This is a sentiment analysis API. Use /predict for chunk-wise analysis or /analyse_text for structured conversation analysis."
}