Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import warnings | |
| class Chatbot(): | |
| def __init__(self): | |
| self.tokenizer = AutoTokenizer.from_pretrained('kakaobrain/kogpt', revision='KoGPT6B-ryan1.5b') | |
| special_tokens_dict = {'additional_special_tokens': ['<sep>', '<eos>', '<sos>', '#@์ด๋ฆ#', '#@๊ณ์ #', '#@์ ์#', '#@์ ๋ฒ#', '#@๊ธ์ต#', '#@๋ฒํธ#', '#@์ฃผ์#', '#@์์#', '#@๊ธฐํ#']} | |
| num_added_toks = self.tokenizer.add_special_tokens(special_tokens_dict) | |
| self.model = AutoModelForCausalLM.from_pretrained("/workspace/test_trainer/checkpoint-10000") | |
| self.model.resize_token_embeddings(len(self.tokenizer)) | |
| self.model = self.model.cuda() | |
| self.info = None | |
| self.talk = [] | |
| def initialize(self, topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex): | |
| def encode(age): | |
| if age < 20: | |
| age = "20๋ ๋ฏธ๋ง" | |
| elif age >= 70: | |
| age = "70๋ ์ด์" | |
| else: | |
| age = str(age // 10 * 10) + "๋" | |
| return age | |
| bot_age = encode(bot_age) | |
| my_age = encode(my_age) | |
| self.info = f"์ผ์ ๋ํ {topic}<sep>P01:{my_addr} {my_age} {my_sex}<sep>P02:{bot_addr} {bot_age} {bot_sex}<sep>" | |
| return self.info_check() | |
| def info_check(self): | |
| return self.info.replace('<sep>', '\n').replace('P01', '๋น์ ').replace('P02', '์ฑ๋ด') | |
| def reset_talk(self): | |
| self.talk = [] | |
| def test(self, myinp): | |
| state = None | |
| inp = "P01<sos>" + myinp + "<eos>" | |
| self.talk.append(inp) | |
| self.talk.append("P02<sos>") | |
| while True: | |
| now_inp = self.info + "".join(self.talk) | |
| inputs = self.tokenizer(now_inp, max_length=1024, truncation='longest_first', return_tensors='pt') | |
| seq_len = inputs.input_ids.size(1) | |
| if seq_len > 512 * 0.8: | |
| state = f"<์ฃผ์> ํ์ฌ ๋ํ ๊ธธ์ด๊ฐ ๊ณง ์ต๋ ๊ธธ์ด์ ๋๋ฌํฉ๋๋ค. ({seq_len} / 512)" | |
| if seq_len >= 512: | |
| state = "<์ฃผ์> ๋ํ ๊ธธ์ด๊ฐ ๋๋ฌด ๊ธธ์ด์ก๊ธฐ ๋๋ฌธ์, ์ดํ ๋ํ๋ ๋งจ ์์ ๋ฐํ๋ฅผ ์กฐ๊ธ์ฉ ์ง์ฐ๋ฉด์ ์งํ๋ฉ๋๋ค." | |
| talk = talk[1:] | |
| else: | |
| break | |
| out = self.model.generate( | |
| inputs=inputs.input_ids.cuda(), | |
| attention_mask=inputs.attention_mask.cuda(), | |
| max_length=512, | |
| do_sample=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.encode('<eos>')[0] | |
| ) | |
| out = self.tokenizer.batch_decode(out) | |
| real_out = out[0][len(now_inp):-5] | |
| self.talk[-1] += out[0][len(now_inp):] | |
| return [(self.talk[i][8:-5], self.talk[i+1][8:-5]) for i in range(0, len(self.talk)-1, 2)] | |
| if __name__ == "__main__": | |
| warnings.filterwarnings("ignore") | |
| chatbot = Chatbot() | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# <center>MINDs Lab Brain's Fast Neural Chit-Chatbot</center>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| topic = gr.Radio(label="Topic", choices=['์ฌ๊ฐ ์ํ', '์์ฌ/๊ต์ก', '๋ฏธ์ฉ๊ณผ ๊ฑด๊ฐ', '์์๋ฃ', '์๊ฑฐ๋(์ผํ)', '์ผ๊ณผ ์ง์ ', '์ฃผ๊ฑฐ์ ์ํ', '๊ฐ์ธ ๋ฐ ๊ด๊ณ', 'ํ์ฌ']) | |
| with gr.Column(): | |
| gr.Markdown(f"Bot's persona") | |
| bot_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
| bot_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1) | |
| bot_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"]) | |
| with gr.Column(): | |
| gr.Markdown(f"Your persona") | |
| my_addr = gr.Dropdown(label="์ง์ญ", choices=['์์ธํน๋ณ์', '๊ฒฝ๊ธฐ๋', '๋ถ์ฐ๊ด์ญ์', '๋์ ๊ด์ญ์', '๊ด์ฃผ๊ด์ญ์', '์ธ์ฐ๊ด์ญ์', '๊ฒฝ์๋จ๋', '์ธ์ฒ๊ด์ญ์', '์ถฉ์ฒญ๋ถ๋', '์ ์ฃผ๋', '๊ฐ์๋', '์ถฉ์ฒญ๋จ๋', '์ ๋ผ๋ถ๋', '๋๊ตฌ๊ด์ญ์', '์ ๋ผ๋จ๋', '๊ฒฝ์๋ถ๋', '์ธ์ข ํน๋ณ์์น์', '๊ธฐํ']) | |
| my_age = gr.Slider(label="๋์ด", minimum=10, maximum=80, value=45, step=1) | |
| my_sex = gr.Radio(label="์ฑ๋ณ", choices=["๋จ์ฑ", "์ฌ์ฑ"]) | |
| with gr.Row(): | |
| btn = gr.Button(label="์ ์ฉ") | |
| state = gr.Textbox(label="์ํ") | |
| btn.click( | |
| fn=chatbot.initialize, | |
| inputs=[topic, bot_addr, bot_age, bot_sex, my_addr, my_age, my_sex], | |
| outputs=state | |
| ) | |
| with gr.Column(): | |
| screen = gr.Chatbot(label="์ต๋ช ์ ์๋") | |
| with gr.Row(): | |
| speak = gr.Textbox(label="์ ๋ ฅ์ฐฝ") | |
| btn = gr.Button(label="Talk") | |
| btn.click( | |
| fn=chatbot.test, | |
| inputs=speak, | |
| outputs=screen | |
| ) | |
| demo.launch(share=True) | |