Upload 10 files
Browse files- .gitignore +5 -0
- __init__.py +1 -0
- hf_utils.py +15 -0
- mamba_block.py +354 -0
- mamba_config.py +86 -0
- mamba_model.py +183 -0
- mlp.py +43 -0
- setup.py +159 -0
- switch_mlp.py +91 -0
- utils.py +82 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*__pycache__/
|
2 |
+
*.egg-info/
|
3 |
+
build/
|
4 |
+
**.so
|
5 |
+
**.ipynb
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
hf_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import torch
|
3 |
+
import transformers
|
4 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
5 |
+
from transformers.utils.hub import cached_file
|
6 |
+
|
7 |
+
|
8 |
+
def load_config_hf(model_name):
|
9 |
+
resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
|
10 |
+
return json.load(open(resolved_archive_file))
|
11 |
+
|
12 |
+
|
13 |
+
def load_state_dict_hf(model_name, device="cpu"):
|
14 |
+
resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
|
15 |
+
return torch.load(resolved_archive_file, map_location=device)
|
mamba_block.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import math
|
3 |
+
from typing import Optional, Union
|
4 |
+
import re
|
5 |
+
from contextlib import nullcontext
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
from dataclasses import dataclass
|
8 |
+
import functools
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from torch import Tensor
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
|
17 |
+
try:
|
18 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
|
19 |
+
except ImportError:
|
20 |
+
causal_conv1d_fn, causal_conv1d_update = None, None
|
21 |
+
|
22 |
+
try:
|
23 |
+
from ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
|
24 |
+
except ImportError:
|
25 |
+
selective_scan_fn, mamba_inner_fn = None, None
|
26 |
+
|
27 |
+
try:
|
28 |
+
from ops.triton.selective_state_update import selective_state_update
|
29 |
+
except ImportError:
|
30 |
+
selective_state_update = None
|
31 |
+
|
32 |
+
try:
|
33 |
+
from ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
34 |
+
except ImportError:
|
35 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
36 |
+
|
37 |
+
from mamba_layer import MambaLayer
|
38 |
+
from mamba_config import MambaConfig
|
39 |
+
from mlp import MLP
|
40 |
+
from switch_mlp import SwitchMLP
|
41 |
+
|
42 |
+
|
43 |
+
class MambaBlock(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
46 |
+
):
|
47 |
+
super().__init__()
|
48 |
+
self.config = config
|
49 |
+
self.residual_in_fp32 = residual_in_fp32
|
50 |
+
self.fused_add_norm = fused_add_norm
|
51 |
+
self.mixer = mixer_cls(config)
|
52 |
+
|
53 |
+
if not config.rms_norm:
|
54 |
+
self.norm = norm_cls
|
55 |
+
else:
|
56 |
+
self.norm = norm_cls(config.hidden_size)
|
57 |
+
|
58 |
+
if self.fused_add_norm:
|
59 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
60 |
+
assert isinstance(
|
61 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
62 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
63 |
+
if moe_cls is not None:
|
64 |
+
self.moe = moe_cls(config)
|
65 |
+
else:
|
66 |
+
self.moe = None
|
67 |
+
|
68 |
+
def forward(
|
69 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
70 |
+
):
|
71 |
+
|
72 |
+
if not self.fused_add_norm:
|
73 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
74 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
75 |
+
if self.residual_in_fp32:
|
76 |
+
residual = residual.to(torch.float32)
|
77 |
+
else:
|
78 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
79 |
+
hidden_states, residual = fused_add_norm_fn(
|
80 |
+
hidden_states,
|
81 |
+
self.norm.weight,
|
82 |
+
self.norm.bias,
|
83 |
+
residual=residual,
|
84 |
+
prenorm=True,
|
85 |
+
residual_in_fp32=self.residual_in_fp32,
|
86 |
+
eps=self.norm.eps,
|
87 |
+
)
|
88 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
89 |
+
return hidden_states , residual
|
90 |
+
|
91 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
92 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
93 |
+
|
94 |
+
class MambaBlockParallelMoe(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, norm_moe=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
97 |
+
):
|
98 |
+
|
99 |
+
super().__init__()
|
100 |
+
self.config = config
|
101 |
+
self.residual_in_fp32 = residual_in_fp32
|
102 |
+
self.fused_add_norm = fused_add_norm
|
103 |
+
self.mixer = mixer_cls(config)
|
104 |
+
if not config.rms_norm:
|
105 |
+
self.norm = norm_cls
|
106 |
+
self.norm_moe = norm_moe
|
107 |
+
else:
|
108 |
+
self.norm = norm_cls(config.hidden_size)
|
109 |
+
self.norm_moe = norm_moe(config.hidden_size)
|
110 |
+
if self.fused_add_norm:
|
111 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
112 |
+
assert isinstance(
|
113 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
114 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
115 |
+
assert isinstance(
|
116 |
+
self.norm_moe, (nn.LayerNorm, RMSNorm)
|
117 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
118 |
+
if moe_cls is not None:
|
119 |
+
self.moe = moe_cls(config)
|
120 |
+
else:
|
121 |
+
self.moe = None
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
125 |
+
):
|
126 |
+
|
127 |
+
if not self.fused_add_norm:
|
128 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
129 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
130 |
+
hidden_states_moe = self.norm_moe(residual.to(dtype=self.norm.weight.dtype))
|
131 |
+
if self.residual_in_fp32:
|
132 |
+
residual = residual.to(torch.float32)
|
133 |
+
else:
|
134 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
135 |
+
hidden_states, residual = fused_add_norm_fn(
|
136 |
+
hidden_states,
|
137 |
+
self.norm.weight,
|
138 |
+
self.norm.bias,
|
139 |
+
residual=residual,
|
140 |
+
prenorm=True,
|
141 |
+
residual_in_fp32=self.residual_in_fp32,
|
142 |
+
eps=self.norm.eps,
|
143 |
+
)
|
144 |
+
hidden_states_moe, _ = fused_add_norm_fn(
|
145 |
+
hidden_states,
|
146 |
+
self.norm_moe.weight,
|
147 |
+
self.norm_moe.bias,
|
148 |
+
residual=residual,
|
149 |
+
prenorm=True,
|
150 |
+
residual_in_fp32=self.residual_in_fp32,
|
151 |
+
eps=self.norm_moe.eps,
|
152 |
+
)
|
153 |
+
|
154 |
+
hidden_states = self.mixer(hidden_states, inference_params=inference_params)
|
155 |
+
|
156 |
+
hidden_states_moe = self.moe(hidden_states_moe)
|
157 |
+
hidden_states += hidden_states_moe
|
158 |
+
return hidden_states , residual
|
159 |
+
|
160 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
161 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
162 |
+
|
163 |
+
|
164 |
+
class MoEBlock(nn.Module):
|
165 |
+
def __init__(
|
166 |
+
self, config, mixer_cls, moe_cls=None, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
|
167 |
+
):
|
168 |
+
|
169 |
+
super().__init__()
|
170 |
+
self.config = config
|
171 |
+
self.residual_in_fp32 = residual_in_fp32
|
172 |
+
self.fused_add_norm = fused_add_norm
|
173 |
+
self.mixer = mixer_cls(config)
|
174 |
+
if not config.rms_norm:
|
175 |
+
self.norm = norm_cls
|
176 |
+
else:
|
177 |
+
self.norm = norm_cls(config.hidden_size)
|
178 |
+
if self.fused_add_norm:
|
179 |
+
assert RMSNorm is not None, "RMSNorm import fails"
|
180 |
+
assert isinstance(
|
181 |
+
self.norm, (nn.LayerNorm, RMSNorm)
|
182 |
+
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
|
183 |
+
if moe_cls is not None:
|
184 |
+
self.moe = moe_cls(config)
|
185 |
+
else:
|
186 |
+
self.moe = None
|
187 |
+
|
188 |
+
def forward(
|
189 |
+
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
|
190 |
+
):
|
191 |
+
if not self.fused_add_norm:
|
192 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
193 |
+
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
|
194 |
+
if self.residual_in_fp32:
|
195 |
+
residual = residual.to(torch.float32)
|
196 |
+
else:
|
197 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
|
198 |
+
hidden_states, residual = fused_add_norm_fn(
|
199 |
+
hidden_states,
|
200 |
+
self.norm.weight,
|
201 |
+
self.norm.bias,
|
202 |
+
residual=residual,
|
203 |
+
prenorm=True,
|
204 |
+
residual_in_fp32=self.residual_in_fp32,
|
205 |
+
eps=self.norm.eps,
|
206 |
+
)
|
207 |
+
hidden_states = self.mixer(hidden_states)
|
208 |
+
return hidden_states , residual
|
209 |
+
|
210 |
+
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
211 |
+
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|
212 |
+
|
213 |
+
|
214 |
+
def create_block(config, layer_idx):
|
215 |
+
|
216 |
+
if config.rms_norm:
|
217 |
+
norm_cls = partial(RMSNorm, eps=config.layernorm_epsilon)
|
218 |
+
else:
|
219 |
+
norm_cls = partial(nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon)
|
220 |
+
|
221 |
+
if (not config.mamba_moe_layers) or config.mamba_moe_layers[layer_idx-1][0] == 'r':
|
222 |
+
if (not config.mamba_moe_layers) or len(config.mamba_moe_layers[layer_idx-1]) == 1:
|
223 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
224 |
+
block = MambaBlock(
|
225 |
+
config,
|
226 |
+
mixer_cls=mixer_cls,
|
227 |
+
norm_cls=norm_cls,
|
228 |
+
fused_add_norm=config.fused_add_norm,
|
229 |
+
residual_in_fp32=config.residual_in_fp32,
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
if config.mamba_moe_layers[layer_idx-1][1] == '1':
|
233 |
+
if config.rms_norm:
|
234 |
+
norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
|
235 |
+
else:
|
236 |
+
norm_moe = partial(
|
237 |
+
nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
|
238 |
+
)
|
239 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
240 |
+
moe_cls = partial(MLP, layer_idx=layer_idx)
|
241 |
+
block = MambaBlockParallelMoe(
|
242 |
+
config,
|
243 |
+
mixer_cls=mixer_cls,
|
244 |
+
moe_cls=moe_cls,
|
245 |
+
norm_cls=norm_cls,
|
246 |
+
norm_moe=norm_moe,
|
247 |
+
fused_add_norm=config.fused_add_norm,
|
248 |
+
residual_in_fp32=config.residual_in_fp32,
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
if config.rms_norm:
|
252 |
+
norm_moe = partial(RMSNorm, eps=config.layernorm_epsilon)
|
253 |
+
else:
|
254 |
+
norm_moe = partial(
|
255 |
+
nn.LayerNorm if not config.rms_norm else RMSNorm, eps=config.layernorm_epsilon
|
256 |
+
)
|
257 |
+
mixer_cls = partial(MambaLayer, layer_idx=layer_idx)
|
258 |
+
moe_cls = partial(SwitchMLP, layer_idx=layer_idx)
|
259 |
+
block = MambaBlockParallelMoe(
|
260 |
+
config,
|
261 |
+
mixer_cls=mixer_cls,
|
262 |
+
moe_cls=moe_cls,
|
263 |
+
norm_cls=norm_cls,
|
264 |
+
norm_moe=norm_moe,
|
265 |
+
fused_add_norm=config.fused_add_norm,
|
266 |
+
residual_in_fp32=config.residual_in_fp32,
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
if config.mamba_moe_layers[layer_idx-1][0] == '1':
|
270 |
+
mixer_cls = partial(MLP, layer_idx=layer_idx)
|
271 |
+
block = MoEBlock(
|
272 |
+
config,
|
273 |
+
mixer_cls=mixer_cls,
|
274 |
+
norm_cls=norm_cls,
|
275 |
+
fused_add_norm=config.fused_add_norm,
|
276 |
+
residual_in_fp32=config.residual_in_fp32,
|
277 |
+
)
|
278 |
+
else:
|
279 |
+
mixer_cls = partial(SwitchMLP, layer_idx=layer_idx)
|
280 |
+
block = MoEBlock(
|
281 |
+
config,
|
282 |
+
mixer_cls=mixer_cls,
|
283 |
+
norm_cls=norm_cls,
|
284 |
+
fused_add_norm=config.fused_add_norm,
|
285 |
+
residual_in_fp32=config.residual_in_fp32,
|
286 |
+
)
|
287 |
+
block.layer_idx = layer_idx
|
288 |
+
return block
|
289 |
+
|
290 |
+
class MambaDecoder(nn.Module):
|
291 |
+
"""Class wrapping a decoder stack of mamba blocks."""
|
292 |
+
|
293 |
+
def __init__(
|
294 |
+
self,
|
295 |
+
config: MambaConfig,
|
296 |
+
post_layer_norm=True,
|
297 |
+
pre_process=True,
|
298 |
+
post_process=True,
|
299 |
+
):
|
300 |
+
super().__init__()
|
301 |
+
|
302 |
+
self.config: MambaConfig = config
|
303 |
+
self.post_layer_norm = post_layer_norm
|
304 |
+
self.pre_process = pre_process
|
305 |
+
self.post_process = post_process
|
306 |
+
self.norm_cls = partial(nn.LayerNorm, eps=self.config.layernorm_epsilon)
|
307 |
+
|
308 |
+
self._build_layers()
|
309 |
+
|
310 |
+
def _build_layers(self):
|
311 |
+
|
312 |
+
num_layers_to_build = self.config.num_layers
|
313 |
+
# build the actual mamba layers
|
314 |
+
self.layers = torch.nn.ModuleList([create_block(self.config, i + 1) for i in range(num_layers_to_build)])
|
315 |
+
|
316 |
+
if self.post_process and self.post_layer_norm:
|
317 |
+
# Final layer norm before output.
|
318 |
+
self.final_layernorm = self.norm_cls(self.config.hidden_size, bias = True)
|
319 |
+
|
320 |
+
def _get_layer(self, layer_number):
|
321 |
+
return self.layers[layer_number]
|
322 |
+
|
323 |
+
def forward(self, hidden_states, residual = None, inference_params=None):
|
324 |
+
|
325 |
+
if not self.pre_process:
|
326 |
+
# See set_input_tensor()
|
327 |
+
hidden_states = self.input_tensor
|
328 |
+
|
329 |
+
residual = None
|
330 |
+
for i,layer in enumerate(self.layers):
|
331 |
+
hidden_states, residual = layer(
|
332 |
+
hidden_states=hidden_states,
|
333 |
+
residual = residual,
|
334 |
+
inference_params=inference_params,
|
335 |
+
)
|
336 |
+
|
337 |
+
# Final layer norm.
|
338 |
+
if self.post_process and self.post_layer_norm:
|
339 |
+
if not self.config.fused_add_norm:
|
340 |
+
residual = (hidden_states + residual) if residual is not None else hidden_states
|
341 |
+
hidden_states = self.final_layernorm(residual.to(dtype=self.final_layernorm.weight.dtype))
|
342 |
+
else:
|
343 |
+
# Set prenorm=False here since we don't need the residual
|
344 |
+
fused_add_norm_fn = rms_norm_fn if isinstance(self.final_layernorm, RMSNorm) else layer_norm_fn
|
345 |
+
hidden_states = fused_add_norm_fn(
|
346 |
+
hidden_states,
|
347 |
+
self.final_layernorm.weight,
|
348 |
+
self.final_layernorm.bias,
|
349 |
+
eps=self.final_layernorm.eps,
|
350 |
+
residual=residual,
|
351 |
+
prenorm=False,
|
352 |
+
residual_in_fp32=self.residual_in_fp32,
|
353 |
+
)
|
354 |
+
return hidden_states
|
mamba_config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Callable
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from utils import init_method_normal, scaled_init_method_normal
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class MambaConfig():
|
10 |
+
base_model_type: str = "mamba"
|
11 |
+
num_layers: int = 0
|
12 |
+
hidden_size: int = 0
|
13 |
+
state_size: int = 0
|
14 |
+
vocab_size: int = 50000
|
15 |
+
expansion_factor: int = 2
|
16 |
+
conv_dimension: int = 0
|
17 |
+
conv_bias: bool = True
|
18 |
+
bias: bool = True
|
19 |
+
use_fast_path: bool = True
|
20 |
+
dt_rank: str = "auto"
|
21 |
+
dt_min: float = 0.001
|
22 |
+
dt_max: float = 0.1
|
23 |
+
dt_init: str = "random"
|
24 |
+
dt_scale: float = 1.0
|
25 |
+
dt_init_floor: float = 1e-4
|
26 |
+
rms_norm: bool = True
|
27 |
+
fused_add_norm: bool = False
|
28 |
+
residual_in_fp32: bool = True
|
29 |
+
hidden_dropout: float = 0.0
|
30 |
+
ffn_hidden_size: int = None
|
31 |
+
gated_linear_unit: bool = False
|
32 |
+
mamba_moe_layers: str = ""
|
33 |
+
routing_mode: str = "sinkhorn"
|
34 |
+
device: str = "cuda"
|
35 |
+
fp32_residual_connection: bool = False
|
36 |
+
layernorm_epsilon: float = 1e-5
|
37 |
+
layernorm_zero_centered_gamma: bool = False
|
38 |
+
add_bias_linear: bool = True
|
39 |
+
activation_func: Callable = F.gelu
|
40 |
+
num_moe_experts: int = None
|
41 |
+
|
42 |
+
# initialization
|
43 |
+
init_method: Callable = None
|
44 |
+
output_layer_init_method: Callable = None
|
45 |
+
init_method_std: float = 0.02
|
46 |
+
|
47 |
+
# mixed-precision
|
48 |
+
apply_query_key_layer_scaling: bool = True
|
49 |
+
attention_softmax_in_fp32: bool = True
|
50 |
+
|
51 |
+
# fusion
|
52 |
+
gated_linear_unit: bool = False
|
53 |
+
bias_gelu_fusion: bool = False
|
54 |
+
persist_layer_norm: bool = False
|
55 |
+
bias_dropout_fusion: bool = False
|
56 |
+
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
""" Python dataclass method that is used to modify attributes after initialization.
|
60 |
+
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
|
61 |
+
"""
|
62 |
+
if self.apply_query_key_layer_scaling:
|
63 |
+
self.attention_softmax_in_fp32 = True
|
64 |
+
|
65 |
+
if self.ffn_hidden_size is None:
|
66 |
+
self.ffn_hidden_size = 4 * self.hidden_size
|
67 |
+
|
68 |
+
if self.apply_query_key_layer_scaling:
|
69 |
+
self.attention_softmax_in_fp32 = True
|
70 |
+
|
71 |
+
if self.bias_gelu_fusion:
|
72 |
+
if not self.add_bias_linear:
|
73 |
+
raise ValueError(
|
74 |
+
"When bias_gelu_fusion is True, add_bias_linear must also be True."
|
75 |
+
)
|
76 |
+
|
77 |
+
if self.activation_func != F.gelu:
|
78 |
+
raise ValueError(f'When bias_gelu_fusion is True, activation_func must be F.gelu.')
|
79 |
+
|
80 |
+
if self.init_method is None:
|
81 |
+
self.init_method = init_method_normal(self.init_method_std)
|
82 |
+
|
83 |
+
if self.output_layer_init_method is None:
|
84 |
+
self.output_layer_init_method = scaled_init_method_normal(
|
85 |
+
self.init_method_std, self.num_layers
|
86 |
+
)
|
mamba_model.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Literal, Optional, Union
|
3 |
+
import functools
|
4 |
+
from functools import partial
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch import Tensor
|
8 |
+
import math
|
9 |
+
import os
|
10 |
+
from mamba_block import MambaBlock, MambaDecoder
|
11 |
+
from mamba_config import MambaConfig
|
12 |
+
from hf_utils import *
|
13 |
+
import os, json
|
14 |
+
from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
|
15 |
+
from transformers.utils.hub import cached_file
|
16 |
+
|
17 |
+
|
18 |
+
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
|
19 |
+
def _init_weights(
|
20 |
+
module,
|
21 |
+
n_layer,
|
22 |
+
initializer_range=0.02, # Now only used for embedding layer.
|
23 |
+
rescale_prenorm_residual=True,
|
24 |
+
n_residuals_per_layer=1, # Change to 2 if we have MLP
|
25 |
+
):
|
26 |
+
if isinstance(module, nn.Linear):
|
27 |
+
if module.bias is not None:
|
28 |
+
if not getattr(module.bias, "_no_reinit", False):
|
29 |
+
nn.init.zeros_(module.bias)
|
30 |
+
elif isinstance(module, nn.Embedding):
|
31 |
+
nn.init.normal_(module.weight, std=initializer_range)
|
32 |
+
|
33 |
+
if rescale_prenorm_residual:
|
34 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
35 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
36 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
37 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
38 |
+
#
|
39 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
40 |
+
for name, p in module.named_parameters():
|
41 |
+
if name in ["out_proj.weight", "fc2.weight"]:
|
42 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
43 |
+
# Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
|
44 |
+
# We need to reinit p since this code could be called multiple times
|
45 |
+
# Having just p *= scale would repeatedly scale it down
|
46 |
+
nn.init.kaiming_uniform_(p, a=math.sqrt(5))
|
47 |
+
with torch.no_grad():
|
48 |
+
p /= math.sqrt(n_residuals_per_layer * n_layer)
|
49 |
+
|
50 |
+
|
51 |
+
class MambaModel(nn.Module):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
config: MambaConfig,
|
55 |
+
max_sequence_length: int,
|
56 |
+
pre_process: bool = True,
|
57 |
+
post_process: bool = True,
|
58 |
+
fp16_lm_cross_entropy: bool = False,
|
59 |
+
parallel_output: bool = True,
|
60 |
+
share_embeddings_and_output_weights: bool = True,
|
61 |
+
initializer_cfg = None,
|
62 |
+
) -> None:
|
63 |
+
super().__init__()
|
64 |
+
|
65 |
+
self.config: MambaConfig = config
|
66 |
+
self.max_sequence_length = max_sequence_length
|
67 |
+
self.pre_process = pre_process
|
68 |
+
self.post_process = post_process
|
69 |
+
self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
|
70 |
+
self.parallel_output = parallel_output
|
71 |
+
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
|
72 |
+
|
73 |
+
if self.pre_process:
|
74 |
+
self.embedding = nn.Embedding(self.config.vocab_size, self.config.hidden_size)
|
75 |
+
|
76 |
+
|
77 |
+
self.decoder = MambaDecoder(
|
78 |
+
config = self.config,
|
79 |
+
pre_process = self.pre_process,
|
80 |
+
post_process = self.post_process,
|
81 |
+
)
|
82 |
+
|
83 |
+
if post_process:
|
84 |
+
self.output_layer = nn.Linear(self.config.hidden_size, self.config.vocab_size, bias = self.config.add_bias_linear)
|
85 |
+
if self.share_embeddings_and_output_weights and (self.pre_process or self.post_process):
|
86 |
+
self.initialize_last_stage_with_word_embeddings()
|
87 |
+
|
88 |
+
# apply weight initialization
|
89 |
+
self.apply(
|
90 |
+
partial(
|
91 |
+
_init_weights,
|
92 |
+
n_layer=self.config.num_layers,
|
93 |
+
**(initializer_cfg if initializer_cfg is not None else {}),
|
94 |
+
)
|
95 |
+
)
|
96 |
+
|
97 |
+
def initialize_last_stage_with_word_embeddings(self):
|
98 |
+
with torch.no_grad():
|
99 |
+
self.output_layer.weight = self.embedding.weight
|
100 |
+
|
101 |
+
def forward(
|
102 |
+
self,
|
103 |
+
input_ids,
|
104 |
+
position_ids = None,
|
105 |
+
decoder_input: Tensor = None,
|
106 |
+
labels: Tensor = None,
|
107 |
+
inference_params=None,
|
108 |
+
) -> Tensor:
|
109 |
+
if decoder_input is not None:
|
110 |
+
pass
|
111 |
+
elif self.pre_process:
|
112 |
+
decoder_input = self.embedding(input_ids)
|
113 |
+
else:
|
114 |
+
decoder_input = None
|
115 |
+
|
116 |
+
hidden_states = self.decoder(
|
117 |
+
hidden_states=decoder_input,
|
118 |
+
residual=None,
|
119 |
+
inference_params=inference_params,
|
120 |
+
)
|
121 |
+
|
122 |
+
if not self.post_process:
|
123 |
+
return hidden_states
|
124 |
+
|
125 |
+
logits = self.output_layer(hidden_states)
|
126 |
+
|
127 |
+
return logits.contiguous()
|
128 |
+
|
129 |
+
@classmethod
|
130 |
+
def from_pretrained(cls, pretrained_model_name = None, checkpoint_name=None, config_name=None, **kwargs):
|
131 |
+
if pretrained_model_name is not None:
|
132 |
+
json_config = load_config_hf(pretrained_model_name)
|
133 |
+
loaded = load_state_dict_hf(pretrained_model_name)
|
134 |
+
elif checkpoint_name is not None and config_name is not None:
|
135 |
+
with open(config_name, 'r') as f:
|
136 |
+
jsonstr = f.read()
|
137 |
+
json_config = json.loads(jsonstr)
|
138 |
+
loaded = torch.load(checkpoint_name, map_location='cpu')
|
139 |
+
else:
|
140 |
+
return
|
141 |
+
model_state_dict = loaded["model"]
|
142 |
+
|
143 |
+
config = MambaConfig(
|
144 |
+
num_layers=json_config['num_layers'],
|
145 |
+
hidden_size=json_config['hidden_size'],
|
146 |
+
state_size=json_config['state_size'],
|
147 |
+
conv_dimension=json_config['conv_dimension'],
|
148 |
+
vocab_size=json_config['vocab_size'],
|
149 |
+
expansion_factor=json_config['expansion_factor'],
|
150 |
+
mamba_moe_layers=json_config['mamba_moe_layers'],
|
151 |
+
ffn_hidden_size=json_config['ffn_hidden_size'],
|
152 |
+
bias = json_config['add_bias_linear'],
|
153 |
+
add_bias_linear = json_config['add_bias_linear'],
|
154 |
+
gated_linear_unit = json_config['swiglu']
|
155 |
+
)
|
156 |
+
|
157 |
+
model = MambaModel(config=config, max_sequence_length=json_config['max_sequence_length'], **kwargs)
|
158 |
+
|
159 |
+
# make keys match
|
160 |
+
model_state_dict["embedding.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
161 |
+
model_state_dict["output_layer.weight"] = model_state_dict["embedding.word_embeddings.weight"].clone()
|
162 |
+
model_state_dict["embedding.word_embeddings.weight"] = None
|
163 |
+
model_state_dict.pop("embedding.word_embeddings.weight")
|
164 |
+
model.load_state_dict(loaded["model"])
|
165 |
+
return model
|
166 |
+
|
167 |
+
def save_pretrained(self, save_directory):
|
168 |
+
"""
|
169 |
+
Minimal implementation of save_pretrained for MambaLMHeadModel.
|
170 |
+
Save the model and its configuration file to a directory.
|
171 |
+
"""
|
172 |
+
# Ensure save_directory exists
|
173 |
+
if not os.path.exists(save_directory):
|
174 |
+
os.makedirs(save_directory)
|
175 |
+
|
176 |
+
# Save the model's state_dict
|
177 |
+
model_path = os.path.join(save_directory, 'pytorch_model.bin')
|
178 |
+
torch.save(self.state_dict(), model_path)
|
179 |
+
|
180 |
+
# Save the configuration of the model
|
181 |
+
config_path = os.path.join(save_directory, 'config.json')
|
182 |
+
with open(config_path, 'w') as f:
|
183 |
+
json.dump(self.config.__dict__, f)
|
mlp.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Union
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from utils import bias_gelu_impl
|
7 |
+
from mamba_config import MambaConfig
|
8 |
+
|
9 |
+
class MLP(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self, config: MambaConfig, is_expert: bool = False, layer_idx=None
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.config: MambaConfig = config
|
16 |
+
self.layer = layer_idx
|
17 |
+
ffn_hidden_size_1 = self.config.ffn_hidden_size
|
18 |
+
ffn_hidden_size_2 = self.config.ffn_hidden_size
|
19 |
+
|
20 |
+
# If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
21 |
+
if self.config.gated_linear_unit:
|
22 |
+
ffn_hidden_size_1 *= 2
|
23 |
+
|
24 |
+
self.linear_fc1 = nn.Linear(self.config.hidden_size, ffn_hidden_size_1, bias = self.config.add_bias_linear, device = self.config.device)
|
25 |
+
self.linear_fc1.is_expert = is_expert
|
26 |
+
|
27 |
+
if self.config.gated_linear_unit:
|
28 |
+
|
29 |
+
def glu(x):
|
30 |
+
x = torch.chunk(x, 2, dim=-1)
|
31 |
+
return self.config.activation_func(x[0]) * x[1]
|
32 |
+
|
33 |
+
self.activation_func = glu
|
34 |
+
else:
|
35 |
+
self.activation_func = self.config.activation_func
|
36 |
+
|
37 |
+
self.linear_fc2 = nn.Linear(ffn_hidden_size_2, self.config.hidden_size, bias = self.config.add_bias_linear, device = self.config.device)
|
38 |
+
|
39 |
+
def forward(self, hidden_states, inference_params=None):
|
40 |
+
intermediate = self.linear_fc1(hidden_states)
|
41 |
+
intermediate = self.activation_func(intermediate)
|
42 |
+
output = self.linear_fc2(intermediate)
|
43 |
+
return output
|
setup.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Albert Gu, Tri Dao.
|
2 |
+
import warnings
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from packaging.version import parse, Version
|
7 |
+
from setuptools import setup, find_packages
|
8 |
+
import subprocess
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.cpp_extension import (
|
13 |
+
BuildExtension,
|
14 |
+
CppExtension,
|
15 |
+
CUDAExtension,
|
16 |
+
CUDA_HOME,
|
17 |
+
)
|
18 |
+
|
19 |
+
PACKAGE_NAME = "blackmamba"
|
20 |
+
VERSION = "0.0.1"
|
21 |
+
|
22 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
23 |
+
long_description = fh.read()
|
24 |
+
|
25 |
+
|
26 |
+
# ninja build does not work unless include_dirs are abs path
|
27 |
+
this_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
|
29 |
+
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
|
30 |
+
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
|
31 |
+
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
|
32 |
+
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
|
33 |
+
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
|
34 |
+
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"
|
35 |
+
|
36 |
+
|
37 |
+
def get_cuda_bare_metal_version(cuda_dir):
|
38 |
+
raw_output = subprocess.check_output(
|
39 |
+
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
40 |
+
)
|
41 |
+
output = raw_output.split()
|
42 |
+
release_idx = output.index("release") + 1
|
43 |
+
bare_metal_version = parse(output[release_idx].split(",")[0])
|
44 |
+
|
45 |
+
return raw_output, bare_metal_version
|
46 |
+
|
47 |
+
|
48 |
+
def check_if_cuda_home_none(global_option: str) -> None:
|
49 |
+
if CUDA_HOME is not None:
|
50 |
+
return
|
51 |
+
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
|
52 |
+
# in that case.
|
53 |
+
warnings.warn(
|
54 |
+
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
|
55 |
+
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
|
56 |
+
"only images whose names contain 'devel' will provide nvcc."
|
57 |
+
)
|
58 |
+
|
59 |
+
|
60 |
+
def append_nvcc_threads(nvcc_extra_args):
|
61 |
+
return nvcc_extra_args + ["--threads", "4"]
|
62 |
+
|
63 |
+
|
64 |
+
ext_modules = []
|
65 |
+
if not SKIP_CUDA_BUILD:
|
66 |
+
print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__))
|
67 |
+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
|
68 |
+
TORCH_MINOR = int(torch.__version__.split(".")[1])
|
69 |
+
|
70 |
+
check_if_cuda_home_none(PACKAGE_NAME)
|
71 |
+
# Check, if CUDA11 is installed for compute capability 8.0
|
72 |
+
cc_flag = []
|
73 |
+
if CUDA_HOME is not None:
|
74 |
+
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
|
75 |
+
if bare_metal_version < Version("11.6"):
|
76 |
+
raise RuntimeError(
|
77 |
+
f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above. "
|
78 |
+
"Note: make sure nvcc has a supported version by running nvcc -V."
|
79 |
+
)
|
80 |
+
|
81 |
+
cc_flag.append("-gencode")
|
82 |
+
cc_flag.append("arch=compute_70,code=sm_70")
|
83 |
+
cc_flag.append("-gencode")
|
84 |
+
cc_flag.append("arch=compute_80,code=sm_80")
|
85 |
+
if bare_metal_version >= Version("11.8"):
|
86 |
+
cc_flag.append("-gencode")
|
87 |
+
cc_flag.append("arch=compute_90,code=sm_90")
|
88 |
+
|
89 |
+
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
|
90 |
+
# torch._C._GLIBCXX_USE_CXX11_ABI
|
91 |
+
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
|
92 |
+
if FORCE_CXX11_ABI:
|
93 |
+
torch._C._GLIBCXX_USE_CXX11_ABI = True
|
94 |
+
|
95 |
+
ext_modules.append(
|
96 |
+
CUDAExtension(
|
97 |
+
name="selective_scan_cuda",
|
98 |
+
sources=[
|
99 |
+
"csrc/selective_scan/selective_scan.cpp",
|
100 |
+
"csrc/selective_scan/selective_scan_fwd_fp32.cu",
|
101 |
+
"csrc/selective_scan/selective_scan_fwd_fp16.cu",
|
102 |
+
"csrc/selective_scan/selective_scan_fwd_bf16.cu",
|
103 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
|
104 |
+
"csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
|
105 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
|
106 |
+
"csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
|
107 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
|
108 |
+
"csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
|
109 |
+
],
|
110 |
+
extra_compile_args={
|
111 |
+
"cxx": ["-O3", "-std=c++17"],
|
112 |
+
"nvcc": append_nvcc_threads(
|
113 |
+
[
|
114 |
+
"-O3",
|
115 |
+
"-std=c++17",
|
116 |
+
"-U__CUDA_NO_HALF_OPERATORS__",
|
117 |
+
"-U__CUDA_NO_HALF_CONVERSIONS__",
|
118 |
+
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
|
119 |
+
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
|
120 |
+
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
|
121 |
+
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
|
122 |
+
"--expt-relaxed-constexpr",
|
123 |
+
"--expt-extended-lambda",
|
124 |
+
"--use_fast_math",
|
125 |
+
"--ptxas-options=-v",
|
126 |
+
"-lineinfo",
|
127 |
+
]
|
128 |
+
+ cc_flag
|
129 |
+
),
|
130 |
+
},
|
131 |
+
include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
|
132 |
+
)
|
133 |
+
)
|
134 |
+
|
135 |
+
|
136 |
+
setup(
|
137 |
+
name=PACKAGE_NAME,
|
138 |
+
version=VERSION,
|
139 |
+
description="Blackmamba state-space + MoE model",
|
140 |
+
long_description=long_description,
|
141 |
+
long_description_content_type="text/markdown",
|
142 |
+
packages=find_packages(include=['ops'],),
|
143 |
+
exclude=(
|
144 |
+
"csrc",
|
145 |
+
"blackmamba.egg-info",
|
146 |
+
),
|
147 |
+
ext_modules=ext_modules,
|
148 |
+
cmdclass={"build_ext": BuildExtension},
|
149 |
+
python_requires=">=3.7",
|
150 |
+
install_requires=[
|
151 |
+
"torch",
|
152 |
+
"packaging",
|
153 |
+
"ninja",
|
154 |
+
"einops",
|
155 |
+
"triton",
|
156 |
+
"transformers",
|
157 |
+
"causal_conv1d>=1.1.0",
|
158 |
+
],
|
159 |
+
)
|
switch_mlp.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import pickle
|
4 |
+
import os
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from mamba_config import MambaConfig
|
8 |
+
from mlp import MLP
|
9 |
+
|
10 |
+
def sinkhorn(cost, tol=0.0001):
|
11 |
+
"Sinkhorn based MoE routing function"
|
12 |
+
cost = torch.exp(2.0 * cost)
|
13 |
+
d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype)
|
14 |
+
# d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype)
|
15 |
+
d1 = 1 / (cost.size(1) * torch.sum(cost, 0))
|
16 |
+
|
17 |
+
eps = 0.00000001
|
18 |
+
error = 1e9
|
19 |
+
d1_old = d1
|
20 |
+
while error > tol:
|
21 |
+
d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps)
|
22 |
+
d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps)
|
23 |
+
error = torch.mean(torch.abs(d1_old - d1))
|
24 |
+
d1_old = d1
|
25 |
+
return d1 * cost * d0.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
class SwitchMLP(nn.Module):
|
29 |
+
"""
|
30 |
+
Top-1 Mixture of Experts Layer. Routes input to one of N MLP "experts"
|
31 |
+
Curently supports Sinkhorn based expert routing.
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, config: MambaConfig, layer_idx=None):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.layer = layer_idx
|
38 |
+
self.config: MambaConfig = config
|
39 |
+
if config.mamba_moe_layers:
|
40 |
+
self.num_moe_experts = int(config.mamba_moe_layers[layer_idx-1][-1])
|
41 |
+
else:
|
42 |
+
self.num_moe_experts = self.config.num_moe_experts
|
43 |
+
self.router = torch.nn.Linear(self.config.hidden_size, self.num_moe_experts)
|
44 |
+
self.add_bias = config.add_bias_linear
|
45 |
+
self.routing = config.routing_mode # 'sinkhorn', 'top1', 'top2', 'sinkhorn_top2'
|
46 |
+
self.route_algo = sinkhorn
|
47 |
+
self.router_activation = torch.sigmoid
|
48 |
+
|
49 |
+
self.num_local_experts = self.num_moe_experts
|
50 |
+
self.local_expert_indices = [i for i in range(self.num_local_experts)]
|
51 |
+
|
52 |
+
self.local_experts = torch.nn.ModuleList()
|
53 |
+
for _ in range(self.num_local_experts):
|
54 |
+
expert = MLP(self.config, is_expert=True, layer_idx=layer_idx)
|
55 |
+
self.local_experts.append(expert)
|
56 |
+
|
57 |
+
def gather_indices(self, local_indices):
|
58 |
+
return local_indices
|
59 |
+
|
60 |
+
def forward(self, hidden_states, inference_params=None):
|
61 |
+
|
62 |
+
hidden_shape = hidden_states.shape
|
63 |
+
route = self.router(hidden_states)
|
64 |
+
route = route.view(-1, self.num_moe_experts)
|
65 |
+
|
66 |
+
if self.routing == 'sinkhorn':
|
67 |
+
route = self.router_activation(route)
|
68 |
+
max_prob, max_ind = torch.max(route, dim=1)
|
69 |
+
else:
|
70 |
+
route = torch.softmax(route, dim=1)
|
71 |
+
max_prob, max_ind = torch.max(route, dim=1)
|
72 |
+
|
73 |
+
max_prob = torch.unsqueeze(max_prob, 1)
|
74 |
+
hidden_states = hidden_states.view(-1, hidden_shape[-1])
|
75 |
+
|
76 |
+
global_hidden_states = hidden_states
|
77 |
+
global_indices = max_ind
|
78 |
+
output_total = torch.zeros_like(global_hidden_states)
|
79 |
+
|
80 |
+
|
81 |
+
for expert_num, expert in enumerate(self.local_experts):
|
82 |
+
local_expert_index = self.local_expert_indices[expert_num]
|
83 |
+
local_indices = (global_indices == local_expert_index).nonzero()
|
84 |
+
hidden = global_hidden_states[local_indices, :]
|
85 |
+
output = expert(hidden)
|
86 |
+
output_total[local_indices, :] = output
|
87 |
+
|
88 |
+
output_total = output_total * max_prob
|
89 |
+
output_total = output_total.view(hidden_shape)
|
90 |
+
|
91 |
+
return output_total
|
utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from operator import itemgetter
|
2 |
+
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def attention_mask_func(attention_scores, attention_mask):
|
8 |
+
attention_scores.masked_fill_(attention_mask, -10000.0)
|
9 |
+
return attention_scores
|
10 |
+
|
11 |
+
|
12 |
+
@torch.jit.script
|
13 |
+
def gelu_impl(x):
|
14 |
+
"""OpenAI's gelu implementation."""
|
15 |
+
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
|
16 |
+
|
17 |
+
|
18 |
+
def openai_gelu(x):
|
19 |
+
return gelu_impl(x)
|
20 |
+
|
21 |
+
|
22 |
+
@torch.jit.script
|
23 |
+
def bias_gelu(bias, y):
|
24 |
+
x = bias + y
|
25 |
+
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)))
|
26 |
+
|
27 |
+
|
28 |
+
# gradient of tanh approximation of gelu
|
29 |
+
# gradient of actual gelu is:
|
30 |
+
# 0.5 * (1. + torch.erf(x * 0.70710678)) + 0.3989423 * x * torch.exp(-0.5 * x * x)
|
31 |
+
@torch.jit.script
|
32 |
+
def bias_gelu_back(g, bias, y):
|
33 |
+
x = bias + y
|
34 |
+
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))
|
35 |
+
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
|
36 |
+
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (
|
37 |
+
1 + tanh_out
|
38 |
+
)
|
39 |
+
return ff * g
|
40 |
+
|
41 |
+
|
42 |
+
class GeLUFunction(torch.autograd.Function):
|
43 |
+
@staticmethod
|
44 |
+
# bias is an optional argument
|
45 |
+
def forward(ctx, input, bias):
|
46 |
+
ctx.save_for_backward(input, bias)
|
47 |
+
return bias_gelu(bias, input)
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def backward(ctx, grad_output):
|
51 |
+
input, bias = ctx.saved_tensors
|
52 |
+
tmp = bias_gelu_back(grad_output, bias, input)
|
53 |
+
return tmp, tmp
|
54 |
+
|
55 |
+
|
56 |
+
bias_gelu_impl = GeLUFunction.apply
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
# This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
|
61 |
+
@torch.jit.script
|
62 |
+
def erf_gelu(x):
|
63 |
+
return (
|
64 |
+
x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype) + torch.ones_like(x).to(dtype=x.dtype))
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
+
def init_method_normal(sigma):
|
69 |
+
|
70 |
+
def init_(tensor):
|
71 |
+
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
|
72 |
+
|
73 |
+
return init_
|
74 |
+
|
75 |
+
|
76 |
+
def scaled_init_method_normal(sigma, num_layers):
|
77 |
+
std = sigma / math.sqrt(2.0 * num_layers)
|
78 |
+
|
79 |
+
def init_(tensor):
|
80 |
+
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
|
81 |
+
|
82 |
+
return init_
|