Spaces:
Runtime error
Runtime error
| import nltk | |
| import ssl | |
| import re | |
| try: | |
| _create_unverified_https_context = ssl._create_unverified_context | |
| except AttributeError: | |
| pass | |
| else: | |
| ssl._create_default_https_context = _create_unverified_https_context | |
| nltk.download('punkt') | |
| nltk.download('stopwords') | |
| from transformers import BartTokenizer, PegasusTokenizer | |
| from transformers import BartForConditionalGeneration, PegasusForConditionalGeneration | |
| from tqdm.notebook import tqdm | |
| class Abstractive_Summarization_Model: | |
| def __init__(self): | |
| self.text = None | |
| self.IS_CNNDM = True # whether to use CNNDM dataset or XSum dataset | |
| self.LOWER = False | |
| self.max_length = 1024 #if self.IS_CNNDM else 512 | |
| self.model, self.tokenizer = self.load_model() | |
| def load_model(self): | |
| # Load our model checkpoints | |
| print('[INFO]: Loading model ...') | |
| if self.IS_CNNDM: | |
| model = BartForConditionalGeneration.from_pretrained('Yale-LILY/brio-cnndm-uncased') | |
| tokenizer = BartTokenizer.from_pretrained('Yale-LILY/brio-cnndm-uncased') | |
| else: | |
| model = PegasusForConditionalGeneration.from_pretrained('Yale-LILY/brio-xsum-cased') | |
| tokenizer = PegasusTokenizer.from_pretrained('Yale-LILY/brio-xsum-cased') | |
| print('[INFO]: Model Successfully Loaded :)') | |
| return model, tokenizer | |
| def summarize(self, text): | |
| # generation example | |
| if self.LOWER: | |
| article = text.lower() | |
| else: | |
| article = text | |
| inputs = self.tokenizer([article], max_length=self.max_length, return_tensors="pt", truncation=True) | |
| # Generate Summary | |
| summary_ids = self.model.generate(inputs["input_ids"]) | |
| return self.tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |