Text_translation / model_translation.py
Didier's picture
Using smaller m2m100 model
4661832
"""
File: model_translation.py
Description:
Loading models for text translations
Author: Didier Guillevic
Date: 2024-03-16
"""
import spaces
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
from transformers import BitsAndBytesConfig
from model_spacy import nlp_xx as model_spacy
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=200.0 # https://discuss.huggingface.co/t/correct-usage-of-bitsandbytesconfig/33809/5
)
# The 100 languages supported by the facebook/m2m100_418M model
# https://huggingface.co/facebook/m2m100_418M
# plus the 'AUTOMATIC' option where we will use a language detector.
language_codes = {
'AUTOMATIC': 'auto',
'Afrikaans (af)': 'af',
'Albanian (sq)': 'sq',
'Amharic (am)': 'am',
'Arabic (ar)': 'ar',
'Armenian (hy)': 'hy',
'Asturian (ast)': 'ast',
'Azerbaijani (az)': 'az',
'Bashkir (ba)': 'ba',
'Belarusian (be)': 'be',
'Bengali (bn)': 'bn',
'Bosnian (bs)': 'bs',
'Breton (br)': 'br',
'Bulgarian (bg)': 'bg',
'Burmese (my)': 'my',
'Catalan; Valencian (ca)': 'ca',
'Cebuano (ceb)': 'ceb',
'Central Khmer (km)': 'km',
'Chinese (zh)': 'zh',
'Croatian (hr)': 'hr',
'Czech (cs)': 'cs',
'Danish (da)': 'da',
'Dutch; Flemish (nl)': 'nl',
'English (en)': 'en',
'Estonian (et)': 'et',
'Finnish (fi)': 'fi',
'French (fr)': 'fr',
'Fulah (ff)': 'ff',
'Gaelic; Scottish Gaelic (gd)': 'gd',
'Galician (gl)': 'gl',
'Ganda (lg)': 'lg',
'Georgian (ka)': 'ka',
'German (de)': 'de',
'Greeek (el)': 'el',
'Gujarati (gu)': 'gu',
'Haitian; Haitian Creole (ht)': 'ht',
'Hausa (ha)': 'ha',
'Hebrew (he)': 'he',
'Hindi (hi)': 'hi',
'Hungarian (hu)': 'hu',
'Icelandic (is)': 'is',
'Igbo (ig)': 'ig',
'Iloko (ilo)': 'ilo',
'Indonesian (id)': 'id',
'Irish (ga)': 'ga',
'Italian (it)': 'it',
'Japanese (ja)': 'ja',
'Javanese (jv)': 'jv',
'Kannada (kn)': 'kn',
'Kazakh (kk)': 'kk',
'Korean (ko)': 'ko',
'Lao (lo)': 'lo',
'Latvian (lv)': 'lv',
'Lingala (ln)': 'ln',
'Lithuanian (lt)': 'lt',
'Luxembourgish; Letzeburgesch (lb)': 'lb',
'Macedonian (mk)': 'mk',
'Malagasy (mg)': 'mg',
'Malay (ms)': 'ms',
'Malayalam (ml)': 'ml',
'Marathi (mr)': 'mr',
'Mongolian (mn)': 'mn',
'Nepali (ne)': 'ne',
'Northern Sotho (ns)': 'ns',
'Norwegian (no)': 'no',
'Occitan (post 1500) (oc)': 'oc',
'Oriya (or)': 'or',
'Panjabi; Punjabi (pa)': 'pa',
'Persian (fa)': 'fa',
'Polish (pl)': 'pl',
'Portuguese (pt)': 'pt',
'Pushto; Pashto (ps)': 'ps',
'Romanian; Moldavian; Moldovan (ro)': 'ro',
'Russian (ru)': 'ru',
'Serbian (sr)': 'sr',
'Sindhi (sd)': 'sd',
'Sinhala; Sinhalese (si)': 'si',
'Slovak (sk)': 'sk',
'Slovenian (sl)': 'sl',
'Somali (so)': 'so',
'Spanish (es)': 'es',
'Sundanese (su)': 'su',
'Swahili (sw)': 'sw',
'Swati (ss)': 'ss',
'Swedish (sv)': 'sv',
'Tagalog (tl)': 'tl',
'Tamil (ta)': 'ta',
'Thai (th)': 'th',
'Tswana (tn)': 'tn',
'Turkish (tr)': 'tr',
'Ukrainian (uk)': 'uk',
'Urdu (ur)': 'ur',
'Uzbek (uz)': 'uz',
'Vietnamese (vi)': 'vi',
'Welsh (cy)': 'cy',
'Western Frisian (fy)': 'fy',
'Wolof (wo)': 'wo',
'Xhosa (xh)': 'xh',
'Yiddish (yi)': 'yi',
'Yoruba (yo)': 'yo',
'Zulu (zu)': 'zu'
}
tgt_language_codes = {
'English (en)': 'en',
'French (fr)': 'fr'
}
def build_text_chunks(
text: str,
sents_per_chunk: int=5,
words_per_chunk=200) -> list[str]:
"""Split a given text into chunks with at most sents_per_chnks and words_per_chunk
Given a text:
- Split the text into sentences.
- Build text chunks:
- Consider up to sents_per_chunk
- Ensure that we do not exceed words_per_chunk
"""
# Split text into sentences...
sentences = [
sent.text.strip() for sent in model_spacy(text).sents if sent.text.strip()
]
logger.info(f"TEXT: {text[:25]}, NB_SENTS: {len(sentences)}")
# Create text chunks of N sentences
chunks = []
chunk = ''
chunk_nb_sentences = 0
chunk_nb_words = 0
for i in range(0, len(sentences)):
# Get sentence
sent = sentences[i]
sent_nb_words = len(sent.split())
# If chunk already 'full', save chunk, start new chunk
if (
(chunk_nb_words + sent_nb_words > words_per_chunk) or
(chunk_nb_sentences + 1 > sents_per_chunk)
):
chunks.append(chunk)
chunk = ''
chunk_nb_sentences = 0
chunk_nb_words = 0
# Append sentence to current chunk. One sentence per line.
chunk = (chunk + '\n' + sent) if chunk else sent
chunk_nb_sentences += 1
chunk_nb_words += sent_nb_words
# Append last chunk
if chunk:
chunks.append(chunk)
return chunks
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
class ModelM2M100(metaclass=Singleton):
"""Loads an instance of the M2M100 model.
Model: https://huggingface.co/facebook/m2m100_1.2B
"""
def __init__(self):
self._model_name = "facebook/m2m100_418M"
self._tokenizer = M2M100Tokenizer.from_pretrained(self._model_name)
self._model = M2M100ForConditionalGeneration.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
#quantization_config=quantization_config
)
self._model = torch.compile(self._model)
@spaces.GPU
def translate(
self,
text: str,
src_lang: str,
tgt_lang: str,
chunk_text: bool=True,
sents_per_chunk: int=5,
words_per_chunk: int=200
) -> str:
"""Translate the given text from src_lang to tgt_lang.
The text will be split into chunks to ensure the chunks fit into the
model input_max_length (usually 512 tokens).
"""
chunks = [text,]
if chunk_text:
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
self._tokenizer.src_lang = src_lang
translated_chunks = []
for chunk in chunks:
input_ids = self._tokenizer(
chunk,
return_tensors="pt").input_ids.to(self._model.device)
outputs = self._model.generate(
input_ids=input_ids,
forced_bos_token_id=self._tokenizer.get_lang_id(tgt_lang))
translated_chunk = self._tokenizer.batch_decode(
outputs,
skip_special_tokens=True)[0]
translated_chunks.append(translated_chunk)
return '\n'.join(translated_chunks)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
@property
def device(self):
return self._model.device
class ModelMADLAD(metaclass=Singleton):
"""Loads an instance of the Google MADLAD model (3B).
Model: https://huggingface.co/google/madlad400-3b-mt
"""
def __init__(self):
self._model_name = "google/madlad400-3b-mt"
self._input_max_length = 512 # config.json n_positions
self._output_max_length = 512 # config.json n_positions
self._tokenizer = AutoTokenizer.from_pretrained(
self.model_name, use_fast=True
)
self._model = AutoModelForSeq2SeqLM.from_pretrained(
self._model_name,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config=quantization_config
)
self._model = torch.compile(self._model)
@spaces.GPU
def translate(
self,
text: str,
tgt_lang: str,
chunk_text: True,
sents_per_chunk: int=5,
words_per_chunk: int=5
) -> str:
"""Translate given text into the target language.
The text will be split into chunks to ensure the chunks fit into the
model input_max_length (usually 512 tokens).
"""
chunks = [text,]
if chunk_text:
chunks = build_text_chunks(text, sents_per_chunk, words_per_chunk)
translated_chunks = []
for chunk in chunks:
input_text = f"<2{tgt_lang}> {chunk}"
logger.info(f" Translating: {input_text[:50]}")
input_ids = self._tokenizer(
input_text,
return_tensors="pt",
max_length=self._input_max_length,
truncation=True,
padding="longest").input_ids.to(self._model.device)
outputs = self._model.generate(
input_ids=input_ids,
max_length=self._output_max_length)
translated_chunk = self._tokenizer.decode(
outputs[0],
skip_special_tokens=True)
translated_chunks.append(translated_chunk)
return '\n'.join(translated_chunks)
@property
def model_name(self):
return self._model_name
@property
def tokenizer(self):
return self._tokenizer
@property
def model(self):
return self._model
@property
def device(self):
return self._model.device
# Bi-lingual individual models
src_langs = set(["ar", "en", "fa", "fr", "he", "ja", "zh"])
model_names = {
"ar": "Helsinki-NLP/opus-mt-ar-en",
"en": "Helsinki-NLP/opus-mt-en-fr",
"fa": "Helsinki-NLP/opus-mt-tc-big-fa-itc",
"fr": "Helsinki-NLP/opus-mt-fr-en",
"he": "Helsinki-NLP/opus-mt-tc-big-he-en",
"zh": "Helsinki-NLP/opus-mt-zh-en",
}
# Registry for all loaded bilingual models
tokenizer_model_registry = {}
device = 'cpu'
def get_tokenizer_model_for_src_lang(src_lang: str) -> (AutoTokenizer, AutoModelForSeq2SeqLM):
"""
Return the (tokenizer, model) for a given source language.
"""
src_lang = src_lang.lower()
# Already loaded?
if src_lang in tokenizer_model_registry:
return tokenizer_model_registry.get(src_lang)
# Load tokenizer and model
model_name = model_names.get(src_lang)
if not model_name:
raise Exception(f"No model defined for language: {src_lang}")
# We will leave the models on the CPU (for now)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
if model.config.torch_dtype != torch.float16:
model = model.half()
model.to(device)
tokenizer_model_registry[src_lang] = (tokenizer, model)
return (tokenizer, model)
# Max number of words for given input text
# - Usually 512 tokens (max position encodings, as well as max length)
# - Let's set to some number of words somewhat lower than that threshold
# - e.g. 200 words
max_words_per_chunk = 200