File size: 2,372 Bytes
24d1036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
from Vocabulary import Vocabulary
from torch.utils.data import Dataset
from torch import nn
from utils import Callable_tokenizer

class MT_Dataset(Dataset):

    def __init__(self, input_sentences_list, target_sentences_list, input_tokenizer:Callable_tokenizer, target_tokenizer:Callable_tokenizer, reversed_input:bool):

        super(MT_Dataset, self).__init__()

        assert len(input_sentences_list) == len(target_sentences_list), (f"Length mismatch: input has {len(input_sentences_list)} sentences, "f"but target has {len(target_sentences_list)} sentences.")
        
        self.input_sentences_list = input_sentences_list
        self.target_sentences_list = target_sentences_list
        self.input_tokenizer = input_tokenizer
        self.target_tokenizer = target_tokenizer
        self.reversed_input = reversed_input
        # self.maxlen = maxlen

    def __len__(self):
        return len(self.input_sentences_list)

    def __getitem__(self, index):
        input, target = self.input_sentences_list[index], self.target_sentences_list[index]

        input_numrical_tokens = [self.input_tokenizer.get_tokenId('<s>')] + self.input_tokenizer(input) + [self.input_tokenizer.get_tokenId('</s>')]
        target_numrical_tokens = [self.target_tokenizer.get_tokenId('<s>')] +  self.target_tokenizer(target) + [self.target_tokenizer.get_tokenId('</s>')]
        
        input_tensor_tokens = torch.tensor(input_numrical_tokens)
        target_tensor_tokens = torch.tensor(target_numrical_tokens)

        if self.reversed_input: input_tensor_tokens = input_tensor_tokens.flip(0)

        return input_tensor_tokens, target_tensor_tokens
    

class MYCollate():
    def __init__(self, batch_first=True, pad_value=0):
        self.pad_value = pad_value
        self.batch_first = batch_first

    def __call__(self, data):
        en_stentences = [ex[0] for ex in data]
        ar_stentences = [ex[1] for ex in data]

        padded_en_stentences = nn.utils.rnn.pad_sequence(en_stentences, batch_first=self.batch_first,
                                                      padding_value=self.pad_value)
        padded_ar_stentences = nn.utils.rnn.pad_sequence(ar_stentences, batch_first=self.batch_first,
                                                      padding_value=self.pad_value)
        return padded_en_stentences, padded_ar_stentences