import torch from Vocabulary import Vocabulary from torch.utils.data import Dataset from torch import nn from utils import Callable_tokenizer class MT_Dataset(Dataset): def __init__(self, input_sentences_list, target_sentences_list, input_tokenizer:Callable_tokenizer, target_tokenizer:Callable_tokenizer, reversed_input:bool): super(MT_Dataset, self).__init__() assert len(input_sentences_list) == len(target_sentences_list), (f"Length mismatch: input has {len(input_sentences_list)} sentences, "f"but target has {len(target_sentences_list)} sentences.") self.input_sentences_list = input_sentences_list self.target_sentences_list = target_sentences_list self.input_tokenizer = input_tokenizer self.target_tokenizer = target_tokenizer self.reversed_input = reversed_input # self.maxlen = maxlen def __len__(self): return len(self.input_sentences_list) def __getitem__(self, index): input, target = self.input_sentences_list[index], self.target_sentences_list[index] input_numrical_tokens = [self.input_tokenizer.get_tokenId('')] + self.input_tokenizer(input) + [self.input_tokenizer.get_tokenId('')] target_numrical_tokens = [self.target_tokenizer.get_tokenId('')] + self.target_tokenizer(target) + [self.target_tokenizer.get_tokenId('')] input_tensor_tokens = torch.tensor(input_numrical_tokens) target_tensor_tokens = torch.tensor(target_numrical_tokens) if self.reversed_input: input_tensor_tokens = input_tensor_tokens.flip(0) return input_tensor_tokens, target_tensor_tokens class MYCollate(): def __init__(self, batch_first=True, pad_value=0): self.pad_value = pad_value self.batch_first = batch_first def __call__(self, data): en_stentences = [ex[0] for ex in data] ar_stentences = [ex[1] for ex in data] padded_en_stentences = nn.utils.rnn.pad_sequence(en_stentences, batch_first=self.batch_first, padding_value=self.pad_value) padded_ar_stentences = nn.utils.rnn.pad_sequence(ar_stentences, batch_first=self.batch_first, padding_value=self.pad_value) return padded_en_stentences, padded_ar_stentences