File size: 2,372 Bytes
24d1036 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import torch
from Vocabulary import Vocabulary
from torch.utils.data import Dataset
from torch import nn
from utils import Callable_tokenizer
class MT_Dataset(Dataset):
def __init__(self, input_sentences_list, target_sentences_list, input_tokenizer:Callable_tokenizer, target_tokenizer:Callable_tokenizer, reversed_input:bool):
super(MT_Dataset, self).__init__()
assert len(input_sentences_list) == len(target_sentences_list), (f"Length mismatch: input has {len(input_sentences_list)} sentences, "f"but target has {len(target_sentences_list)} sentences.")
self.input_sentences_list = input_sentences_list
self.target_sentences_list = target_sentences_list
self.input_tokenizer = input_tokenizer
self.target_tokenizer = target_tokenizer
self.reversed_input = reversed_input
# self.maxlen = maxlen
def __len__(self):
return len(self.input_sentences_list)
def __getitem__(self, index):
input, target = self.input_sentences_list[index], self.target_sentences_list[index]
input_numrical_tokens = [self.input_tokenizer.get_tokenId('<s>')] + self.input_tokenizer(input) + [self.input_tokenizer.get_tokenId('</s>')]
target_numrical_tokens = [self.target_tokenizer.get_tokenId('<s>')] + self.target_tokenizer(target) + [self.target_tokenizer.get_tokenId('</s>')]
input_tensor_tokens = torch.tensor(input_numrical_tokens)
target_tensor_tokens = torch.tensor(target_numrical_tokens)
if self.reversed_input: input_tensor_tokens = input_tensor_tokens.flip(0)
return input_tensor_tokens, target_tensor_tokens
class MYCollate():
def __init__(self, batch_first=True, pad_value=0):
self.pad_value = pad_value
self.batch_first = batch_first
def __call__(self, data):
en_stentences = [ex[0] for ex in data]
ar_stentences = [ex[1] for ex in data]
padded_en_stentences = nn.utils.rnn.pad_sequence(en_stentences, batch_first=self.batch_first,
padding_value=self.pad_value)
padded_ar_stentences = nn.utils.rnn.pad_sequence(ar_stentences, batch_first=self.batch_first,
padding_value=self.pad_value)
return padded_en_stentences, padded_ar_stentences |