Spaces:
Sleeping
Sleeping
| import numpy as np | |
| # import gradio as gr | |
| import os | |
| import tempfile | |
| import shutil | |
| from trainer import Trainer | |
| def predict(input_text, model_type): | |
| if model_type in ['lstm', 'bilstm']: | |
| predicted_label = trainer.predict(input_text ) | |
| elif model_type == 'max_ent': | |
| predicted_label = trainer.predict_maxent(input_text) | |
| elif model_type == 'svm': | |
| predicted_label = trainer.predict_svm(input_text) | |
| return str(predicted_label) | |
| # pass | |
| def predict_omni(input_text, model_type): | |
| predicted_label_net = trainer.predict(input_text ) | |
| predicted_label_maxent = trainer_maxent.predict_maxent(input_text ) | |
| predicted_label_svm = trainer_svm.predict_svm(input_text ) | |
| # if model_type in ['lstm', 'bilstm']: | |
| # predicted_label = trainer.predict(input_text ) | |
| # elif model_type == 'max_ent': | |
| # predicted_label = trainer.predict_maxent(input_text) | |
| # elif model_type == 'svm': | |
| # predicted_label = trainer.predict_svm(input_text) | |
| predicted_text = f"LSTM: {predicted_label_net}, Max Ent: {predicted_label_maxent}, SVM: {predicted_label_svm}" | |
| return predicted_text | |
| # pass | |
| def create_demo(): | |
| USAGE = """## Text Classification | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown(USAGE) | |
| # demo = | |
| # gr.Interface( | |
| # predict, | |
| # # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
| # gr.File(type="filepath"), | |
| # gr.File(type="filepath"), | |
| # cache_examples=False | |
| # ) | |
| input_file = gr.File(type="filepath") | |
| output_file = gr.File(type="filepath") | |
| gr.Interface(fn=greet, inputs="textbox", outputs="textbox") | |
| # gr.Interface( | |
| # predict, | |
| # # gr.Dataframe(type="numpy", datatype="number", row_count=5, col_count=3), | |
| # input_file, | |
| # output_file, | |
| # cache_examples=False | |
| # ) | |
| # inputs = input_file | |
| # outputs = output_file | |
| # gr.Examples( | |
| # examples=[os.path.join(os.path.dirname(__file__), "./gradio_inter/20231104_017.pkl")], | |
| # inputs=inputs, | |
| # fn=predict, | |
| # outputs=outputs, | |
| # ) | |
| return demo | |
| if __name__ == "__main__": | |
| vocab_size = 8000 | |
| sequence_len = 150 | |
| # batch_size = 1024 | |
| batch_size = 256 | |
| nn_epochs = 20 | |
| model_type = "lstm" | |
| # model_type = "bilstm" | |
| # model_type = "max_ent" | |
| # trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
| # print(f"Trainer loaded") | |
| model_type = "lstm" | |
| trainer = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
| model_type = "max_ent" | |
| trainer_maxent = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
| model_type = "svm" | |
| trainer_svm = Trainer(vocab_size, sequence_len, batch_size, nn_epochs, model_type) | |
| while True: | |
| input_text = input() | |
| # if model_type in ["lstm", "bilstm"]: | |
| # label = predict(input_text, model_type) | |
| label = predict_omni(input_text, model_type) | |
| # elif model_type in ["max_ent"]: | |
| # label = | |
| print(label) | |
| # demo = create_demo() | |
| # demo.launch() | |
| # python app_local.py | |