farzadab commited on
Commit
a4cd5e8
·
verified ·
1 Parent(s): d11095a

Upload 5 files

Browse files
preprocessor_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "chunk_length": 30,
3
+ "feature_extractor_type": "WhisperFeatureExtractor",
4
+ "feature_size": 80,
5
+ "hop_length": 160,
6
+ "n_fft": 400,
7
+ "n_samples": 480000,
8
+ "nb_max_frames": 3000,
9
+ "padding_side": "right",
10
+ "padding_value": 0.0,
11
+ "processor_class": "WhisperProcessor",
12
+ "return_attention_mask": false,
13
+ "sampling_rate": 16000
14
+ }
processor_config.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "audio_context_size": 3000,
3
+ "audio_padding": "longest",
4
+ "audio_placeholder": "<|audio|>",
5
+ "auto_map": {
6
+ "AutoProcessor": "ultravox_processing.UltravoxProcessor"
7
+ },
8
+ "encoder_ds_factor": 2,
9
+ "processor_class": "UltravoxProcessor",
10
+ "stack_factor": 8
11
+ }
ultravox_config.py CHANGED
@@ -19,6 +19,8 @@ class LoraConfigSimplified:
19
  target_modules: Optional[List[str]] = dataclasses.field(
20
  default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
21
  )
 
 
22
 
23
 
24
  class LossFunction(str, Enum):
@@ -28,8 +30,10 @@ class LossFunction(str, Enum):
28
 
29
  @dataclasses.dataclass
30
  class LossConfig:
31
- loss_function: LossFunction = LossFunction.KL_Divergence
32
  kl_temperature: float = 2.0
 
 
33
 
34
  @property
35
  def requires_alt_fields(self):
@@ -45,7 +49,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
45
  documentation from [`PretrainedConfig`] for more information.
46
 
47
  Args:
48
- audio_config (`Wav2Vec2Config`, *optional*):
49
  Custom audio config or dict
50
  text_config (`Union[AutoConfig, dict]`, *optional*):
51
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
@@ -63,15 +67,17 @@ class UltravoxConfig(transformers.PretrainedConfig):
63
  The LoRA configuration for finetuning the text model.
64
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
65
  The LoRA configuration for finetuning the audio model.
 
 
66
 
67
 
68
  Example:
69
 
70
  ```python
71
- >>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
72
 
73
  >>> # Initializing an audio encoder config
74
- >>> audio_config = Wav2Vec2Config()
75
 
76
  >>> # Initializing a Llama config
77
  >>> text_config = LlamaConfig()
@@ -80,13 +86,13 @@ class UltravoxConfig(transformers.PretrainedConfig):
80
  >>> configuration = UltravoxConfig(audio_config, text_config)
81
 
82
  >>> # Initializing a completely untrained model from the configuration
83
- >>> model = UltravoxForConditionalGeneration(configuration)
84
 
85
  >>> # Accessing the model configuration
86
  >>> configuration = model.config
87
 
88
  >>> # Initialize a model from pretrained checkpoints and random projector weights
89
- >>> config = UltravoxConfig(audio_model_id="facebook/wav2vec2-base-960h", text_model_id="meta-llama/Llama-2-7b-chat-hf")
90
  ```"""
91
 
92
  model_type = "ultravox"
@@ -99,26 +105,26 @@ class UltravoxConfig(transformers.PretrainedConfig):
99
  audio_model_id: Optional[str] = None,
100
  text_model_id: Optional[str] = None,
101
  ignore_index: int = -100,
102
- audio_token_index: int = 32000,
103
  hidden_size: int = 4096,
104
  stack_factor: int = 8,
105
  norm_init: float = 0.4,
106
  projector_act: str = "swiglu",
 
107
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
108
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
 
109
  **kwargs,
110
  ):
111
  self.ignore_index = ignore_index
112
 
113
  self.audio_model_id = audio_model_id
114
  self.text_model_id = text_model_id
115
- self.audio_token_index = audio_token_index
116
 
117
  self.hidden_size = hidden_size
118
  self.stack_factor = stack_factor
119
  self.norm_init = norm_init
120
  self.projector_act = projector_act
121
-
122
  if text_model_id is not None:
123
  self.text_config: transformers.LlamaConfig = (
124
  transformers.AutoConfig.from_pretrained(text_model_id)
@@ -136,7 +142,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
136
  else:
137
  audio_config = audio_config or {}
138
  self.audio_config = transformers.CONFIG_MAPPING[
139
- audio_config.get("model_type", "wav2vec2")
140
  ](**audio_config)
141
 
142
  self.text_model_lora_config = (
@@ -149,9 +155,26 @@ class UltravoxConfig(transformers.PretrainedConfig):
149
  if isinstance(audio_model_lora_config, dict)
150
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
151
  )
 
152
 
153
  self.vocab_size = self.text_config.vocab_size
154
 
155
  self.initializer_range = self.text_config.initializer_range
156
 
157
  super().__init__(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  target_modules: Optional[List[str]] = dataclasses.field(
20
  default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
21
  )
22
+ # A list of module names regex patterns to unfreeze. Only used if r == 0.
23
+ unfreeze_layers: Optional[List[str]] = None
24
 
25
 
26
  class LossFunction(str, Enum):
 
30
 
31
  @dataclasses.dataclass
32
  class LossConfig:
33
+ loss_function: LossFunction = LossFunction.CrossEntropy
34
  kl_temperature: float = 2.0
35
+ # Number of tokens to ignore from the beginning of the sequence. Only used in LSM
36
+ initial_tokens_to_ignore: int = 0
37
 
38
  @property
39
  def requires_alt_fields(self):
 
49
  documentation from [`PretrainedConfig`] for more information.
50
 
51
  Args:
52
+ audio_config (`WhisperConfig`, *optional*):
53
  Custom audio config or dict
54
  text_config (`Union[AutoConfig, dict]`, *optional*):
55
  The config object of the text backbone. Can be any of `LlamaConfig` or `MistralConfig`.
 
67
  The LoRA configuration for finetuning the text model.
68
  audio_model_lora_config (`LoraConfigSimplified`, *optional*):
69
  The LoRA configuration for finetuning the audio model.
70
+ audio_latency_block_size (`int`, *optional*, defaults to `None`):
71
+ The latency block size for simulating audio streaming.
72
 
73
 
74
  Example:
75
 
76
  ```python
77
+ >>> from transformers import UltravoxModel, WhisperConfig, UltravoxConfig, LlamaConfig
78
 
79
  >>> # Initializing an audio encoder config
80
+ >>> audio_config = WhisperConfig()
81
 
82
  >>> # Initializing a Llama config
83
  >>> text_config = LlamaConfig()
 
86
  >>> configuration = UltravoxConfig(audio_config, text_config)
87
 
88
  >>> # Initializing a completely untrained model from the configuration
89
+ >>> model = UltravoxModel(configuration)
90
 
91
  >>> # Accessing the model configuration
92
  >>> configuration = model.config
93
 
94
  >>> # Initialize a model from pretrained checkpoints and random projector weights
95
+ >>> config = UltravoxConfig(audio_model_id="openai/whisper-tiny", text_model_id="meta-llama/Llama-2-7b-chat-hf")
96
  ```"""
97
 
98
  model_type = "ultravox"
 
105
  audio_model_id: Optional[str] = None,
106
  text_model_id: Optional[str] = None,
107
  ignore_index: int = -100,
 
108
  hidden_size: int = 4096,
109
  stack_factor: int = 8,
110
  norm_init: float = 0.4,
111
  projector_act: str = "swiglu",
112
+ projector_ln_mid: bool = False, # defaults to False for compatibility with v0.4.1 and below
113
  text_model_lora_config: Optional[LoraConfigSimplified] = None,
114
  audio_model_lora_config: Optional[LoraConfigSimplified] = None,
115
+ audio_latency_block_size: Optional[int] = None,
116
  **kwargs,
117
  ):
118
  self.ignore_index = ignore_index
119
 
120
  self.audio_model_id = audio_model_id
121
  self.text_model_id = text_model_id
 
122
 
123
  self.hidden_size = hidden_size
124
  self.stack_factor = stack_factor
125
  self.norm_init = norm_init
126
  self.projector_act = projector_act
127
+ self.projector_ln_mid = projector_ln_mid
128
  if text_model_id is not None:
129
  self.text_config: transformers.LlamaConfig = (
130
  transformers.AutoConfig.from_pretrained(text_model_id)
 
142
  else:
143
  audio_config = audio_config or {}
144
  self.audio_config = transformers.CONFIG_MAPPING[
145
+ audio_config.get("model_type", "whisper")
146
  ](**audio_config)
147
 
148
  self.text_model_lora_config = (
 
155
  if isinstance(audio_model_lora_config, dict)
156
  else dataclasses.asdict(audio_model_lora_config or LoraConfigSimplified())
157
  )
158
+ self.audio_latency_block_size = audio_latency_block_size
159
 
160
  self.vocab_size = self.text_config.vocab_size
161
 
162
  self.initializer_range = self.text_config.initializer_range
163
 
164
  super().__init__(**kwargs)
165
+
166
+ def to_diff_dict(self) -> Dict[str, Any]:
167
+ diff_dict = super().to_diff_dict()
168
+
169
+ # remove text_config and audio_config if text_model_id and audio_model_id are present
170
+ if self.text_model_id is not None:
171
+ diff_dict.pop("text_config", None)
172
+ elif "text_config" in diff_dict:
173
+ diff_dict["text_config"].pop("_attn_implementation_autoset", None)
174
+
175
+ if self.audio_model_id is not None:
176
+ diff_dict.pop("audio_config", None)
177
+ elif "audio_config" in diff_dict:
178
+ diff_dict["audio_config"].pop("_attn_implementation_autoset", None)
179
+
180
+ return diff_dict
ultravox_model.py CHANGED
@@ -1,5 +1,6 @@
1
  import logging
2
- from typing import Any, Dict, Optional, Set, Tuple, Union
 
3
 
4
  import peft
5
  import torch
@@ -9,6 +10,7 @@ import transformers
9
  import transformers.activations
10
  import transformers.modeling_outputs
11
  import transformers.models
 
12
  from transformers.models.whisper import modeling_whisper as whisper
13
 
14
  # We must use relative import in this directory to allow uploading to HF Hub
@@ -18,7 +20,7 @@ from .ultravox_config import LossFunction
18
  from .ultravox_config import UltravoxConfig
19
 
20
 
21
- class UltravoxModel(transformers.LlamaPreTrainedModel):
22
  """
23
  The Ultravox model which consists of an audio encoder and a language model.
24
 
@@ -34,26 +36,31 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
34
 
35
  config_class = UltravoxConfig
36
  config: UltravoxConfig # for type hinting
37
- _no_split_modules = ["Wav2Vec2Model", "WhisperEncoder", "LlamaDecoderLayer"]
38
- # We minimize the weights in state_dict in order to reduce the size of the checkpoint
39
- # The issue is that load_pretrained() uses state_dict() keys to know what keys are expected
40
- # As such we have to tell is to ignore some keys that are not always in the model
41
- _keys_to_ignore_on_load_unexpected = ["audio_tower.*", "language_model.*"]
42
- # Usually we load encoder weights from a pretrained model, so we don't want to load the decoder weights
43
- # Technically we never hit this issue because these keys are already removed from state_dict() however,
44
- # but there's no harm in keeping it here for when we change that behavior.
45
- _keys_to_ignore_on_load_missing = ["audio_tower.*"]
46
 
47
  def __init__(self, config: UltravoxConfig):
48
  super().__init__(config)
 
49
 
50
  self.keep_params: Set[str] = set()
51
  self.vocab_size = config.vocab_size
52
 
53
  self.audio_tower = self._create_audio_tower(config)
54
- self.multi_modal_projector = UltravoxProjector(config)
 
 
 
55
  self.language_model = self._create_language_model(config)
56
 
 
 
 
 
 
57
  self.loss_config = LossConfig()
58
  self.post_init()
59
 
@@ -139,6 +146,24 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
139
  )
140
  return {"loss": kl_loss}
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  def forward(
143
  self,
144
  input_ids: torch.Tensor,
@@ -147,7 +172,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
147
  labels: Optional[torch.Tensor] = None,
148
  attention_mask: Optional[torch.Tensor] = None,
149
  audio_token_start_idx: Optional[torch.Tensor] = None,
 
150
  audio_token_len: Optional[torch.Tensor] = None,
 
151
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
152
  # the alt_* fields are needed for KL divergence loss
153
  alt_input_ids: Optional[torch.Tensor] = None,
@@ -178,28 +205,37 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
178
  # B x T -> B x T x D
179
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
180
 
181
- if audio_values is not None:
182
  assert (
183
- audio_token_start_idx is not None and audio_token_len is not None
184
- ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
 
 
 
185
  assert (
186
- len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
187
- ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
188
-
189
- # B x A/3200 x D
 
 
 
 
 
 
190
  audio_tower_output = self.audio_tower.forward(
191
- audio_values
 
192
  ).last_hidden_state
193
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
194
-
195
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
196
 
197
  # combine audio and text embeddings
198
- for i, (audio, start, length) in enumerate(
199
- zip(audio_embeds, audio_token_start_idx, audio_token_len)
200
- ):
201
- length = min(length, audio.shape[0])
202
- inputs_embeds[i, start : start + length] = audio[:length]
203
 
204
  lm_output = self.language_model.forward(
205
  inputs_embeds=inputs_embeds,
@@ -234,6 +270,8 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
234
  audio_values: Optional[torch.FloatTensor] = None,
235
  audio_token_start_idx: Optional[torch.Tensor] = None,
236
  audio_token_len: Optional[torch.Tensor] = None,
 
 
237
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
238
  attention_mask: Optional[torch.Tensor] = None,
239
  inputs_embeds: Optional[torch.Tensor] = None,
@@ -262,26 +300,50 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
262
  audio_token_start_idx - prefill_start_idx
263
  )
264
  model_input["audio_token_len"] = audio_token_len
 
 
265
 
266
  return model_input
267
 
 
 
 
 
 
 
 
 
268
  @classmethod
269
  def _create_audio_tower(
270
  cls, config: UltravoxConfig
271
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
272
  if config.audio_model_id is not None:
273
- if "whisper" in config.audio_model_id is not None:
274
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
275
- config.audio_model_id
 
 
 
276
  )
277
  else:
 
 
 
 
278
  audio_tower = transformers.AutoModel.from_pretrained(
279
- config.audio_model_id
280
  )
281
  else:
282
- if "whisper" in config.audio_config._name_or_path:
283
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
 
 
 
284
  else:
 
 
 
 
285
  with transformers.modeling_utils.no_init_weights():
286
  # we only ever use from_config if the weights are retrained, hence initializing is not
287
  # required. This makes the model quite creation faster since init on CPU is quite slow.
@@ -307,21 +369,27 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
307
  ) -> transformers.LlamaForCausalLM:
308
  if config.text_model_id is not None:
309
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
310
- config.text_model_id, attn_implementation=config._attn_implementation
 
 
311
  )
312
  else:
313
  with transformers.modeling_utils.no_init_weights():
314
  # we only ever use from_config if the weights are retrained, hence initializing is not
315
  # required. This makes the model quite creation faster since init on CPU is quite slow.
316
  language_model = transformers.AutoModelForCausalLM.from_config(
317
- config.text_config, attn_implementation=config._attn_implementation
 
 
318
  )
319
 
320
  language_model = apply_lora(language_model, config.text_model_lora_config)
321
  return language_model
322
 
323
- def _add_language_model_weights_to_keep(self):
324
- if self.config.text_model_id is not None:
 
 
325
  self.config.text_model_id = None
326
  self.keep_params.update(
327
  set(
@@ -332,8 +400,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
332
  )
333
  )
334
 
335
- def _add_audio_tower_weights_to_keep(self):
336
- if self.config.audio_model_id is not None:
 
337
  self.config.audio_model_id = None
338
  self.keep_params.update(
339
  set(
@@ -344,46 +413,44 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
344
  )
345
  )
346
 
347
- def merge_and_unload(self):
348
- if isinstance(self.language_model, peft.PeftModel):
349
- self.language_model = self.language_model.merge_and_unload()
350
- # no need to download base language model weights anymore, so we can remove the id
351
- self._add_language_model_weights_to_keep()
352
-
353
- if isinstance(self.audio_tower, peft.PeftModel):
354
- self.audio_tower = self.audio_tower.merge_and_unload()
355
- # no need to download base audio model weights anymore, so we can remove the id
356
- self._add_audio_tower_weights_to_keep()
357
-
358
  for param in ["text_model_lora_config", "audio_model_lora_config"]:
359
  if hasattr(self.config, param):
360
  delattr(self.config, param)
361
 
362
  def push_to_hub(self, *args, **kwargs):
363
  self.merge_and_unload()
364
- self.to(self.language_model.dtype)
365
  return super().push_to_hub(*args, **kwargs)
366
 
367
- def state_dict(self, *args, **kwargs):
368
- named_params = dict(self.named_parameters())
369
- state_dict = super().state_dict(*args, **kwargs)
 
 
 
 
 
 
 
 
 
370
 
371
  state_dict = {
372
  k: v
373
  for k, v in state_dict.items()
374
- if k in self.keep_params
375
- or (k in named_params and named_params[k].requires_grad)
376
  }
 
377
  return state_dict
378
 
379
- def load_state_dict(
380
- self,
381
- state_dict: Dict[str, Any],
382
- *args,
383
- **kwargs,
384
  ):
 
 
 
 
 
385
  self.keep_params.update(set(state_dict.keys()))
386
- return super().load_state_dict(state_dict, *args, **kwargs)
387
 
388
  def print_trainable_parameters(self):
389
  """
@@ -414,8 +481,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
414
  )
415
 
416
 
 
417
  def is_cache_empty(
418
- past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
419
  ) -> bool:
420
  """
421
  Check if the cache is empty.
@@ -431,12 +499,18 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
431
  """
432
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
433
  """
 
434
  lora_config = peft.LoraConfig(**lora_config or {})
435
 
436
  if lora_config.r == 0:
437
- # freeze the model entirely
438
- for param in model.parameters():
439
- param.requires_grad = False
 
 
 
 
 
440
  else:
441
  model = peft.get_peft_model(model, lora_config)
442
 
@@ -445,12 +519,8 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
445
 
446
  class StackAudioFrames(nn.Module):
447
  """
448
- Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
449
-
450
- The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
451
- NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
452
- we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
453
- In most cases this extra padding will get removed in the model's forward function so it has no effect.
454
  """
455
 
456
  def __init__(self, stack_factor: int = 8):
@@ -460,7 +530,7 @@ class StackAudioFrames(nn.Module):
460
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
461
  B, T, C = audio_embeds.shape
462
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
463
- audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
464
  B, T, C = audio_embeds.shape
465
  audio_embeds = audio_embeds.view(
466
  B, T // self.stack_factor, C * self.stack_factor
@@ -480,31 +550,43 @@ class SwiGLU(nn.Module):
480
  return F.silu(gate) * x
481
 
482
 
483
- class UltravoxProjector(nn.Sequential):
484
  def __init__(self, config: UltravoxConfig):
485
  super().__init__()
486
  self.hidden_dim = config.hidden_size
487
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
488
- dim = config.audio_config.hidden_size * config.stack_factor
489
- self.ln_pre = RMSNorm(dim, init=config.norm_init)
490
- self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
491
- dim = self.hidden_dim
492
  self.act = transformers.activations.get_activation(config.projector_act)
493
- dim = dim // 2 if config.projector_act == "swiglu" else dim
494
- self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
495
- self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
 
 
 
 
 
 
 
 
 
496
 
497
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
498
  audio_features = self._pad_and_stack(audio_features)
499
  audio_features = self.ln_pre(audio_features)
500
  hidden_states = self.linear_1(audio_features)
501
  hidden_states = self.act(hidden_states)
 
502
  hidden_states = self.linear_2(hidden_states)
503
  hidden_states = self.ln_post(hidden_states)
504
  return hidden_states
505
 
506
 
507
- class ModifiedWhisperEncoder(whisper.WhisperEncoder):
 
 
508
  """
509
  Encoder portion of OpenAI's Whisper model.
510
 
@@ -518,21 +600,59 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
518
  """
519
 
520
  base_model_prefix = "model.encoder"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
 
522
  def forward(
523
  self,
524
  input_features,
525
- attention_mask=None,
526
  head_mask=None,
527
  output_attentions=None,
528
  output_hidden_states=None,
529
  return_dict=None,
530
  ):
531
- expected_seq_length = (
532
- self.config.max_source_positions
533
- * self.conv1.stride[0]
534
- * self.conv2.stride[0]
535
- )
536
  if input_features.shape[-1] > expected_seq_length:
537
  raise ValueError(
538
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
@@ -565,6 +685,37 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
565
  encoder_states = () if output_hidden_states else None
566
  all_attentions = () if output_attentions else None
567
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
  # check if head_mask has a correct number of layers specified if desired
569
  if head_mask is not None:
570
  assert head_mask.size()[0] == (
@@ -588,14 +739,14 @@ class ModifiedWhisperEncoder(whisper.WhisperEncoder):
588
  layer_outputs = self._gradient_checkpointing_func(
589
  encoder_layer.__call__,
590
  hidden_states,
591
- None,
592
  (head_mask[idx] if head_mask is not None else None),
593
  output_attentions,
594
  )
595
  else:
596
  layer_outputs = encoder_layer(
597
  hidden_states,
598
- None,
599
  layer_head_mask=(
600
  head_mask[idx] if head_mask is not None else None
601
  ),
@@ -629,6 +780,5 @@ UltravoxModel.register_for_auto_class()
629
 
630
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
631
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
632
- # transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor) # TODO: make processor work standalone
633
 
634
  transformers.activations.ACT2FN["swiglu"] = SwiGLU
 
1
  import logging
2
+ import re
3
+ from typing import Any, Dict, Generator, Optional, Set, Tuple, Union
4
 
5
  import peft
6
  import torch
 
10
  import transformers.activations
11
  import transformers.modeling_outputs
12
  import transformers.models
13
+ from transformers.generation.utils import GenerationMixin
14
  from transformers.models.whisper import modeling_whisper as whisper
15
 
16
  # We must use relative import in this directory to allow uploading to HF Hub
 
20
  from .ultravox_config import UltravoxConfig
21
 
22
 
23
+ class UltravoxModel(transformers.LlamaPreTrainedModel, GenerationMixin):
24
  """
25
  The Ultravox model which consists of an audio encoder and a language model.
26
 
 
36
 
37
  config_class = UltravoxConfig
38
  config: UltravoxConfig # for type hinting
39
+ # Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
40
+ _keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
41
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
42
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
43
+ accepts_loss_kwargs = False
 
 
 
 
44
 
45
  def __init__(self, config: UltravoxConfig):
46
  super().__init__(config)
47
+ self._register_load_state_dict_pre_hook(self._pre_load_state_dict_hook)
48
 
49
  self.keep_params: Set[str] = set()
50
  self.vocab_size = config.vocab_size
51
 
52
  self.audio_tower = self._create_audio_tower(config)
53
+ self.audio_tower_context_length: Optional[int] = None
54
+ self.audio_tower_context_length = self.audio_tower.max_context_length
55
+
56
+ self.multi_modal_projector = self._create_multi_modal_projector(config)
57
  self.language_model = self._create_language_model(config)
58
 
59
+ # Determine no_split_modules dynamically to use with FSDP auto_wrap policy.
60
+ # FSDP throws an error if some of the layer types are not found in the model.
61
+ # This would be something like ["LlamaDecoderLayer"] as we don't split audio encoder layers.
62
+ self._no_split_modules = self.language_model._no_split_modules
63
+
64
  self.loss_config = LossConfig()
65
  self.post_init()
66
 
 
146
  )
147
  return {"loss": kl_loss}
148
 
149
+ def _audio_iter(
150
+ self, audio_batch_size: torch.Tensor
151
+ ) -> Generator[Tuple[int, int], None, None]:
152
+ """
153
+ Iterate over the audio batch size and yield the batch index and audio index of each audio item.
154
+
155
+ Args:
156
+ audio_batch_size: A tensor of shape (B,) where B is the batch size.
157
+
158
+ Returns:
159
+ A generator that yields a tuple of (start index, length) for each audio item.
160
+ """
161
+ audio_index = 0
162
+ for i_b, batch_count in enumerate(audio_batch_size):
163
+ for _ in range(batch_count):
164
+ yield i_b, audio_index
165
+ audio_index += 1
166
+
167
  def forward(
168
  self,
169
  input_ids: torch.Tensor,
 
172
  labels: Optional[torch.Tensor] = None,
173
  attention_mask: Optional[torch.Tensor] = None,
174
  audio_token_start_idx: Optional[torch.Tensor] = None,
175
+ audio_lens: Optional[torch.Tensor] = None,
176
  audio_token_len: Optional[torch.Tensor] = None,
177
+ audio_batch_size: Optional[torch.Tensor] = None,
178
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
179
  # the alt_* fields are needed for KL divergence loss
180
  alt_input_ids: Optional[torch.Tensor] = None,
 
205
  # B x T -> B x T x D
206
  inputs_embeds = self.get_input_embeddings().forward(input_ids)
207
 
208
+ if audio_values is not None and len(audio_values) > 0:
209
  assert (
210
+ audio_token_start_idx is not None
211
+ and audio_token_len is not None
212
+ and audio_lens is not None
213
+ and audio_batch_size is not None
214
+ ), "audio_token_start_idx/audio_token_len/audio_lens must be provided if audio_values are provided."
215
  assert (
216
+ len(audio_token_start_idx)
217
+ == len(audio_token_len)
218
+ == len(audio_lens)
219
+ == len(audio_values)
220
+ ), "audio_token_start_idx/audio_token_len/audio_lens/audio_values must have the same batch size."
221
+ assert len(audio_batch_size) == len(
222
+ inputs_embeds
223
+ ), "audio_batch_size and inputs_embeds must have the same batch size."
224
+
225
+ # B x A/3200 x (D=max-audio-length-in-batch)
226
  audio_tower_output = self.audio_tower.forward(
227
+ audio_values.to(self.audio_tower.dtype),
228
+ audio_len=audio_lens,
229
  ).last_hidden_state
230
  audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
 
231
  audio_embeds = self.multi_modal_projector.forward(audio_tower_output)
232
 
233
  # combine audio and text embeddings
234
+ for i_b, i_a in self._audio_iter(audio_batch_size):
235
+ start_idx = audio_token_start_idx[i_a]
236
+ token_len = audio_token_len[i_a]
237
+ item_embedding = audio_embeds[i_a][:token_len]
238
+ inputs_embeds[i_b][start_idx : start_idx + token_len] = item_embedding
239
 
240
  lm_output = self.language_model.forward(
241
  inputs_embeds=inputs_embeds,
 
270
  audio_values: Optional[torch.FloatTensor] = None,
271
  audio_token_start_idx: Optional[torch.Tensor] = None,
272
  audio_token_len: Optional[torch.Tensor] = None,
273
+ audio_lens: Optional[torch.Tensor] = None,
274
+ audio_batch_size: Optional[torch.Tensor] = None,
275
  past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]] = None,
276
  attention_mask: Optional[torch.Tensor] = None,
277
  inputs_embeds: Optional[torch.Tensor] = None,
 
300
  audio_token_start_idx - prefill_start_idx
301
  )
302
  model_input["audio_token_len"] = audio_token_len
303
+ model_input["audio_batch_size"] = audio_batch_size
304
+ model_input["audio_lens"] = audio_lens
305
 
306
  return model_input
307
 
308
+ @classmethod
309
+ def _create_multi_modal_projector(
310
+ cls, config: UltravoxConfig
311
+ ) -> "UltravoxProjector":
312
+ projector = UltravoxProjector(config)
313
+ projector.to(config.torch_dtype)
314
+ return projector
315
+
316
  @classmethod
317
  def _create_audio_tower(
318
  cls, config: UltravoxConfig
319
  ) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
320
  if config.audio_model_id is not None:
321
+ if "whisper" in config.audio_model_id.lower():
322
  audio_tower = ModifiedWhisperEncoder.from_pretrained(
323
+ config.audio_model_id, torch_dtype=config.torch_dtype
324
+ )
325
+ audio_tower.init_latency_mask(
326
+ config.audio_latency_block_size, dtype=config.torch_dtype
327
  )
328
  else:
329
+ assert config.audio_latency_block_size in (
330
+ None,
331
+ 0,
332
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
333
  audio_tower = transformers.AutoModel.from_pretrained(
334
+ config.audio_model_id, torch_dtype=config.torch_dtype
335
  )
336
  else:
337
+ if "whisper" in config.audio_config._name_or_path.lower():
338
  audio_tower = ModifiedWhisperEncoder(config.audio_config)
339
+ audio_tower.init_latency_mask(
340
+ config.audio_latency_block_size, dtype=config.torch_dtype
341
+ )
342
  else:
343
+ assert config.audio_latency_block_size in (
344
+ None,
345
+ 0,
346
+ ), "only whisper audio tower supports audio latency masking, got non-zero value for 'audio_latency_block_size'"
347
  with transformers.modeling_utils.no_init_weights():
348
  # we only ever use from_config if the weights are retrained, hence initializing is not
349
  # required. This makes the model quite creation faster since init on CPU is quite slow.
 
369
  ) -> transformers.LlamaForCausalLM:
370
  if config.text_model_id is not None:
371
  language_model = transformers.AutoModelForCausalLM.from_pretrained(
372
+ config.text_model_id,
373
+ attn_implementation=config._attn_implementation,
374
+ torch_dtype=config.torch_dtype,
375
  )
376
  else:
377
  with transformers.modeling_utils.no_init_weights():
378
  # we only ever use from_config if the weights are retrained, hence initializing is not
379
  # required. This makes the model quite creation faster since init on CPU is quite slow.
380
  language_model = transformers.AutoModelForCausalLM.from_config(
381
+ config.text_config,
382
+ attn_implementation=config._attn_implementation,
383
+ torch_dtype=config.torch_dtype,
384
  )
385
 
386
  language_model = apply_lora(language_model, config.text_model_lora_config)
387
  return language_model
388
 
389
+ def merge_and_unload(self):
390
+ if isinstance(self.language_model, peft.PeftModel):
391
+ self.language_model = self.language_model.merge_and_unload()
392
+ # no need to download base language model weights anymore, so we can remove the id
393
  self.config.text_model_id = None
394
  self.keep_params.update(
395
  set(
 
400
  )
401
  )
402
 
403
+ if isinstance(self.audio_tower, peft.PeftModel):
404
+ self.audio_tower = self.audio_tower.merge_and_unload()
405
+ # no need to download base audio model weights anymore, so we can remove the id
406
  self.config.audio_model_id = None
407
  self.keep_params.update(
408
  set(
 
413
  )
414
  )
415
 
 
 
 
 
 
 
 
 
 
 
 
416
  for param in ["text_model_lora_config", "audio_model_lora_config"]:
417
  if hasattr(self.config, param):
418
  delattr(self.config, param)
419
 
420
  def push_to_hub(self, *args, **kwargs):
421
  self.merge_and_unload()
 
422
  return super().push_to_hub(*args, **kwargs)
423
 
424
+ def diff_state_dict(
425
+ self, state_dict: Optional[Dict[str, Any]] = None
426
+ ) -> Dict[str, Any]:
427
+ if state_dict is None:
428
+ state_dict = super().state_dict()
429
+
430
+ trainable_params = {k for k, v in self.named_parameters() if v.requires_grad}
431
+ # normalize the keys to match the original model
432
+ # Example: audio_tower.base_model.model.layers.0._fsdp_wrapped_module.self_attn.k_proj.lora_B.default.weight
433
+ trainable_params = {
434
+ k.replace("_fsdp_wrapped_module.", "") for k in trainable_params
435
+ }
436
 
437
  state_dict = {
438
  k: v
439
  for k, v in state_dict.items()
440
+ if k in self.keep_params or k in trainable_params
 
441
  }
442
+
443
  return state_dict
444
 
445
+ def save_pretrained(
446
+ self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
 
 
 
447
  ):
448
+ state_dict = self.diff_state_dict(state_dict)
449
+
450
+ super().save_pretrained(*args, state_dict=state_dict, **kwargs)
451
+
452
+ def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
453
  self.keep_params.update(set(state_dict.keys()))
 
454
 
455
  def print_trainable_parameters(self):
456
  """
 
481
  )
482
 
483
 
484
+ # TODO: refactor common parts to a shared module
485
  def is_cache_empty(
486
+ past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]],
487
  ) -> bool:
488
  """
489
  Check if the cache is empty.
 
499
  """
500
  Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
501
  """
502
+ unfreeze_layers = lora_config.pop("unfreeze_layers", None)
503
  lora_config = peft.LoraConfig(**lora_config or {})
504
 
505
  if lora_config.r == 0:
506
+ # freeze the model entirely, except for the specified layers
507
+ for name, param in model.named_parameters():
508
+ if not unfreeze_layers or not any(
509
+ re.match(layer, name) for layer in unfreeze_layers
510
+ ):
511
+ param.requires_grad = False
512
+ else:
513
+ logging.info(f"Unfreezing layer: {name} with #{param.numel()} params")
514
  else:
515
  model = peft.get_peft_model(model, lora_config)
516
 
 
519
 
520
  class StackAudioFrames(nn.Module):
521
  """
522
+ Stack the audio embedding frames to reduce the sequence length by a factor
523
+ of `stack_factor`.
 
 
 
 
524
  """
525
 
526
  def __init__(self, stack_factor: int = 8):
 
530
  def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
531
  B, T, C = audio_embeds.shape
532
  T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
533
+ audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T))
534
  B, T, C = audio_embeds.shape
535
  audio_embeds = audio_embeds.view(
536
  B, T // self.stack_factor, C * self.stack_factor
 
550
  return F.silu(gate) * x
551
 
552
 
553
+ class UltravoxProjector(nn.Module):
554
  def __init__(self, config: UltravoxConfig):
555
  super().__init__()
556
  self.hidden_dim = config.hidden_size
557
  self._pad_and_stack = StackAudioFrames(config.stack_factor)
558
+ dim_in = config.audio_config.hidden_size * config.stack_factor
559
+ self.ln_pre = RMSNorm(dim_in, init=config.norm_init)
560
+ self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
561
+ dim_mid = self.hidden_dim
562
  self.act = transformers.activations.get_activation(config.projector_act)
563
+ dim_mid = dim_mid // 2 if config.projector_act == "swiglu" else dim_mid
564
+ dim_out = config.text_config.hidden_size
565
+ self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
566
+
567
+ # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
568
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
569
+ if config.projector_ln_mid:
570
+ self.ln_mid: nn.Module = RMSNorm(dim_mid, init=config.norm_init)
571
+ self.ln_post: nn.Module = nn.Identity()
572
+ else:
573
+ self.ln_mid = nn.Identity()
574
+ self.ln_post = RMSNorm(dim_out, init=config.norm_init)
575
 
576
  def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
577
  audio_features = self._pad_and_stack(audio_features)
578
  audio_features = self.ln_pre(audio_features)
579
  hidden_states = self.linear_1(audio_features)
580
  hidden_states = self.act(hidden_states)
581
+ hidden_states = self.ln_mid(hidden_states)
582
  hidden_states = self.linear_2(hidden_states)
583
  hidden_states = self.ln_post(hidden_states)
584
  return hidden_states
585
 
586
 
587
+ class ModifiedWhisperEncoder(
588
+ whisper.WhisperEncoder, transformers.modeling_utils.ModuleUtilsMixin
589
+ ):
590
  """
591
  Encoder portion of OpenAI's Whisper model.
592
 
 
600
  """
601
 
602
  base_model_prefix = "model.encoder"
603
+ _no_split_modules = ["WhisperEncoderLayer"]
604
+
605
+ def __init__(self, config: transformers.WhisperConfig):
606
+ super().__init__(config)
607
+ self.config.is_decoder = False
608
+
609
+ @property
610
+ def max_context_length(self):
611
+ return (
612
+ self.config.max_source_positions
613
+ * self.conv1.stride[0]
614
+ * self.conv2.stride[0]
615
+ )
616
+
617
+ def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
618
+ if audio_latency_block_size is None:
619
+ self.audio_streaming_mask = None
620
+ return
621
+
622
+ # Use max_context_length directly in the calculation
623
+ max_seqlen = self.max_context_length
624
+ assert (
625
+ max_seqlen > 0
626
+ ), f"maximum sequence length must be positive, got {max_seqlen}"
627
+ assert (
628
+ max_seqlen % audio_latency_block_size == 0
629
+ ), f"audio_latency_block_size {audio_latency_block_size} must divide {max_seqlen} evenly."
630
+ # Given the block size, we calculate number of blocks.
631
+ audio_latency_nblocks = max_seqlen // audio_latency_block_size
632
+ audio_streaming_mask = (
633
+ torch.tril(
634
+ torch.ones(audio_latency_nblocks, audio_latency_nblocks),
635
+ diagonal=0,
636
+ )
637
+ .repeat_interleave(audio_latency_block_size, dim=0)
638
+ .repeat_interleave(audio_latency_block_size, dim=1)
639
+ )
640
+ audio_streaming_mask = (1.0 - audio_streaming_mask) * torch.finfo(dtype).min
641
+ audio_streaming_mask = audio_streaming_mask[None, None, :, :]
642
+ self.register_buffer(
643
+ "audio_streaming_mask", audio_streaming_mask, persistent=False
644
+ )
645
 
646
  def forward(
647
  self,
648
  input_features,
649
+ audio_len=None,
650
  head_mask=None,
651
  output_attentions=None,
652
  output_hidden_states=None,
653
  return_dict=None,
654
  ):
655
+ expected_seq_length = self.max_context_length
 
 
 
 
656
  if input_features.shape[-1] > expected_seq_length:
657
  raise ValueError(
658
  f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
 
685
  encoder_states = () if output_hidden_states else None
686
  all_attentions = () if output_attentions else None
687
 
688
+ # Create attention mask based on audio lengths to mask out padding tokens
689
+ # For each sample in batch:
690
+ # - Convert raw audio length to feature length after convolutions
691
+ # - Create boolean mask that is True for valid positions and False for padding
692
+ # - Convert to extended attention mask format expected by transformer layers
693
+ # (1.0 for positions to attend to, large negative for positions to ignore)
694
+ # This masking ensures consistent behavior between training and inference
695
+ # by preventing the model from attending to padding tokens in both cases
696
+ attention_mask = None
697
+ if audio_len != None:
698
+ audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
699
+ max_seq_len = hidden_states.shape[1]
700
+ attention_mask = torch.arange(max_seq_len, device=hidden_states.device)[
701
+ None, :
702
+ ].lt(audio_feature_len.view(-1, 1))
703
+ attention_mask = self.get_extended_attention_mask(
704
+ attention_mask,
705
+ None,
706
+ dtype=hidden_states.dtype,
707
+ )
708
+
709
+ if self.audio_streaming_mask is not None:
710
+ seqlen = hidden_states.size(-2)
711
+ if attention_mask is not None:
712
+ attention_mask = torch.minimum(
713
+ self.audio_streaming_mask[:, :, :seqlen, :seqlen], attention_mask
714
+ ) # merge
715
+ else:
716
+ attention_mask = self.audio_streaming_mask[:, :, :seqlen, :seqlen]
717
+ attention_mask = attention_mask.to(hidden_states.dtype)
718
+
719
  # check if head_mask has a correct number of layers specified if desired
720
  if head_mask is not None:
721
  assert head_mask.size()[0] == (
 
739
  layer_outputs = self._gradient_checkpointing_func(
740
  encoder_layer.__call__,
741
  hidden_states,
742
+ attention_mask,
743
  (head_mask[idx] if head_mask is not None else None),
744
  output_attentions,
745
  )
746
  else:
747
  layer_outputs = encoder_layer(
748
  hidden_states,
749
+ attention_mask,
750
  layer_head_mask=(
751
  head_mask[idx] if head_mask is not None else None
752
  ),
 
780
 
781
  transformers.AutoConfig.register("ultravox", UltravoxConfig)
782
  transformers.AutoModel.register(UltravoxConfig, UltravoxModel)
 
783
 
784
  transformers.activations.ACT2FN["swiglu"] = SwiGLU
ultravox_processing.py CHANGED
@@ -1,12 +1,69 @@
1
- from typing import Optional, Union
 
2
 
3
  import numpy as np
4
  import torch
 
5
  import transformers
6
 
7
  from .ultravox_config import UltravoxConfig
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class UltravoxProcessor(transformers.ProcessorMixin):
11
  """
12
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
@@ -17,11 +74,7 @@ class UltravoxProcessor(transformers.ProcessorMixin):
17
  """
18
 
19
  attributes = ["audio_processor", "tokenizer"]
20
- audio_processor_class = (
21
- "Wav2Vec2Processor",
22
- "SeamlessM4TFeatureExtractor",
23
- "WhisperProcessor",
24
- )
25
  tokenizer_class = (
26
  "PreTrainedTokenizer",
27
  "PreTrainedTokenizerFast",
@@ -35,41 +88,46 @@ class UltravoxProcessor(transformers.ProcessorMixin):
35
  audio_processor=None,
36
  tokenizer=None,
37
  audio_padding: str = "longest",
38
- encoder_ds_factor: int = 320,
39
  stack_factor: int = 8,
40
  audio_placeholder: str = "<|audio|>",
 
 
41
  ):
42
  """
43
  Args:
44
  audio_processor: The audio processor for the audio encoder.
45
  tokenizer: The tokenizer for the language model.
46
  audio_padding: The padding strategy for the audio encoder.
47
- encoder_ds_factor: The downsample factor of the audio encoder.
48
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
 
49
  audio_placeholder: The placeholder for the audio in the text.
 
50
  """
51
  self.audio_padding = audio_padding
52
  self.encoder_ds_factor = encoder_ds_factor
53
  self.stack_factor = stack_factor
54
  self.audio_placeholder = audio_placeholder
55
- self.audio_token_replacement = tokenizer.eos_token
56
  assert (
57
- self.audio_token_replacement is not None
58
  ), "The tokenizer has no EOS token. Cannot recover."
 
 
59
  if tokenizer.pad_token_id is None:
60
  tokenizer.pad_token_id = tokenizer.eos_token_id
61
 
62
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
63
 
64
  @classmethod
65
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
66
  config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
67
  pretrained_model_name_or_path, **kwargs
68
  )
69
  audio_processor = transformers.AutoProcessor.from_pretrained(
70
  config.audio_model_id
71
  or config.audio_config._name_or_path
72
- or "facebook/wav2vec2-base-960h"
73
  )
74
 
75
  tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -84,30 +142,100 @@ class UltravoxProcessor(transformers.ProcessorMixin):
84
  stack_factor=config.stack_factor,
85
  )
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  def __call__(
88
  self,
89
  text: Optional[str] = None,
90
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
 
 
 
 
 
91
  sampling_rate: Optional[int] = None,
92
  return_tensors: Optional[
93
  Union[str, transformers.TensorType]
94
  ] = transformers.TensorType.PYTORCH,
 
95
  **kwargs,
96
  ) -> transformers.BatchFeature:
97
  """
98
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
99
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
100
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
101
- audio processor's [`~Wav2Vec2Processor.__call__`] if `audio` is not `None`. Please refer to the docstring
102
  of the above two methods for more information.
103
 
104
  Args:
105
  text (`str`, `List[str]`):
106
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
107
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
108
- The audio to be prepared. Audio can be NumPy array or PyTorch tensor. In case of a
109
- NumPy array/PyTorch tensor, each audio should be of shape (C, T), where C is a number of channels, and T the
110
- sample length of the audio.
111
  sampling_rate (`int`, *optional*, defaults to 16000):
112
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
113
  you are doing.
@@ -131,64 +259,105 @@ class UltravoxProcessor(transformers.ProcessorMixin):
131
  Returned when `audio` is not `None`.
132
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
133
  """
134
- # TODO: Add support for multiple audio and text inputs.
 
 
 
 
 
 
 
135
  data = {}
136
- audio_embed_frames = 0
137
- if audio is not None and len(audio) > 0:
138
- if self.audio_padding == "max_length":
139
- # 30 seconds is the expected length for Whisper
140
- assert sampling_rate is not None, "Sampling rate must be provided."
141
- audio_len = 30 * sampling_rate
142
- else:
143
- audio_len = audio.shape[-1]
144
- # It's guaranteed that the number of frames is less than or equal to this amount.
145
- # For Whisper this is exact AFAICT, but for Wav2Vec2 it's an upper bound.
146
- # Currently, StackAudioFrames makes sure an over-estimation won't cause issues by padding the audio embeddings.
147
- nb_encoder_frames = int(round(audio_len / self.encoder_ds_factor + 1e-4))
148
- audio_embed_frames = int(np.ceil(nb_encoder_frames / self.stack_factor))
149
- data["audio_token_len"] = [audio_embed_frames]
150
 
151
  # Main audio processing. The processor is model-specific.
152
- x = self.audio_processor(
153
- audio,
154
  sampling_rate=sampling_rate,
155
  padding="longest",
156
- max_length=audio_len,
 
 
157
  **kwargs,
158
  )
159
- if "input_features" in x:
160
- data["audio_values"] = x.input_features
161
- else:
162
- data["audio_values"] = x.input_values
163
 
164
- if text is not None:
165
- assert isinstance(
166
- text, str
167
- ), "Text must be a string. Batch mode not supported yet."
168
- if self.audio_placeholder in text:
169
- if "audio_token_len" not in data:
170
- raise ValueError(
171
- f"audio must be provided when using audio placeholder ({self.audio_placeholder}) in text."
172
- )
173
-
174
- start_idx = len(
175
- self.tokenizer.encode(
176
- text[: text.index(self.audio_placeholder)],
177
- add_special_tokens=False,
178
- )
179
- )
180
- data["audio_token_start_idx"] = [start_idx]
181
-
182
- # Replace the audio placeholder with the audio token.
183
- # e.g. "Transcribe\n<|audio|>" -> "Transcribe </s></s></s></s></s></s></s></s>"
184
- # where the number of </s> is the number of audio frames.
185
- text = text.replace(
186
- self.audio_placeholder,
187
- self.audio_token_replacement * audio_embed_frames,
188
  )
 
 
 
 
 
 
 
 
 
 
189
 
190
  # Special tokens like BOS should already have been added by the caller.
191
- data.update(self.tokenizer([text], add_special_tokens=False, **kwargs))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
194
 
@@ -205,4 +374,6 @@ class UltravoxProcessor(transformers.ProcessorMixin):
205
  return list(set(tokenizer_input_names + audio_processor_input_names))
206
 
207
 
208
- transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)
 
 
 
1
+ import dataclasses
2
+ from typing import Any, Dict, List, Optional, Union
3
 
4
  import numpy as np
5
  import torch
6
+ import torch.nn.functional as F
7
  import transformers
8
 
9
  from .ultravox_config import UltravoxConfig
10
 
11
 
12
+ @dataclasses.dataclass
13
+ class DataCollatorForSeq2SeqWithAudio(transformers.DataCollatorForSeq2Seq):
14
+ # when enabled, the alt_input_ids, alt_attention_mask, and alt_labels fields are used for computing the KL loss in UltravoxModel
15
+ include_alt_fields: bool = False
16
+
17
+ def __call__(self, features, *args, **kwargs):
18
+ audio_values = [x for f in features for x in f.pop("audio_values", [])]
19
+ audio_lens = [x for f in features for x in f.pop("audio_lens", [])]
20
+ audio_token_len = [x for f in features for x in f.pop("audio_token_len", [])]
21
+ audio_token_start_idx = [
22
+ x for f in features for x in f.pop("audio_token_start_idx", [])
23
+ ]
24
+
25
+ if self.include_alt_fields:
26
+ # these fields are hard-coded in the transformer data collator, so they need special handling before calling the super method
27
+ alt_features = [
28
+ {
29
+ "input_ids": f.pop("alt_input_ids"),
30
+ "attention_mask": f.pop("alt_attention_mask"),
31
+ "labels": f.pop("alt_labels"),
32
+ }
33
+ for f in features
34
+ ]
35
+
36
+ batch = super().__call__(features, *args, **kwargs)
37
+ if self.include_alt_fields:
38
+ alt_batch = super().__call__(alt_features, *args, **kwargs)
39
+ batch["alt_input_ids"] = alt_batch["input_ids"]
40
+ batch["alt_attention_mask"] = alt_batch["attention_mask"]
41
+ batch["alt_labels"] = alt_batch["labels"]
42
+
43
+ batch["audio_token_start_idx"] = torch.stack(audio_token_start_idx)
44
+ batch["audio_lens"] = torch.stack(audio_lens)
45
+ batch["audio_token_len"] = torch.stack(audio_token_len)
46
+
47
+ # Pad the last dimension of all audio_values to the same length, with 0s on the right.
48
+ if audio_values:
49
+ max_len = max([x.shape[-1] for x in audio_values])
50
+ batch["audio_values"] = torch.stack(
51
+ [F.pad(x, (0, max_len - x.shape[-1])) for x in audio_values]
52
+ )
53
+ if self.tokenizer.padding_side == "left":
54
+ input_ids_lens = torch.LongTensor(
55
+ [f["input_ids"].shape[-1] for f in features]
56
+ )
57
+ displacement = batch["input_ids"].shape[-1] - input_ids_lens
58
+ displacement = displacement.repeat_interleave(
59
+ batch["audio_batch_size"].squeeze(-1)
60
+ )
61
+ batch["audio_token_start_idx"] += displacement.to(
62
+ batch["audio_token_start_idx"].device
63
+ )
64
+ return batch
65
+
66
+
67
  class UltravoxProcessor(transformers.ProcessorMixin):
68
  """
69
  Constructs an Ultravox processor which wraps an audio processor and a tokenizer into a single processor.
 
74
  """
75
 
76
  attributes = ["audio_processor", "tokenizer"]
77
+ audio_processor_class = ("WhisperProcessor",)
 
 
 
 
78
  tokenizer_class = (
79
  "PreTrainedTokenizer",
80
  "PreTrainedTokenizerFast",
 
88
  audio_processor=None,
89
  tokenizer=None,
90
  audio_padding: str = "longest",
91
+ encoder_ds_factor: int = 2,
92
  stack_factor: int = 8,
93
  audio_placeholder: str = "<|audio|>",
94
+ # Defaults to whisper encoder context size
95
+ audio_context_size: Optional[int] = 3000,
96
  ):
97
  """
98
  Args:
99
  audio_processor: The audio processor for the audio encoder.
100
  tokenizer: The tokenizer for the language model.
101
  audio_padding: The padding strategy for the audio encoder.
 
102
  stack_factor: The factor by which the audio encoder output is stacked in the multimodal projector.
103
+ encoder_ds_factor: The downsampling factor of the audio encoder.
104
  audio_placeholder: The placeholder for the audio in the text.
105
+ audio_context_size: The maximum number of frames that the audio encoder can handle.
106
  """
107
  self.audio_padding = audio_padding
108
  self.encoder_ds_factor = encoder_ds_factor
109
  self.stack_factor = stack_factor
110
  self.audio_placeholder = audio_placeholder
111
+ self.audio_context_size = audio_context_size
112
  assert (
113
+ tokenizer.eos_token is not None
114
  ), "The tokenizer has no EOS token. Cannot recover."
115
+ self.vocab = tokenizer.get_vocab()
116
+ self.audio_token_replacement = tokenizer.eos_token
117
  if tokenizer.pad_token_id is None:
118
  tokenizer.pad_token_id = tokenizer.eos_token_id
119
 
120
  super().__init__(audio_processor=audio_processor, tokenizer=tokenizer)
121
 
122
  @classmethod
123
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
124
  config: UltravoxConfig = transformers.AutoConfig.from_pretrained(
125
  pretrained_model_name_or_path, **kwargs
126
  )
127
  audio_processor = transformers.AutoProcessor.from_pretrained(
128
  config.audio_model_id
129
  or config.audio_config._name_or_path
130
+ or "openai/whisper-tiny"
131
  )
132
 
133
  tokenizer = transformers.AutoTokenizer.from_pretrained(
 
142
  stack_factor=config.stack_factor,
143
  )
144
 
145
+ def _chunk_and_pad_audio(
146
+ self,
147
+ audio_values: torch.Tensor,
148
+ audio_lens: torch.Tensor,
149
+ include_audio_num_chunks: bool = False,
150
+ ) -> Dict[str, Any]:
151
+ """
152
+ Processes the audio batch by chunking any items in the batch according to the audio_context_size,
153
+ padding the last chunk if needed, and returns a dictionary with updated audio data.
154
+
155
+ Args:
156
+ audio_values (torch.Tensor): A tensor of audio values (e.g., in B, D, T format).
157
+ audio_lens (torch.Tensor): A tensor of audio lengths.
158
+
159
+ Returns:
160
+ Dict[str, Any]: Dictionary with the following keys:
161
+ - "audio_values": The concatenated audio tensor after chunking and padding.
162
+ - "audio_lens": Tensor of lengths for each chunk.
163
+ - "audio_is_continuation": Tensor of booleans indicating if the chunk is a continuation of the previous chunk.
164
+ - "audio_batch_size": A Tensor with one integer representing the number of chunks.
165
+
166
+ """
167
+ chunked_audio_values: List[torch.Tensor] = []
168
+ chunked_audio_lens: List[int] = []
169
+ is_continuation_list: List[bool] = []
170
+ num_chunks: List[int] = []
171
+ context_size = self.audio_context_size or audio_values.shape[-1]
172
+
173
+ for i in range(audio_values.shape[0]): # iterate over the batch
174
+ num_chunks.append(int(np.ceil(audio_lens[i] / context_size)))
175
+ for offset in range(0, audio_lens[i], context_size):
176
+ is_continuation = offset > 0
177
+ chunk = audio_values[i, :, offset : offset + context_size]
178
+ if is_continuation and chunk.shape[-1] < context_size:
179
+ # N.B. We only need to pad continuation chunks. If none of the samples require chunking, the
180
+ # batch might not (need to) be padded all the way to the audio_context_size, in which case
181
+ # we've already included the padding above. On the other hand, if we have any continuation
182
+ # chunks we know that the batch needs to be padded to audio_context_size because that's what
183
+ # we're slicing to.
184
+ chunk = F.pad(chunk, (0, context_size - chunk.shape[-1]))
185
+ chunked_audio_values.append(chunk)
186
+ chunked_audio_lens.append(
187
+ min(int(audio_lens[i].item()) - offset, context_size)
188
+ )
189
+ is_continuation_list.append(is_continuation)
190
+
191
+ data = {
192
+ "audio_values": torch.stack(chunked_audio_values, dim=0),
193
+ "audio_lens": torch.tensor(
194
+ chunked_audio_lens, dtype=torch.int64, device=audio_values.device
195
+ ),
196
+ "audio_is_continuation": torch.tensor(
197
+ is_continuation_list, dtype=torch.bool, device=audio_values.device
198
+ ),
199
+ "audio_batch_size": torch.tensor(
200
+ [len(chunked_audio_values)], device=audio_values.device
201
+ ),
202
+ }
203
+ if include_audio_num_chunks:
204
+ data["audio_num_chunks"] = torch.tensor(
205
+ num_chunks, dtype=torch.int64, device=audio_values.device
206
+ )
207
+ return data
208
+
209
  def __call__(
210
  self,
211
  text: Optional[str] = None,
212
  audio: Optional[Union[np.ndarray, torch.Tensor]] = None,
213
+ audios: Optional[
214
+ Union[
215
+ List[Union[np.ndarray, torch.Tensor]], Union[np.ndarray, torch.Tensor]
216
+ ]
217
+ ] = None,
218
  sampling_rate: Optional[int] = None,
219
  return_tensors: Optional[
220
  Union[str, transformers.TensorType]
221
  ] = transformers.TensorType.PYTORCH,
222
+ include_audio_num_chunks: bool = False,
223
  **kwargs,
224
  ) -> transformers.BatchFeature:
225
  """
226
  Main method to prepare for the model one text sequence and audio. This method forwards the `text`
227
  and `kwargs` arguments to PreTrainedTokenizerFast's [`~PreTrainedTokenizerFast.__call__`] if `text` is not `None` to encode
228
  the text. To prepare the audio(s), this method forwards the `audio`, `sampling_rate` and `kwargs` arguments to
229
+ audio processor's [`~WhisperProcessor.__call__`] if `audio` is not `None`. Please refer to the docstring
230
  of the above two methods for more information.
231
 
232
  Args:
233
  text (`str`, `List[str]`):
234
  The sequence to be encoded. Sequence can be a string or (pretokenized string).
235
  audio (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
236
+ The audio to be prepared. Audio can be a single-channel (1-dimensional) NumPy array or PyTorch tensor.
237
+ audios (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`):
238
+ A list or two dimensional array of audio to be prepared.
239
  sampling_rate (`int`, *optional*, defaults to 16000):
240
  Sampling rate of the input audio. We expect 16kHz audio. Don't change this value unless you know what
241
  you are doing.
 
259
  Returned when `audio` is not `None`.
260
  - **audio_token_start_idx** -- The index in the tokenized text where the audio starts. Returned when `audio` is not `None`.
261
  """
262
+ # TODO: Add support for multiple text inputs.
263
+ if audio is not None and audios is not None:
264
+ raise ValueError("Only one of `audio` or `audios` should be provided.")
265
+ elif audio is not None:
266
+ audios = audio if isinstance(audio, list) or audio.ndim == 2 else [audio]
267
+ elif audios is None:
268
+ audios = []
269
+
270
  data = {}
271
+ audio_is_continuation = []
272
+ if len(audios) > 0:
273
+ audios = [x.numpy() if isinstance(x, torch.Tensor) else x for x in audios]
274
+
275
+ # Pad out each audio to at least 2 hops (the minimum required by the processor).
276
+ hop_length = self.audio_processor.feature_extractor.hop_length
277
+ audios = [
278
+ (
279
+ np.pad(x, (0, 2 * hop_length - len(x)), mode="constant")
280
+ if len(x) < 2 * hop_length
281
+ else x
282
+ )
283
+ for x in audios
284
+ ]
285
 
286
  # Main audio processing. The processor is model-specific.
287
+ x: transformers.BatchFeature = self.audio_processor(
288
+ audios,
289
  sampling_rate=sampling_rate,
290
  padding="longest",
291
+ pad_to_multiple_of=hop_length, # The attention mask effectively gets padded to the hop length, so pad the audio to be consistent.
292
+ truncation=False,
293
+ return_attention_mask=True,
294
  **kwargs,
295
  )
 
 
 
 
296
 
297
+ data.update(
298
+ self._chunk_and_pad_audio(
299
+ audio_values=torch.as_tensor(
300
+ x.input_features if "input_features" in x else x.input_values
301
+ ),
302
+ audio_lens=torch.as_tensor(x.attention_mask).sum(-1),
303
+ include_audio_num_chunks=include_audio_num_chunks,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  )
305
+ )
306
+
307
+ audio_is_continuation = data.pop("audio_is_continuation")
308
+ data["audio_token_len"] = torch.ceil(
309
+ data["audio_lens"] / (self.encoder_ds_factor * self.stack_factor)
310
+ ).to(dtype=torch.int)
311
+
312
+ if text is not None:
313
+ if not isinstance(text, str):
314
+ raise ValueError("Text must be a string. Batch mode not supported yet.")
315
 
316
  # Special tokens like BOS should already have been added by the caller.
317
+ tokenized_parts = self.tokenizer(
318
+ text.split(
319
+ "<|audio|>" # The placeholder isn't part of the vocabulary, so split the text around it.
320
+ ),
321
+ add_special_tokens=False,
322
+ **kwargs,
323
+ )
324
+
325
+ audio_token_start_idx = []
326
+ placeholder_index = -1
327
+ split_input_ids = tokenized_parts["input_ids"]
328
+ input_ids: List[int] = []
329
+
330
+ audio_token_replacement_token_id = self.vocab[self.audio_token_replacement]
331
+
332
+ for i, token_len in enumerate(data.get("audio_token_len", [])):
333
+ if not audio_is_continuation[i]:
334
+ placeholder_index += 1
335
+ if placeholder_index >= len(split_input_ids):
336
+ raise ValueError(
337
+ f"Text contains too few audio placeholders. (Expected {len(audios)} placeholders)"
338
+ )
339
+
340
+ input_ids.extend(split_input_ids[placeholder_index])
341
+
342
+ audio_token_start_idx.append(len(input_ids))
343
+
344
+ input_ids.extend([audio_token_replacement_token_id] * token_len)
345
+
346
+ # Include any tokens after the last audio.
347
+ placeholder_index += 1
348
+ if placeholder_index != len(split_input_ids) - 1:
349
+ raise ValueError(
350
+ f"Text contains too many audio placeholders. (Expected {len(audios)} placeholders)"
351
+ )
352
+ input_ids.extend(split_input_ids[placeholder_index])
353
+
354
+ if "audio_token_len" in data:
355
+ data["audio_token_start_idx"] = torch.as_tensor(audio_token_start_idx)
356
+
357
+ data["input_ids"] = [input_ids]
358
+ data["attention_mask"] = [[1] * len(input_ids)]
359
+
360
+ # Ensure that there are no audio placeholders after the last audio.
361
 
362
  return transformers.BatchFeature(data=data, tensor_type=return_tensors)
363
 
 
374
  return list(set(tokenizer_input_names + audio_processor_input_names))
375
 
376
 
377
+ UltravoxProcessor.register_for_auto_class()
378
+
379
+ transformers.AutoProcessor.register(UltravoxConfig, UltravoxProcessor)