zhiqu22
commited on
Commit
·
1c61cca
1
Parent(s):
1d88f72
updates
Browse files- README.md +13 -7
- modeling_mitre.py +27 -26
README.md
CHANGED
@@ -32,13 +32,13 @@ pipeline_tag: translation
|
|
32 |
# MITRE 466M
|
33 |
|
34 |
## Description
|
35 |
-
MITRE (
|
36 |
The technology, i.e., registering, is introduced in our [paper](url_placeholder).
|
37 |
This repository allows you employ our pre-trained model for inference. If you want to reproduce the data mining and training, please refer to this [repository](url_placeholder).
|
38 |
|
39 |
-
The model
|
40 |
-
You can
|
41 |
-
|
42 |
|
43 |
|
44 |
## Usages
|
@@ -65,11 +65,12 @@ After get the objects of the model and the tokenizer, we can do translation.
|
|
65 |
```python
|
66 |
english_text = "I have a red apple."
|
67 |
chinese_text = "我有一个红苹果。"
|
|
|
68 |
model.eval()
|
69 |
|
70 |
-
# Translating from one or several sentences to a
|
71 |
src_tokens = tokenizer.encode_source_tokens_to_input_ids([english_text, ], target_language="zh")
|
72 |
-
# Translating from one or several sentences to
|
73 |
# src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
|
74 |
|
75 |
generated_tokens = model.generate(src_tokens.cuda())
|
@@ -83,11 +84,16 @@ print(results)
|
|
83 |
# 1. The difference between tgt_tokens and labels is that the eos_tokens are moved to the right side.
|
84 |
# 2. We recommend using 'tokenizer.encode_target_tokens_to_labels' instead of modifying tgt_tokens,
|
85 |
# because 'tokenizer.encode_target_tokens_to_input_ids' has pads.
|
86 |
-
# 3. You can refer
|
87 |
# tgt_tokens = tokenizer.encode_target_tokens_to_input_ids(chinese_text)
|
88 |
# labels = tokenizer.encode_target_tokens_to_labels(chinese_text)
|
89 |
```
|
90 |
|
|
|
|
|
|
|
|
|
|
|
91 |
## Languages covered
|
92 |
Germanic: English (en), German (de), Dutch; Flemish (nl), Swedish (sv), Danish (da), Afrikaans (af)
|
93 |
Romance: French (fr), Spanish (es), Italian (it), Portuguese (pt), Romanian; Moldavian; Moldovan (ro)
|
|
|
32 |
# MITRE 466M
|
33 |
|
34 |
## Description
|
35 |
+
MITRE (Multilingual Translation with Registers) is a multilingual, decoder-only model designed for many-to-many translation tasks.
|
36 |
The technology, i.e., registering, is introduced in our [paper](url_placeholder).
|
37 |
This repository allows you employ our pre-trained model for inference. If you want to reproduce the data mining and training, please refer to this [repository](url_placeholder).
|
38 |
|
39 |
+
The model supports direct translation across 552 directions for 24 languages spanning over 5 language families.
|
40 |
+
You can use our models directly via the `transformers` libs.
|
41 |
+
An alternative version of MITRE with 913M parameters is also available in this [repository](url_placeholder).
|
42 |
|
43 |
|
44 |
## Usages
|
|
|
65 |
```python
|
66 |
english_text = "I have a red apple."
|
67 |
chinese_text = "我有一个红苹果。"
|
68 |
+
model.half() # recommended
|
69 |
model.eval()
|
70 |
|
71 |
+
# Translating from one or several sentences to a 'target_language'
|
72 |
src_tokens = tokenizer.encode_source_tokens_to_input_ids([english_text, ], target_language="zh")
|
73 |
+
# Translating from one or several sentences to given languages
|
74 |
# src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
|
75 |
|
76 |
generated_tokens = model.generate(src_tokens.cuda())
|
|
|
84 |
# 1. The difference between tgt_tokens and labels is that the eos_tokens are moved to the right side.
|
85 |
# 2. We recommend using 'tokenizer.encode_target_tokens_to_labels' instead of modifying tgt_tokens,
|
86 |
# because 'tokenizer.encode_target_tokens_to_input_ids' has pads.
|
87 |
+
# 3. You can refer to our code for detailed implementation.
|
88 |
# tgt_tokens = tokenizer.encode_target_tokens_to_input_ids(chinese_text)
|
89 |
# labels = tokenizer.encode_target_tokens_to_labels(chinese_text)
|
90 |
```
|
91 |
|
92 |
+
## Notes
|
93 |
+
We basically follow the style of [M2M](https://huggingface.co/facebook/m2m100_418M), however, we make some necessary improvements to reduce cost in generation.
|
94 |
+
You can refer to the codes of 'generate()' in [modeling_mitre.py](https://huggingface.co/naist-nlp/mitre_466m/blob/main/modeling_mitre.py) for much more details.
|
95 |
+
Moreover, we have a plan to implement FlashAttention V2 to further boost our models, which will be updated as soon as possible.
|
96 |
+
|
97 |
## Languages covered
|
98 |
Germanic: English (en), German (de), Dutch; Flemish (nl), Swedish (sv), Danish (da), Afrikaans (af)
|
99 |
Romance: French (fr), Spanish (es), Italian (it), Portuguese (pt), Romanian; Moldavian; Moldovan (ro)
|
modeling_mitre.py
CHANGED
@@ -74,11 +74,11 @@ class MitreSdpaAttention(nn.Module):
|
|
74 |
attention_mask: Optional[torch.Tensor] = None,
|
75 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
76 |
"""
|
77 |
-
1. MitreModel
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
2.
|
82 |
"""
|
83 |
bsz, tgt_len, _ = hidden_states.size()
|
84 |
|
@@ -777,32 +777,33 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
|
|
777 |
):
|
778 |
"""
|
779 |
Inference with beam search.
|
780 |
-
This code is improved
|
781 |
-
There are
|
782 |
1. 'soft early_stop' in beam search.
|
783 |
a) problem in the vanilla version.
|
784 |
-
In multilingual translation
|
785 |
-
|
786 |
-
|
787 |
-
the ended sequence is fed into the model still, resulting in a heavy memory waste.
|
788 |
b) our improvement.
|
789 |
-
We
|
790 |
-
|
791 |
-
|
792 |
-
|
793 |
-
Based on our
|
794 |
2. mask reusing.
|
795 |
-
a) problem:
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
|
|
800 |
b) our improvement.
|
801 |
-
First, we turncate the source tokens to
|
802 |
-
Second,
|
803 |
-
|
804 |
-
Third,
|
805 |
-
|
|
|
806 |
"""
|
807 |
if generation_config != None:
|
808 |
assert type(generation_config) is GenerationConfig
|
|
|
74 |
attention_mask: Optional[torch.Tensor] = None,
|
75 |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
76 |
"""
|
77 |
+
1. MitreModel uses MitreSdpaAttention, which is modified from M2M100SdpaAttention.
|
78 |
+
Notably, neither of them supports 'output_attentions=True' or 'layer_head_mask is not None',
|
79 |
+
meaning that attn_weights are not included in the output.
|
80 |
+
Improving this feature is currently a low priority, and we leave this functionality for users to customize.
|
81 |
+
2.We plan to enhance this code with Flash Attention v2 in the future.
|
82 |
"""
|
83 |
bsz, tgt_len, _ = hidden_states.size()
|
84 |
|
|
|
777 |
):
|
778 |
"""
|
779 |
Inference with beam search.
|
780 |
+
This code is an improved version of transformers.generation.utils.GenerationMixin.generate.
|
781 |
+
There are two main improvements:
|
782 |
1. 'soft early_stop' in beam search.
|
783 |
a) problem in the vanilla version.
|
784 |
+
In multilingual translation models such as NLLB and M2M, the vanilla early stop in BeamSearchScorer
|
785 |
+
(the official implementation by HuggingFace) marks ended sequences with pad(1). However, these ended
|
786 |
+
sequences are still fed into the model, leading to significant memory waste.
|
|
|
787 |
b) our improvement.
|
788 |
+
We implemented a "soft early stop" to address this issue. Instead of modifying BeamSearchScorer
|
789 |
+
(to maintain code flexibility), we remove ended sequences from the input. Since this changes the
|
790 |
+
shape of the output hidden states, we insert placeholders to maintain compatibility with
|
791 |
+
BeamSearchScorer's state shapes.
|
792 |
+
Based on our tests, this improvement reduces memory usage by half.
|
793 |
2. mask reusing.
|
794 |
+
a) problem:
|
795 |
+
Registers require attention masks at each step.
|
796 |
+
A sequence may consist of four parts: padding, source tokens, registers, and target tokens.
|
797 |
+
During training, we mask all tokens before registers for target token generation. During generation,
|
798 |
+
we cannot allow target tokens to "see" padding tokens, requiring masks at every step.
|
799 |
+
This leads to computational inefficiency.
|
800 |
b) our improvement.
|
801 |
+
First, we turncate the source tokens and their representations to reduce cost.
|
802 |
+
Second, for source tokens acting as placeholders, we modified the mask generation logic compared to
|
803 |
+
our Fairseq implementation.
|
804 |
+
Third, to avoid regenerating masks at each step, we cache the mask in 'registering_cache', where cached
|
805 |
+
mask is managed like the key-value cache in beam search. Then, At every step, we add a column of zeros
|
806 |
+
to maintain alignment.
|
807 |
"""
|
808 |
if generation_config != None:
|
809 |
assert type(generation_config) is GenerationConfig
|