import torch from torch import nn import random class Encoder(nn.Module): def __init__(self, encoder_vocab_size, encoder_embding_dims, encoder_hidden_size, decoder_hidden_size, encoder_rnn_layers, p): super().__init__() self.embd_layer = nn.Embedding(encoder_vocab_size, encoder_embding_dims) self.dropout = nn.Dropout() self.rnn = nn.GRU(encoder_embding_dims, encoder_hidden_size, encoder_rnn_layers, batch_first=True, dropout=p, bidirectional=True) self.fc_map = nn.Linear(encoder_hidden_size*2, decoder_hidden_size) def forward(self, x): embds = self.dropout(self.embd_layer(x)) context, hidden = self.rnn(embds) last_hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=-1) to_decoder_hidden = self.fc_map(last_hidden) return context, to_decoder_hidden ########################--------------------################################## ########################--------------------################################## ########################--------------------################################## class Attention(nn.Module): def __init__(self, encoder_output_dims, decoder_hidden_dims): super().__init__() self.fc_downscale = nn.Linear((encoder_output_dims*2)+decoder_hidden_dims, decoder_hidden_dims) self.alpha = nn.Linear(decoder_hidden_dims, 1, bias=False) def forward(self, encoder_output, decoder_hidden): batch_size = encoder_output.size(0) seq_len = encoder_output.size(1) ## encoder_output (batch, seq_len, encoder_hidden_dims*2) decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1) ## (batch, seq_len, decoder_hidden_dims) energy = self.fc_downscale(torch.cat((decoder_hidden, encoder_output), dim=-1)) alphas = self.alpha(energy).squeeze(-1) return torch.softmax(alphas, dim=-1) ########################--------------------################################## ########################--------------------################################## ########################--------------------################################## class Decoder(nn.Module): def __init__(self, decoder_vocab_size, decoder_embding_dims, decoder_hidden_size, encoder_hidden_size, attention): super().__init__() self.attention = attention self.embd_layer = nn.Embedding(decoder_vocab_size, decoder_embding_dims) self.rnn = nn.GRU((encoder_hidden_size*2) + decoder_embding_dims, decoder_hidden_size, batch_first=True) self.classifier = nn.Linear(decoder_hidden_size, decoder_vocab_size) def forward(self, x, encoder_output, hidden_t_1): ## hidden_t_1 shape: (batch, decoder_hidden_dims) ## encoder_output shape : (batch, seq_len, encoder_hidden_dims * 2) ## x shape: (batch, 1) one token embds = self.embd_layer(x) ## (batch, 1, embd_dims) alphas = self.attention(encoder_output, hidden_t_1).unsqueeze(1) ## (batch, 1, seq_len) attention = torch.bmm(alphas, encoder_output) ## (batch, 1, encoder_output) rnn_input = torch.cat((embds, attention), dim=-1) ## (batch, 1, (encoder_hidden_size*2) + decoder_embding_dims) output, hidden_t = self.rnn(rnn_input, hidden_t_1.unsqueeze(0)) output = output.squeeze(1) ## (batch, decoder_hidden_size) prediction = self.classifier(output) ## (batch, decoder_vocab_size) return prediction, hidden_t.squeeze(0), alphas.squeeze(1) ## "a" is returned for visualization ########################--------------------################################## ########################--------------------################################## ########################--------------------################################## class Seq2seq_with_attention(nn.Module): def __init__(self, encoder_vocab_len, encoder_embding_dims, encoder_hidden_size, encoder_rnn_layers, encoder_dropout, decoder_vocab_len, decoder_embding_dims, decoder_hidden_size): super().__init__() self.decoder_vocab_len = decoder_vocab_len self.encoder = Encoder(encoder_vocab_len, encoder_embding_dims, encoder_hidden_size, decoder_hidden_size, encoder_rnn_layers, encoder_dropout) self.attention = Attention(encoder_hidden_size, decoder_hidden_size) self.decoder = Decoder(decoder_vocab_len, decoder_embding_dims, decoder_hidden_size, encoder_hidden_size, self.attention) def forward(self, x:torch.tensor, y:torch.tensor, teacher_force_ratio=0.5): batch_size, seq_len = y.size() total_outputs = torch.zeros(batch_size, seq_len, self.decoder_vocab_len, device=x.device) context, hidden = self.encoder(x) step_token = y[:, [0]] for step in range(1, seq_len): logits, hidden, alphas = self.decoder(step_token, context, hidden) total_outputs[:, step] = logits top1 = logits.argmax(1, keepdim=True) x = y[:, [step]] if teacher_force_ratio > random.random() else top1 return total_outputs @torch.no_grad def translate(self, source:torch.Tensor, max_tries=50): output = [2] ## token targets = torch.randint(0,1,(source.size(0), max_tries)).to(device=source.device) targets_hat = self.forward(source, targets, 0.0) targets_hat = targets_hat.argmax(-1).squeeze(0) for token in targets_hat[1:]: output.append(token.item()) if token == 3: break return output