zhiqu22 commited on
Commit
1c61cca
·
1 Parent(s): 1d88f72
Files changed (2) hide show
  1. README.md +13 -7
  2. modeling_mitre.py +27 -26
README.md CHANGED
@@ -32,13 +32,13 @@ pipeline_tag: translation
32
  # MITRE 466M
33
 
34
  ## Description
35
- MITRE (multilingual translation with registers) is a multilingual decoder-only model trained for many-to-many translation.
36
  The technology, i.e., registering, is introduced in our [paper](url_placeholder).
37
  This repository allows you employ our pre-trained model for inference. If you want to reproduce the data mining and training, please refer to this [repository](url_placeholder).
38
 
39
- The model can directly translate between the 552 directions of 24 languages spanning more than 5 language families.
40
- You can directly use our models by `transformers` libs.
41
- MITRE has another version with 913M parameters, which can be found in this [repository](url_placeholder).
42
 
43
 
44
  ## Usages
@@ -65,11 +65,12 @@ After get the objects of the model and the tokenizer, we can do translation.
65
  ```python
66
  english_text = "I have a red apple."
67
  chinese_text = "我有一个红苹果。"
 
68
  model.eval()
69
 
70
- # Translating from one or several sentences to a sole language
71
  src_tokens = tokenizer.encode_source_tokens_to_input_ids([english_text, ], target_language="zh")
72
- # Translating from one or several sentences to corresponding languages
73
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
74
 
75
  generated_tokens = model.generate(src_tokens.cuda())
@@ -83,11 +84,16 @@ print(results)
83
  # 1. The difference between tgt_tokens and labels is that the eos_tokens are moved to the right side.
84
  # 2. We recommend using 'tokenizer.encode_target_tokens_to_labels' instead of modifying tgt_tokens,
85
  # because 'tokenizer.encode_target_tokens_to_input_ids' has pads.
86
- # 3. You can refer our codes to know the details in implementation.
87
  # tgt_tokens = tokenizer.encode_target_tokens_to_input_ids(chinese_text)
88
  # labels = tokenizer.encode_target_tokens_to_labels(chinese_text)
89
  ```
90
 
 
 
 
 
 
91
  ## Languages covered
92
  Germanic: English (en), German (de), Dutch; Flemish (nl), Swedish (sv), Danish (da), Afrikaans (af)
93
  Romance: French (fr), Spanish (es), Italian (it), Portuguese (pt), Romanian; Moldavian; Moldovan (ro)
 
32
  # MITRE 466M
33
 
34
  ## Description
35
+ MITRE (Multilingual Translation with Registers) is a multilingual, decoder-only model designed for many-to-many translation tasks.
36
  The technology, i.e., registering, is introduced in our [paper](url_placeholder).
37
  This repository allows you employ our pre-trained model for inference. If you want to reproduce the data mining and training, please refer to this [repository](url_placeholder).
38
 
39
+ The model supports direct translation across 552 directions for 24 languages spanning over 5 language families.
40
+ You can use our models directly via the `transformers` libs.
41
+ An alternative version of MITRE with 913M parameters is also available in this [repository](url_placeholder).
42
 
43
 
44
  ## Usages
 
65
  ```python
66
  english_text = "I have a red apple."
67
  chinese_text = "我有一个红苹果。"
68
+ model.half() # recommended
69
  model.eval()
70
 
71
+ # Translating from one or several sentences to a 'target_language'
72
  src_tokens = tokenizer.encode_source_tokens_to_input_ids([english_text, ], target_language="zh")
73
+ # Translating from one or several sentences to given languages
74
  # src_tokens = tokenizer.encode_source_tokens_to_input_ids_with_different_tags([english_text, english_text, ], target_languages_list=["de", "zh", ])
75
 
76
  generated_tokens = model.generate(src_tokens.cuda())
 
84
  # 1. The difference between tgt_tokens and labels is that the eos_tokens are moved to the right side.
85
  # 2. We recommend using 'tokenizer.encode_target_tokens_to_labels' instead of modifying tgt_tokens,
86
  # because 'tokenizer.encode_target_tokens_to_input_ids' has pads.
87
+ # 3. You can refer to our code for detailed implementation.
88
  # tgt_tokens = tokenizer.encode_target_tokens_to_input_ids(chinese_text)
89
  # labels = tokenizer.encode_target_tokens_to_labels(chinese_text)
90
  ```
91
 
92
+ ## Notes
93
+ We basically follow the style of [M2M](https://huggingface.co/facebook/m2m100_418M), however, we make some necessary improvements to reduce cost in generation.
94
+ You can refer to the codes of 'generate()' in [modeling_mitre.py](https://huggingface.co/naist-nlp/mitre_466m/blob/main/modeling_mitre.py) for much more details.
95
+ Moreover, we have a plan to implement FlashAttention V2 to further boost our models, which will be updated as soon as possible.
96
+
97
  ## Languages covered
98
  Germanic: English (en), German (de), Dutch; Flemish (nl), Swedish (sv), Danish (da), Afrikaans (af)
99
  Romance: French (fr), Spanish (es), Italian (it), Portuguese (pt), Romanian; Moldavian; Moldovan (ro)
modeling_mitre.py CHANGED
@@ -74,11 +74,11 @@ class MitreSdpaAttention(nn.Module):
74
  attention_mask: Optional[torch.Tensor] = None,
75
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
76
  """
77
- 1. MitreModel is using MitreSdpaAttention, which is modifed from M2M100SdpaAttention.
78
- Notabley, both of them do not support `output_attentions=True` or `layer_head_mask` not None,
79
- leading to 'attn_weights' always being None in output.
80
- The plan of improving this point has a low priority.
81
- 2. We plan to improve this code with Flash Attention v2.
82
  """
83
  bsz, tgt_len, _ = hidden_states.size()
84
 
@@ -777,32 +777,33 @@ class MitreForConditionalGeneration(MitrePreTrainedModel, GenerationMixin):
777
  ):
778
  """
779
  Inference with beam search.
780
- This code is improved from 'transformers.generation.utils.GenerationMixin.generate'.
781
- There are **two main improved points**:
782
  1. 'soft early_stop' in beam search.
783
  a) problem in the vanilla version.
784
- In multilingual translation model, e.g., NLLB and M2M, they adopt the 'vanilla early_
785
- stop' in BeamSearchScorer (the official implementation provided by HuggingFace), i.e.,
786
- the sequence, which is labled by 'end', is filled by 'pad(1)' still, in other words,
787
- the ended sequence is fed into the model still, resulting in a heavy memory waste.
788
  b) our improvement.
789
- We implement soft early_stop to resolve the problem. Specifically, we do not change
790
- anything in BeamSearchScorer to keep the codes' flexibility, rather we remove the ended
791
- sequence from the input. Then, given that the output hidden states' shape is changed,
792
- we insert some placeholders to keep the shape of BeamSearchScorer's states.
793
- Based on our test, this improvement can decrease the memory cost to half than before.
794
  2. mask reusing.
795
- a) problem: registers need attention masks in each step.
796
- A sequence possibly consists 4 parts, i.e., pads, source tokens, registers, and target
797
- tokens. In training, we mask all tokens before registers for the generation of target
798
- tokens. As a result, in generation, we cannot allow the target tokens to 'see' pads.
799
- So, we need masks in each step, leading to computational resource waste.
 
800
  b) our improvement.
801
- First, we turncate the source tokens to save cost.
802
- Second, given that there still exists some source tokens playing the role of placeholders,
803
- we modify the mask generation compared to our codes in fairseq.
804
- Third, in order to avoid re-generating masks, we add the mask into 'registering_cache'.
805
- Then, we manage its order as the kv cache in beam search, and add a column of 0. every step.
 
806
  """
807
  if generation_config != None:
808
  assert type(generation_config) is GenerationConfig
 
74
  attention_mask: Optional[torch.Tensor] = None,
75
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
76
  """
77
+ 1. MitreModel uses MitreSdpaAttention, which is modified from M2M100SdpaAttention.
78
+ Notably, neither of them supports 'output_attentions=True' or 'layer_head_mask is not None',
79
+ meaning that attn_weights are not included in the output.
80
+ Improving this feature is currently a low priority, and we leave this functionality for users to customize.
81
+ 2.We plan to enhance this code with Flash Attention v2 in the future.
82
  """
83
  bsz, tgt_len, _ = hidden_states.size()
84
 
 
777
  ):
778
  """
779
  Inference with beam search.
780
+ This code is an improved version of transformers.generation.utils.GenerationMixin.generate.
781
+ There are two main improvements:
782
  1. 'soft early_stop' in beam search.
783
  a) problem in the vanilla version.
784
+ In multilingual translation models such as NLLB and M2M, the vanilla early stop in BeamSearchScorer
785
+ (the official implementation by HuggingFace) marks ended sequences with pad(1). However, these ended
786
+ sequences are still fed into the model, leading to significant memory waste.
 
787
  b) our improvement.
788
+ We implemented a "soft early stop" to address this issue. Instead of modifying BeamSearchScorer
789
+ (to maintain code flexibility), we remove ended sequences from the input. Since this changes the
790
+ shape of the output hidden states, we insert placeholders to maintain compatibility with
791
+ BeamSearchScorer's state shapes.
792
+ Based on our tests, this improvement reduces memory usage by half.
793
  2. mask reusing.
794
+ a) problem:
795
+ Registers require attention masks at each step.
796
+ A sequence may consist of four parts: padding, source tokens, registers, and target tokens.
797
+ During training, we mask all tokens before registers for target token generation. During generation,
798
+ we cannot allow target tokens to "see" padding tokens, requiring masks at every step.
799
+ This leads to computational inefficiency.
800
  b) our improvement.
801
+ First, we turncate the source tokens and their representations to reduce cost.
802
+ Second, for source tokens acting as placeholders, we modified the mask generation logic compared to
803
+ our Fairseq implementation.
804
+ Third, to avoid regenerating masks at each step, we cache the mask in 'registering_cache', where cached
805
+ mask is managed like the key-value cache in beam search. Then, At every step, we add a column of zeros
806
+ to maintain alignment.
807
  """
808
  if generation_config != None:
809
  assert type(generation_config) is GenerationConfig