Spaces:
Sleeping
Sleeping
Update mingru_lm.py
Browse files- mingru_lm.py +15 -12
mingru_lm.py
CHANGED
@@ -59,18 +59,6 @@ class MinGRU(Module):
|
|
59 |
return out
|
60 |
return out, next_prev_hidden
|
61 |
|
62 |
-
if __name__ == "__main__":
|
63 |
-
x = torch.rand(2,256,512)
|
64 |
-
model = MinGRU(dim=512)
|
65 |
-
out , next_prev_hidden = model(x,return_next_prev_hidden=True)
|
66 |
-
|
67 |
-
|
68 |
-
print("out",out[0,0,:3])
|
69 |
-
print("next_prev_hidden",next_prev_hidden[0,0,:3])
|
70 |
-
print("out shape",out.shape)
|
71 |
-
print("X shape",x.shape)
|
72 |
-
assert x.shape == out.shape
|
73 |
-
|
74 |
|
75 |
class FeedForward(nn.Module):
|
76 |
def __init__(self, dim, mult=4):
|
@@ -85,6 +73,20 @@ class FeedForward(nn.Module):
|
|
85 |
def forward(self, x):
|
86 |
return self.net(x)
|
87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
class RMSNorm(nn.Module):
|
89 |
def __init__(self, dim):
|
90 |
super().__init__()
|
@@ -98,6 +100,7 @@ class MinGRU_Layers(nn.Module):
|
|
98 |
def __init__(self, dim, num_tokens):
|
99 |
super().__init__()
|
100 |
self.emb = nn.Embedding(num_tokens, dim)
|
|
|
101 |
self.rms_norm = RMSNorm(dim)
|
102 |
self.gru = MinGRU(dim)
|
103 |
self.ff = FeedForward(dim)
|
|
|
59 |
return out
|
60 |
return out, next_prev_hidden
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
class FeedForward(nn.Module):
|
64 |
def __init__(self, dim, mult=4):
|
|
|
73 |
def forward(self, x):
|
74 |
return self.net(x)
|
75 |
|
76 |
+
class CausalDepthWiseConv1d(nn.Module):
|
77 |
+
def __init__(self, dim, kernel_size):
|
78 |
+
super().__init__()
|
79 |
+
self.kernel_size = kernel_size
|
80 |
+
self.net = nn.Sequential(
|
81 |
+
nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim),
|
82 |
+
nn.Conv1d(dim, dim, kernel_size = 1)
|
83 |
+
)
|
84 |
+
def forward(self, x):
|
85 |
+
x = x.transpose(1, 2) # b n d -> b d n
|
86 |
+
x = F.pad(x, (self.kernel_size - 1, 0), value = 0.)
|
87 |
+
x = self.net(x)
|
88 |
+
return x.transpose(1, 2) # b d n -> b n d
|
89 |
+
|
90 |
class RMSNorm(nn.Module):
|
91 |
def __init__(self, dim):
|
92 |
super().__init__()
|
|
|
100 |
def __init__(self, dim, num_tokens):
|
101 |
super().__init__()
|
102 |
self.emb = nn.Embedding(num_tokens, dim)
|
103 |
+
self.casual_depth = CausalDepthWiseConv1d(dim=dim,kernel_size=3)
|
104 |
self.rms_norm = RMSNorm(dim)
|
105 |
self.gru = MinGRU(dim)
|
106 |
self.ff = FeedForward(dim)
|