damerajee commited on
Commit
fa23921
·
verified ·
1 Parent(s): a39f8ba

Update mingru_lm.py

Browse files
Files changed (1) hide show
  1. mingru_lm.py +15 -12
mingru_lm.py CHANGED
@@ -59,18 +59,6 @@ class MinGRU(Module):
59
  return out
60
  return out, next_prev_hidden
61
 
62
- if __name__ == "__main__":
63
- x = torch.rand(2,256,512)
64
- model = MinGRU(dim=512)
65
- out , next_prev_hidden = model(x,return_next_prev_hidden=True)
66
-
67
-
68
- print("out",out[0,0,:3])
69
- print("next_prev_hidden",next_prev_hidden[0,0,:3])
70
- print("out shape",out.shape)
71
- print("X shape",x.shape)
72
- assert x.shape == out.shape
73
-
74
 
75
  class FeedForward(nn.Module):
76
  def __init__(self, dim, mult=4):
@@ -85,6 +73,20 @@ class FeedForward(nn.Module):
85
  def forward(self, x):
86
  return self.net(x)
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  class RMSNorm(nn.Module):
89
  def __init__(self, dim):
90
  super().__init__()
@@ -98,6 +100,7 @@ class MinGRU_Layers(nn.Module):
98
  def __init__(self, dim, num_tokens):
99
  super().__init__()
100
  self.emb = nn.Embedding(num_tokens, dim)
 
101
  self.rms_norm = RMSNorm(dim)
102
  self.gru = MinGRU(dim)
103
  self.ff = FeedForward(dim)
 
59
  return out
60
  return out, next_prev_hidden
61
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  class FeedForward(nn.Module):
64
  def __init__(self, dim, mult=4):
 
73
  def forward(self, x):
74
  return self.net(x)
75
 
76
+ class CausalDepthWiseConv1d(nn.Module):
77
+ def __init__(self, dim, kernel_size):
78
+ super().__init__()
79
+ self.kernel_size = kernel_size
80
+ self.net = nn.Sequential(
81
+ nn.Conv1d(dim, dim, kernel_size = kernel_size, groups = dim),
82
+ nn.Conv1d(dim, dim, kernel_size = 1)
83
+ )
84
+ def forward(self, x):
85
+ x = x.transpose(1, 2) # b n d -> b d n
86
+ x = F.pad(x, (self.kernel_size - 1, 0), value = 0.)
87
+ x = self.net(x)
88
+ return x.transpose(1, 2) # b d n -> b n d
89
+
90
  class RMSNorm(nn.Module):
91
  def __init__(self, dim):
92
  super().__init__()
 
100
  def __init__(self, dim, num_tokens):
101
  super().__init__()
102
  self.emb = nn.Embedding(num_tokens, dim)
103
+ self.casual_depth = CausalDepthWiseConv1d(dim=dim,kernel_size=3)
104
  self.rms_norm = RMSNorm(dim)
105
  self.gru = MinGRU(dim)
106
  self.ff = FeedForward(dim)