piliguori commited on
Commit
42dfa32
·
verified ·
1 Parent(s): 089e833

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -11
app.py CHANGED
@@ -1,21 +1,31 @@
1
- import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
- model_id = "OSS-Forge/codet5p-770m-vhdl"
5
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
8
 
9
- def generate_vhdl(description):
10
- inputs = tokenizer(description, return_tensors="pt")
11
- outputs = model.generate(**inputs, max_new_tokens=256)
12
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
13
 
14
  iface = gr.Interface(
15
- fn=generate_vhdl,
16
- inputs=gr.Textbox(lines=4, label="Enter VHDL description"),
17
- outputs=gr.Textbox(lines=12, label="Generated VHDL"),
18
- title="VHDL Code Generator",
 
19
  )
20
 
21
- iface.launch()
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import gradio as gr
3
 
4
+ model_id = "OSS-Forge/codet5p-770m-vhdl"
5
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
8
 
9
+
10
+ def generate_output(input_text):
11
+ inputs = tokenizer.encode(input_text, return_tensors='pt')
12
+ outputs = model.generate(
13
+ inputs,
14
+ max_length=256,
15
+ num_beams=5,
16
+ early_stopping=True,
17
+ )
18
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
19
+ return generated_text
20
+
21
 
22
  iface = gr.Interface(
23
+ fn=generate_output,
24
+ inputs=gr.Textbox(lines=5, placeholder='Insert the English description here...'),
25
+ outputs=gr.Textbox(),
26
+ title='VHDL Code Generator (CodeT5+ 770M)',
27
+ description='Generate VHDL code from an English description using a fine-tuned CodeT5+ model.'
28
  )
29
 
30
+
31
+ iface.launch()