Rsnarsna commited on
Commit
cfcbe55
·
verified ·
1 Parent(s): 1fc75a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -47
app.py CHANGED
@@ -1,15 +1,13 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import pipeline, XLMRobertaTokenizer, XLMRobertaForSequenceClassification
4
  import torch
5
 
6
- # Load the model and tokenizer
7
- model_name = "citizenlab/twitter-xlm-roberta-base-sentiment-finetunned"
8
- tokenizer = XLMRobertaTokenizer.from_pretrained(model_name)
9
- model = XLMRobertaForSequenceClassification.from_pretrained(model_name)
10
-
11
- # Define the sentiment analysis pipeline using the model and tokenizer
12
- pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)
13
 
14
  # Define a request model for the input text
15
  class SentimentRequest(BaseModel):
@@ -23,14 +21,6 @@ class SentimentResponse(BaseModel):
23
  # Initialize FastAPI app
24
  app = FastAPI()
25
 
26
- # Function to split text into smaller chunks
27
- def chunk_text(text, max_length=512):
28
- # Tokenize the text
29
- tokens = tokenizer.encode(text, truncation=False)
30
-
31
- # Split into chunks of max_length tokens each
32
- return [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
33
-
34
  @app.get("/")
35
  async def read_root():
36
  return {"message": "Welcome to the Sentiment Analysis API! Use '/predict' to analyze sentiment."}
@@ -38,38 +28,11 @@ async def read_root():
38
  @app.post("/predict")
39
  async def analyze_sentiment(input: SentimentRequest):
40
  text = input.text
 
 
41
 
42
- # Split the input text into chunks if it exceeds the token limit
43
- chunks = chunk_text(text, max_length=512) # 512 tokens is the max for XLM-Roberta
44
-
45
- # Run sentiment analysis for each chunk
46
- analysis_results = []
47
- for chunk in chunks:
48
- # Convert chunk back to text
49
- chunk_text = tokenizer.decode(chunk, skip_special_tokens=True)
50
-
51
- # Tokenize the chunk text
52
- inputs = tokenizer(chunk_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
53
-
54
- # Run sentiment analysis through the pipeline
55
- with torch.no_grad(): # No need to compute gradients for inference
56
- analysis = pipe(**inputs, top_k=None) # Get all possible labels
57
-
58
- # Extract the result as a dictionary of labels and confidence scores
59
- result = {entry['label']: entry['score'] for entry in analysis}
60
- analysis_results.append(result)
61
-
62
- # Aggregate analysis results
63
- combined_analysis = {}
64
- for result in analysis_results:
65
- for label, score in result.items():
66
- if label in combined_analysis:
67
- combined_analysis[label] += score # Sum up scores for the same label
68
- else:
69
- combined_analysis[label] = score
70
-
71
  # Return the sentiment analysis result as a response
72
- return SentimentResponse(text=text, analysis=combined_analysis)
73
 
74
  # Run the application with Uvicorn (from the terminal/command line)
75
- # uvicorn app_name:app --reload
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import pipeline
4
  import torch
5
 
6
+ # Load the sentiment analysis pipeline
7
+ pipe = pipeline(
8
+ "text-classification",
9
+ model="citizenlab/twitter-xlm-roberta-base-sentiment-finetunned"
10
+ )
 
 
11
 
12
  # Define a request model for the input text
13
  class SentimentRequest(BaseModel):
 
21
  # Initialize FastAPI app
22
  app = FastAPI()
23
 
 
 
 
 
 
 
 
 
24
  @app.get("/")
25
  async def read_root():
26
  return {"message": "Welcome to the Sentiment Analysis API! Use '/predict' to analyze sentiment."}
 
28
  @app.post("/predict")
29
  async def analyze_sentiment(input: SentimentRequest):
30
  text = input.text
31
+ analysis = pipe(text, top_k=None) # top_k=None to get all possible labels
32
+ result = {entry['label']: entry['score'] for entry in analysis}
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Return the sentiment analysis result as a response
35
+ return SentimentResponse(text=text, analysis=result)
36
 
37
  # Run the application with Uvicorn (from the terminal/command line)
38
+ # uvicorn app_name:app --reload