Iker commited on
Commit
e142967
·
1 Parent(s): 9febb95

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +168 -0
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import io
4
+ from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
5
+ import time
6
+ import json
7
+ from typing import List
8
+ import torch
9
+ import random
10
+
11
+ if torch.cuda.is_available():
12
+ device = torch.device("cuda:0")
13
+ else:
14
+ device = torch.device("cpu")
15
+ logging.warning("GPU not found, using CPU, translation will be very slow.")
16
+
17
+ st.cache(suppress_st_warning=True, allow_output_mutation=True)
18
+ st.set_page_config(page_title="M2M100 Translator")
19
+
20
+ lang_id = {
21
+ "Afrikaans": "af",
22
+ "Amharic": "am",
23
+ "Arabic": "ar",
24
+ "Asturian": "ast",
25
+ "Azerbaijani": "az",
26
+ "Bashkir": "ba",
27
+ "Belarusian": "be",
28
+ "Bulgarian": "bg",
29
+ "Bengali": "bn",
30
+ "Breton": "br",
31
+ "Bosnian": "bs",
32
+ "Catalan": "ca",
33
+ "Cebuano": "ceb",
34
+ "Czech": "cs",
35
+ "Welsh": "cy",
36
+ "Danish": "da",
37
+ "German": "de",
38
+ "Greeek": "el",
39
+ "English": "en",
40
+ "Spanish": "es",
41
+ "Estonian": "et",
42
+ "Persian": "fa",
43
+ "Fulah": "ff",
44
+ "Finnish": "fi",
45
+ "French": "fr",
46
+ "Western Frisian": "fy",
47
+ "Irish": "ga",
48
+ "Gaelic": "gd",
49
+ "Galician": "gl",
50
+ "Gujarati": "gu",
51
+ "Hausa": "ha",
52
+ "Hebrew": "he",
53
+ "Hindi": "hi",
54
+ "Croatian": "hr",
55
+ "Haitian": "ht",
56
+ "Hungarian": "hu",
57
+ "Armenian": "hy",
58
+ "Indonesian": "id",
59
+ "Igbo": "ig",
60
+ "Iloko": "ilo",
61
+ "Icelandic": "is",
62
+ "Italian": "it",
63
+ "Japanese": "ja",
64
+ "Javanese": "jv",
65
+ "Georgian": "ka",
66
+ "Kazakh": "kk",
67
+ "Central Khmer": "km",
68
+ "Kannada": "kn",
69
+ "Korean": "ko",
70
+ "Luxembourgish": "lb",
71
+ "Ganda": "lg",
72
+ "Lingala": "ln",
73
+ "Lao": "lo",
74
+ "Lithuanian": "lt",
75
+ "Latvian": "lv",
76
+ "Malagasy": "mg",
77
+ "Macedonian": "mk",
78
+ "Malayalam": "ml",
79
+ "Mongolian": "mn",
80
+ "Marathi": "mr",
81
+ "Malay": "ms",
82
+ "Burmese": "my",
83
+ "Nepali": "ne",
84
+ "Dutch": "nl",
85
+ "Norwegian": "no",
86
+ "Northern Sotho": "ns",
87
+ "Occitan": "oc",
88
+ "Oriya": "or",
89
+ "Panjabi": "pa",
90
+ "Polish": "pl",
91
+ "Pushto": "ps",
92
+ "Portuguese": "pt",
93
+ "Romanian": "ro",
94
+ "Russian": "ru",
95
+ "Sindhi": "sd",
96
+ "Sinhala": "si",
97
+ "Slovak": "sk",
98
+ "Slovenian": "sl",
99
+ "Somali": "so",
100
+ "Albanian": "sq",
101
+ "Serbian": "sr",
102
+ "Swati": "ss",
103
+ "Sundanese": "su",
104
+ "Swedish": "sv",
105
+ "Swahili": "sw",
106
+ "Tamil": "ta",
107
+ "Thai": "th",
108
+ "Tagalog": "tl",
109
+ "Tswana": "tn",
110
+ "Turkish": "tr",
111
+ "Ukrainian": "uk",
112
+ "Urdu": "ur",
113
+ "Uzbek": "uz",
114
+ "Vietnamese": "vi",
115
+ "Wolof": "wo",
116
+ "Xhosa": "xh",
117
+ "Yiddish": "yi",
118
+ "Yoruba": "yo",
119
+ "Chinese": "zh",
120
+ "Zulu": "zu",
121
+ }
122
+
123
+
124
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
125
+ def load_model(
126
+ pretrained_model: str = "facebook/m2m100_418M",
127
+ cache_dir: str = "models/",
128
+ ):
129
+ tokenizer = M2M100Tokenizer.from_pretrained(pretrained_model, cache_dir=cache_dir)
130
+ model = M2M100ForConditionalGeneration.from_pretrained(
131
+ pretrained_model, cache_dir=cache_dir
132
+ ).to(device)
133
+ model.eval()
134
+ return tokenizer, model
135
+
136
+
137
+ st.title("M2M100 Translator")
138
+
139
+
140
+ user_input: str = st.text_area(
141
+ "Input text",
142
+ height=200,
143
+ max_chars=5120,
144
+ )
145
+
146
+ source_lang = st.selectbox(label="Source language", options=list(lang_id.keys()))
147
+ target_lang = st.selectbox(label="Target language", options=list(lang_id.keys()))
148
+
149
+ if st.button("Run"):
150
+ time_start = time.time()
151
+ tokenizer, model = load_model()
152
+
153
+ src_lang = lang_id[source_lang]
154
+ trg_lang = lang_id[target_lang]
155
+ tokenizer.src_lang = src_lang
156
+ with torch.no_grad():
157
+ encoded_input = tokenizer(user_input, return_tensors="pt").to(device)
158
+ generated_tokens = model.generate(
159
+ **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
160
+ )
161
+ translated_text = tokenizer.batch_decode(
162
+ generated_tokens, skip_special_tokens=True
163
+ )[0]
164
+
165
+ time_end = time.time()
166
+ st.success(translated_text)
167
+
168
+ st.write(f"Computation time: {round((time_end-time_start),3)} segs")