Spaces:
Runtime error
Runtime error
| """ | |
| RPG Room Generator App | |
| Sets the sampling parameters and provides minimal interface to the user | |
| https://huggingface.co/blog/how-to-generate | |
| """ | |
| import gradio as gr | |
| from gradio import inputs # allows easier doc lookup in Pycharm | |
| import transformers as tr | |
| MPATH = "./models/mdl_roomgen7" | |
| MODEL = tr.GPT2LMHeadModel.from_pretrained(MPATH) | |
| # ToDo: Will save tokenizer next time so can replace this with a load | |
| SPECIAL_TOKENS = { | |
| 'eos_token': '<|EOS|>', | |
| 'bos_token': '<|endoftext|>', | |
| 'pad_token': '<pad>', | |
| 'sep_token': '<|body|>' | |
| } | |
| TOK = tr.GPT2Tokenizer.from_pretrained("gpt2") | |
| TOK.add_special_tokens(SPECIAL_TOKENS) | |
| SAMPLING_OPTIONS = { | |
| "Reasonable": | |
| { | |
| "top_k": 25, | |
| "temperature": 50, | |
| "top_p": 60 | |
| }, | |
| "Odd": | |
| { | |
| "top_k": 50, | |
| "temperature": 75, | |
| "top_p": 90 | |
| }, | |
| "Insane": | |
| { | |
| "top_k": 300, | |
| "temperature": 100, | |
| "top_p": 85 | |
| }, | |
| } | |
| def generate_room(room_name, room_desc, max_length, sampling_method): | |
| """ | |
| Uses pretrained model to generate text for a dungeon room | |
| Returns: Room description text | |
| """ | |
| prompt = " ".join( | |
| [ | |
| SPECIAL_TOKENS["bos_token"], | |
| room_name, | |
| SPECIAL_TOKENS["sep_token"], | |
| room_desc | |
| ] | |
| ) | |
| # Only want to skip the room name part | |
| to_skip = TOK.encode(" ".join([SPECIAL_TOKENS["bos_token"], room_name, SPECIAL_TOKENS["sep_token"]]), | |
| return_tensors="pt") | |
| ids = TOK.encode(prompt, return_tensors="pt") | |
| # Sample | |
| top_k = SAMPLING_OPTIONS[sampling_method]["top_k"] | |
| temperature = SAMPLING_OPTIONS[sampling_method]["temperature"] / 100. | |
| top_p = SAMPLING_OPTIONS[sampling_method]["top_p"] / 100. | |
| output = MODEL.generate( | |
| ids, | |
| max_length=max_length, | |
| do_sample=True, | |
| top_k=top_k, | |
| temperature=temperature, | |
| top_p=top_p | |
| ) | |
| output = TOK.decode(output[0][to_skip.shape[1]:], clean_up_tokenization_spaces=True).replace(" ", " ") | |
| # Slice off last partial sentence | |
| last_period = output.rfind(".") | |
| if last_period > 0: | |
| output = output[:last_period+1] | |
| return output | |
| if __name__ == "__main__": | |
| iface = gr.Interface( | |
| title="RPG Room Generator", | |
| fn=generate_room, | |
| inputs=[ | |
| inputs.Textbox(lines=1, label="Room Name"), | |
| inputs.Textbox(lines=3, label="Start of Room Description (Optional)", default=""), | |
| inputs.Slider(minimum=50, maximum=1000, default=200, label="Length"), | |
| inputs.Radio(choices=list(SAMPLING_OPTIONS.keys()), default="Odd", label="Craziness"), | |
| ], | |
| outputs="text", | |
| layout="horizontal", | |
| allow_flagging="never", | |
| theme="dark", | |
| ) | |
| app, local_url, share_url = iface.launch() | |