|
import torch
|
|
from torch import nn
|
|
import random
|
|
|
|
class Encoder(nn.Module):
|
|
def __init__(self, embd_matrix:torch.Tensor, pretrained:bool, lstm_hidden_size, lstm_layers=1, dropout_probability=0.1):
|
|
super().__init__()
|
|
self.hidden_size = lstm_hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
self.input_size = embd_matrix.size(-1)
|
|
self.embd_layer = nn.Embedding.from_pretrained(embd_matrix, freeze=False) if pretrained else nn.Embedding(embd_matrix.size(0), embd_matrix.size(1))
|
|
|
|
self.dropout = nn.Dropout(dropout_probability)
|
|
self.gru = nn.GRU(self.input_size, self.hidden_size, self.lstm_layers, dropout=dropout_probability, batch_first=True, bidirectional=True)
|
|
self.fc = nn.Linear(lstm_hidden_size * 2, lstm_hidden_size)
|
|
|
|
def forward(self, x):
|
|
embds = self.dropout(self.embd_layer(x))
|
|
output, hidden = self.gru(embds)
|
|
|
|
|
|
|
|
hidden = torch.cat([hidden[-2,:,:], hidden[-1,:,:]], dim=1)
|
|
hidden = torch.tanh(self.fc(hidden))
|
|
|
|
return output, hidden
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
|
|
super().__init__()
|
|
self.attn_fc = nn.Linear((encoder_hidden_dim * 2) + decoder_hidden_dim, decoder_hidden_dim)
|
|
self.v_fc = nn.Linear(decoder_hidden_dim, 1, bias=False)
|
|
|
|
def forward(self, hidden, encoder_outputs):
|
|
|
|
|
|
batch_size = encoder_outputs.shape[0]
|
|
src_length = encoder_outputs.shape[1]
|
|
|
|
hidden = hidden.unsqueeze(1).repeat(1, src_length, 1)
|
|
|
|
|
|
pre_energy = torch.cat((hidden, encoder_outputs), dim=2)
|
|
energy = torch.tanh(self.attn_fc(pre_energy))
|
|
|
|
|
|
attention = self.v_fc(energy).squeeze(2)
|
|
|
|
return torch.softmax(attention, dim=1)
|
|
|
|
|
|
class Decoder(nn.Module):
|
|
def __init__(self, embd_matrix:torch.Tensor, pretrained:bool, attention:Attention, lstm_hidden_size, lstm_layers=1, dropout_probability=0):
|
|
super().__init__()
|
|
self.hidden_size = lstm_hidden_size
|
|
self.lstm_layers = lstm_layers
|
|
self.input_size = embd_matrix.size(-1)
|
|
self.embd_layer = nn.Embedding.from_pretrained(embd_matrix, freeze=False) if pretrained else nn.Embedding(embd_matrix.size(0), embd_matrix.size(1))
|
|
self.attention = attention
|
|
|
|
self.dropout = nn.Dropout(dropout_probability)
|
|
self.gru = nn.GRU((lstm_hidden_size * 2) + self.input_size, self.hidden_size, self.lstm_layers, dropout=dropout_probability, batch_first=True)
|
|
self.fc_out = nn.Linear(self.hidden_size, embd_matrix.size(0))
|
|
|
|
def forward(self, x, hidden_t_1, encoder_outputs):
|
|
embds = self.dropout(self.embd_layer(x))
|
|
a = self.attention(hidden_t_1, encoder_outputs)
|
|
a = a.unsqueeze(1)
|
|
weighted = torch.bmm(a, encoder_outputs)
|
|
rnn_input = torch.cat((embds, weighted), dim=2)
|
|
|
|
output, hidden_t = self.gru(rnn_input, hidden_t_1.unsqueeze(0))
|
|
|
|
output = output.squeeze(1)
|
|
|
|
|
|
|
|
prediction = self.fc_out(output)
|
|
|
|
return prediction, hidden_t.squeeze(0), a.squeeze(1)
|
|
|
|
|
|
|
|
class Seq2seq_with_attention(nn.Module):
|
|
def __init__(self, encoder:Encoder, decoder:Decoder):
|
|
super(Seq2seq_with_attention, self).__init__()
|
|
|
|
self.decoder_vocab_size = decoder.embd_layer.weight.size(0)
|
|
self.encoder = encoder
|
|
self.decoder = decoder
|
|
|
|
def forward(self, source, target, teacher_force_ratio=0.5):
|
|
|
|
batch_size, seq_len = target.size()
|
|
|
|
total_outputs = torch.zeros(batch_size, seq_len, self.decoder_vocab_size, device=source.device)
|
|
|
|
encoder_outputs, hidden = self.encoder(source)
|
|
|
|
x = target[:, [0]]
|
|
for step in range(1, seq_len):
|
|
logits, hidden, _ = self.decoder(x, hidden, encoder_outputs)
|
|
|
|
total_outputs[:, step] = logits
|
|
top1 = logits.argmax(1, keepdim=True)
|
|
x = target[:, [step]] if teacher_force_ratio > random.random() else top1
|
|
|
|
return total_outputs
|
|
|
|
@torch.no_grad
|
|
def translate(self, source:torch.Tensor, max_tries=50):
|
|
output = [2]
|
|
|
|
targets = torch.randint(0,1,(source.size(0), max_tries)).to(device=source.device)
|
|
targets_hat = self.forward(source, targets, 0.0)
|
|
targets_hat = targets_hat.argmax(-1).squeeze(0)
|
|
for token in targets_hat[1:]:
|
|
output.append(token.item())
|
|
if token == 3:
|
|
break
|
|
|
|
return output |