Spaces:
Runtime error
Runtime error
| import re | |
| from typing import Optional, List | |
| import vllm | |
| from fire import Fire | |
| from pydantic import BaseModel | |
| from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM | |
| class ZeroShotChatTemplate: | |
| # This is the default template used in llama-factory for training | |
| texts: List[str] = [] | |
| def make_prompt(prompt: str) -> str: | |
| return f"Human: {prompt}\nAssistant: " | |
| def get_stopping_words() -> List[str]: | |
| return ["Human:"] | |
| def extract_answer(text: str) -> str: | |
| filtered = "".join([char for char in text if char.isdigit() or char == " "]) | |
| if not filtered.strip(): | |
| return text | |
| return re.findall(pattern=r"\d+", string=filtered)[-1] | |
| class VLLMModel(BaseModel, arbitrary_types_allowed=True): | |
| path_model: str | |
| model: vllm.LLM = None | |
| tokenizer: Optional[PreTrainedTokenizer] = None | |
| max_input_length: int = 512 | |
| max_output_length: int = 512 | |
| stopping_words: Optional[List[str]] = None | |
| def load(self): | |
| if self.model is None: | |
| self.model = vllm.LLM(model=self.path_model, trust_remote_code=True) | |
| if self.tokenizer is None: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) | |
| def format_prompt(self, prompt: str) -> str: | |
| self.load() | |
| prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ") | |
| return prompt | |
| def make_kwargs(self, do_sample: bool, **kwargs) -> dict: | |
| if self.stopping_words: | |
| kwargs.update(stop=self.stopping_words) | |
| params = vllm.SamplingParams( | |
| temperature=0.5 if do_sample else 0.0, | |
| max_tokens=self.max_output_length, | |
| **kwargs, | |
| ) | |
| outputs = dict(sampling_params=params, use_tqdm=False) | |
| return outputs | |
| def run(self, prompt: str) -> str: | |
| prompt = self.format_prompt(prompt) | |
| outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False)) | |
| pred = outputs[0].outputs[0].text | |
| pred = pred.split("<|endoftext|>")[0] | |
| return pred | |
| def upload_to_hub(path: str, repo_id: str): | |
| tokenizer = AutoTokenizer.from_pretrained(path) | |
| model = AutoModelForCausalLM.from_pretrained(path) | |
| model.push_to_hub(repo_id) | |
| tokenizer.push_to_hub(repo_id) | |
| def main( | |
| question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?", | |
| **kwargs, | |
| ): | |
| model = VLLMModel(**kwargs) | |
| demo = ZeroShotChatTemplate() | |
| model.stopping_words = demo.get_stopping_words() | |
| prompt = demo.make_prompt(question) | |
| raw_outputs = model.run(prompt) | |
| pred = demo.extract_answer(raw_outputs) | |
| print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred)) | |
| """ | |
| p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo | |
| p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo | |
| """ | |
| if __name__ == "__main__": | |
| Fire() | |