|
import torch |
|
from utils import Callable_tokenizer, preprocess_en |
|
from models import Seq2seq_with_attention, Encoder, Decoder, Attention |
|
import gradio as gr |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
seq2seq_with_attention = torch.load("./seq2seq_with_attention_df-CoVoST2_df-opus_seed-123_subword.bin", map_location=device, weights_only=False) |
|
|
|
en_sp = Callable_tokenizer('./tokenizers/NEW_en_vocab_df-CoVoST2_df-opus_seed-123_vocab-16K_FULL.model') |
|
ar_sp = Callable_tokenizer('./tokenizers/NEW_ar_vocab_df-CoVoST2_df-opus_seed-123_vocab-32K_FULL.model') |
|
|
|
def pre_processor(text): |
|
preprocessed = preprocess_en(text) |
|
en_tokens = torch.tensor(en_sp.user_tokenization(preprocessed)).unsqueeze(0).to(device) |
|
return en_tokens |
|
|
|
def post_processor(raw_output): |
|
return ar_sp.decode(raw_output[1:-1]) |
|
|
|
@torch.no_grad |
|
def lunch(raw_input, maxtries=30): |
|
en_tokens = pre_processor(raw_input) |
|
output = seq2seq_with_attention.translate(en_tokens, maxtries) |
|
return post_processor(output) |
|
|
|
|
|
custom_css ='.gr-button {background-color: #bf4b04; color: white;}' |
|
with gr.Blocks(css=custom_css) as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
input_text = gr.Textbox(label='English Sentence') |
|
gr.Examples(['How are you?', |
|
'She is a good girl.', |
|
'Who is better than me?!'], |
|
inputs=input_text, label="Examples: ") |
|
with gr.Column(): |
|
output = gr.Textbox(label="Arabic Translation") |
|
|
|
start_btn = gr.Button(value='Submit', elem_classes=["gr-button"]) |
|
start_btn.click(fn=lunch, inputs=input_text, outputs=output) |
|
|
|
demo.launch() |
|
|