Spaces:
Running
Running
""" | |
File: model_translation.py | |
Description: | |
Loading models for text translations | |
Author: Didier Guillevic | |
Date: 2024-03-16 | |
""" | |
import spaces | |
import logging | |
logger = logging.getLogger(__name__) | |
logging.basicConfig(level=logging.INFO) | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration | |
from transformers import BitsAndBytesConfig | |
from model_spacy import nlp_xx as model_spacy | |
quantization_config = BitsAndBytesConfig( | |
load_in_8bit=True, | |
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5 | |
) | |
# The 100 languages supported by the facebook/m2m100_418M model | |
# https://huggingface.co/facebook/m2m100_418M | |
# plus the 'AUTOMATIC' option where we will use a language detector. | |
language_codes = { | |
'AUTOMATIC': 'auto', | |
'Afrikaans (af)': 'af', | |
'Albanian (sq)': 'sq', | |
'Amharic (am)': 'am', | |
'Arabic (ar)': 'ar', | |
'Armenian (hy)': 'hy', | |
'Asturian (ast)': 'ast', | |
'Azerbaijani (az)': 'az', | |
'Bashkir (ba)': 'ba', | |
'Belarusian (be)': 'be', | |
'Bengali (bn)': 'bn', | |
'Bosnian (bs)': 'bs', | |
'Breton (br)': 'br', | |
'Bulgarian (bg)': 'bg', | |
'Burmese (my)': 'my', | |
'Catalan; Valencian (ca)': 'ca', | |
'Cebuano (ceb)': 'ceb', | |
'Central Khmer (km)': 'km', | |
'Chinese (zh)': 'zh', | |
'Croatian (hr)': 'hr', | |
'Czech (cs)': 'cs', | |
'Danish (da)': 'da', | |
'Dutch; Flemish (nl)': 'nl', | |
'English (en)': 'en', | |
'Estonian (et)': 'et', | |
'Finnish (fi)': 'fi', | |
'French (fr)': 'fr', | |
'Fulah (ff)': 'ff', | |
'Gaelic; Scottish Gaelic (gd)': 'gd', | |
'Galician (gl)': 'gl', | |
'Ganda (lg)': 'lg', | |
'Georgian (ka)': 'ka', | |
'German (de)': 'de', | |
'Greeek (el)': 'el', | |
'Gujarati (gu)': 'gu', | |
'Haitian; Haitian Creole (ht)': 'ht', | |
'Hausa (ha)': 'ha', | |
'Hebrew (he)': 'he', | |
'Hindi (hi)': 'hi', | |
'Hungarian (hu)': 'hu', | |
'Icelandic (is)': 'is', | |
'Igbo (ig)': 'ig', | |
'Iloko (ilo)': 'ilo', | |
'Indonesian (id)': 'id', | |
'Irish (ga)': 'ga', | |
'Italian (it)': 'it', | |
'Japanese (ja)': 'ja', | |
'Javanese (jv)': 'jv', | |
'Kannada (kn)': 'kn', | |
'Kazakh (kk)': 'kk', | |
'Korean (ko)': 'ko', | |
'Lao (lo)': 'lo', | |
'Latvian (lv)': 'lv', | |
'Lingala (ln)': 'ln', | |
'Lithuanian (lt)': 'lt', | |
'Luxembourgish; Letzeburgesch (lb)': 'lb', | |
'Macedonian (mk)': 'mk', | |
'Malagasy (mg)': 'mg', | |
'Malay (ms)': 'ms', | |
'Malayalam (ml)': 'ml', | |
'Marathi (mr)': 'mr', | |
'Mongolian (mn)': 'mn', | |
'Nepali (ne)': 'ne', | |
'Northern Sotho (ns)': 'ns', | |
'Norwegian (no)': 'no', | |
'Occitan (post 1500) (oc)': 'oc', | |
'Oriya (or)': 'or', | |
'Panjabi; Punjabi (pa)': 'pa', | |
'Persian (fa)': 'fa', | |
'Polish (pl)': 'pl', | |
'Portuguese (pt)': 'pt', | |
'Pushto; Pashto (ps)': 'ps', | |
'Romanian; Moldavian; Moldovan (ro)': 'ro', | |
'Russian (ru)': 'ru', | |
'Serbian (sr)': 'sr', | |
'Sindhi (sd)': 'sd', | |
'Sinhala; Sinhalese (si)': 'si', | |
'Slovak (sk)': 'sk', | |
'Slovenian (sl)': 'sl', | |
'Somali (so)': 'so', | |
'Spanish (es)': 'es', | |
'Sundanese (su)': 'su', | |
'Swahili (sw)': 'sw', | |
'Swati (ss)': 'ss', | |
'Swedish (sv)': 'sv', | |
'Tagalog (tl)': 'tl', | |
'Tamil (ta)': 'ta', | |
'Thai (th)': 'th', | |
'Tswana (tn)': 'tn', | |
'Turkish (tr)': 'tr', | |
'Ukrainian (uk)': 'uk', | |
'Urdu (ur)': 'ur', | |
'Uzbek (uz)': 'uz', | |
'Vietnamese (vi)': 'vi', | |
'Welsh (cy)': 'cy', | |
'Western Frisian (fy)': 'fy', | |
'Wolof (wo)': 'wo', | |
'Xhosa (xh)': 'xh', | |
'Yiddish (yi)': 'yi', | |
'Yoruba (yo)': 'yo', | |
'Zulu (zu)': 'zu' | |
} | |
tgt_language_codes = { | |
'English (en)': 'en', | |
'French (fr)': 'fr' | |
} | |
def build_text_chunks( | |
text: str, | |
sents_per_chunk: int=5, | |
words_per_chunk=200) -> list[str]: | |
"""Split a given text into chunks with at most sents_per_chnks and words_per_chunk | |
Given a text: | |
- Split the text into sentences. | |
- Build text chunks: | |
- Consider up to sents_per_chunk | |
- Ensure that we do not exceed words_per_chunk | |
""" | |
# Split text into sentences... | |
sentences = [ | |
sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip() | |
] | |
logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}") | |
# Create text chunks of N sentences | |
chunks = [] | |
chunk = '' | |
chunk_nb_sentences = 0 | |
chunk_nb_words = 0 | |
for i in range(0, len(sentences)): | |
# Get sentence | |
sent = sentences[i] | |
sent_nb_words = len(sent.split()) | |
# If chunk already 'full', save chunk, start new chunk | |
if ( | |
(chunk_nb_words + sent_nb_words > words_per_chunk) or | |
(chunk_nb_sentences + 1 > sents_per_chunk) | |
): | |
chunks.append(chunk) | |
chunk = '' | |
chunk_nb_sentences = 0 | |
chunk_nb_words = 0 | |
# Append sentence to current chunk. One sentence per line. | |
chunk = (chunk + '\n' + sent) if chunk else sent | |
chunk_nb_sentences += 1 | |
chunk_nb_words += sent_nb_words | |
# Append last chunk | |
if chunk: | |
chunks.append(chunk) | |
return chunks | |
class Singleton(type): | |
_instances = {} | |
def __call__(cls, *args, **kwargs): | |
if cls not in cls._instances: | |
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |
return cls._instances[cls] | |
class ModelM2M100(metaclass=Singleton): | |
"""Loads an instance of the M2M100 model. | |
Model: https://huggingface.co/facebook/m2m100_1.2B | |
""" | |
def __init__(self): | |
self._model_name = "facebook/m2m100_418M" | |
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name) | |
self._model = M2M100ForConditionalGeneration.from_pretrained( | |
self._model_name, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
#quantization_config=quantization_config | |
) | |
self._model = torch.compile(self._model) | |
def translate( | |
self, | |
text: str, | |
src_lang: str, | |
tgt_lang: str, | |
chunk_text: bool=True, | |
sents_per_chunk: int=5, | |
words_per_chunk: int=200 | |
) -> str: | |
"""Translate the given text from src_lang to tgt_lang. | |
The text will be split into chunks to ensure the chunks fit into the | |
model input_max_length (usually 512 tokens). | |
""" | |
chunks = [text,] | |
if chunk_text: | |
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) | |
self._tokenizer.src_lang = src_lang | |
translated_chunks = [] | |
for chunk in chunks: | |
input_ids = self._tokenizer( | |
chunk, | |
return_tensors="pt").input_ids.to(self._model.device) | |
outputs = self._model.generate( | |
input_ids=input_ids, | |
forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang)) | |
translated_chunk = self._tokenizer.batch_decode( | |
outputs, | |
skip_special_tokens=True)[0] | |
translated_chunks.append(translated_chunk) | |
return '\n'.join(translated_chunks) | |
def model_name(self): | |
return self._model_name | |
def tokenizer(self): | |
return self._tokenizer | |
def model(self): | |
return self._model | |
def device(self): | |
return self._model.device | |
class ModelMADLAD(metaclass=Singleton): | |
"""Loads an instance of the Google MADLAD model (3B). | |
Model: https://huggingface.co/google/madlad400-3b-mt | |
""" | |
def __init__(self): | |
self._model_name = "google/madlad400-3b-mt" | |
self._input_max_length = 512 # config.json n_positions | |
self._output_max_length = 512 # config.json n_positions | |
self._tokenizer = AutoTokenizer.from_pretrained( | |
self.model_name, use_fast=True | |
) | |
self._model = AutoModelForSeq2SeqLM.from_pretrained( | |
self._model_name, | |
device_map="auto", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True, | |
quantization_config=quantization_config | |
) | |
self._model = torch.compile(self._model) | |
def translate( | |
self, | |
text: str, | |
tgt_lang: str, | |
chunk_text: True, | |
sents_per_chunk: int=5, | |
words_per_chunk: int=5 | |
) -> str: | |
"""Translate given text into the target language. | |
The text will be split into chunks to ensure the chunks fit into the | |
model input_max_length (usually 512 tokens). | |
""" | |
chunks = [text,] | |
if chunk_text: | |
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk) | |
translated_chunks = [] | |
for chunk in chunks: | |
input_text = f"<2{tgt_lang}> {chunk}" | |
logger.info(f" Translating: {input_text[:50]}") | |
input_ids = self._tokenizer( | |
input_text, | |
return_tensors="pt", | |
max_length=self._input_max_length, | |
truncation=True, | |
padding="longest").input_ids.to(self._model.device) | |
outputs = self._model.generate( | |
input_ids=input_ids, | |
max_length=self._output_max_length) | |
translated_chunk = self._tokenizer.decode( | |
outputs[0], | |
skip_special_tokens=True) | |
translated_chunks.append(translated_chunk) | |
return '\n'.join(translated_chunks) | |
def model_name(self): | |
return self._model_name | |
def tokenizer(self): | |
return self._tokenizer | |
def model(self): | |
return self._model | |
def device(self): | |
return self._model.device | |
# Bi-lingual individual models | |
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"]) | |
model_names = { | |
"ar": "Helsinki-NLP/opus-mt-ar-en", | |
"en": "Helsinki-NLP/opus-mt-en-fr", | |
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc", | |
"fr": "Helsinki-NLP/opus-mt-fr-en", | |
"he": "Helsinki-NLP/opus-mt-tc-big-he-en", | |
"zh": "Helsinki-NLP/opus-mt-zh-en", | |
} | |
# Registry for all loaded bilingual models | |
tokenizer_model_registry = {} | |
device = 'cpu' | |
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM): | |
""" | |
Return the (tokenizer, model) for a given source language. | |
""" | |
src_lang = src_lang.lower() | |
# Already loaded? | |
if src_lang in tokenizer_model_registry: | |
return tokenizer_model_registry.get(src_lang) | |
# Load tokenizer and model | |
model_name = model_names.get(src_lang) | |
if not model_name: | |
raise Exception(f"No model defined for language: {src_lang}") | |
# We will leave the models on the CPU (for now) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
if model.config.torch_dtype != torch.float16: | |
model = model.half() | |
model.to(device) | |
tokenizer_model_registry[src_lang] = (tokenizer, model) | |
return (tokenizer, model) | |
# Max number of words for given input text | |
# - Usually 512 tokens (max position encodings, as well as max length) | |
# - Let's set to some number of words somewhat lower than that threshold | |
# - e.g. 200 words | |
max_words_per_chunk = 200 | |