|
import torch |
|
from Vocabulary import Vocabulary |
|
from torch.utils.data import Dataset |
|
from torch import nn |
|
from utils import Callable_tokenizer |
|
|
|
class MT_Dataset(Dataset): |
|
|
|
def __init__(self, input_sentences_list, target_sentences_list, input_tokenizer:Callable_tokenizer, target_tokenizer:Callable_tokenizer, reversed_input:bool): |
|
|
|
super(MT_Dataset, self).__init__() |
|
|
|
assert len(input_sentences_list) == len(target_sentences_list), (f"Length mismatch: input has {len(input_sentences_list)} sentences, "f"but target has {len(target_sentences_list)} sentences.") |
|
|
|
self.input_sentences_list = input_sentences_list |
|
self.target_sentences_list = target_sentences_list |
|
self.input_tokenizer = input_tokenizer |
|
self.target_tokenizer = target_tokenizer |
|
self.reversed_input = reversed_input |
|
|
|
|
|
def __len__(self): |
|
return len(self.input_sentences_list) |
|
|
|
def __getitem__(self, index): |
|
input, target = self.input_sentences_list[index], self.target_sentences_list[index] |
|
|
|
input_numrical_tokens = [self.input_tokenizer.get_tokenId('<s>')] + self.input_tokenizer(input) + [self.input_tokenizer.get_tokenId('</s>')] |
|
target_numrical_tokens = [self.target_tokenizer.get_tokenId('<s>')] + self.target_tokenizer(target) + [self.target_tokenizer.get_tokenId('</s>')] |
|
|
|
input_tensor_tokens = torch.tensor(input_numrical_tokens) |
|
target_tensor_tokens = torch.tensor(target_numrical_tokens) |
|
|
|
if self.reversed_input: input_tensor_tokens = input_tensor_tokens.flip(0) |
|
|
|
return input_tensor_tokens, target_tensor_tokens |
|
|
|
|
|
class MYCollate(): |
|
def __init__(self, batch_first=True, pad_value=0): |
|
self.pad_value = pad_value |
|
self.batch_first = batch_first |
|
|
|
def __call__(self, data): |
|
en_stentences = [ex[0] for ex in data] |
|
ar_stentences = [ex[1] for ex in data] |
|
|
|
padded_en_stentences = nn.utils.rnn.pad_sequence(en_stentences, batch_first=self.batch_first, |
|
padding_value=self.pad_value) |
|
padded_ar_stentences = nn.utils.rnn.pad_sequence(ar_stentences, batch_first=self.batch_first, |
|
padding_value=self.pad_value) |
|
return padded_en_stentences, padded_ar_stentences |