|
1 | | -import numpy as np |
2 | | -import torch |
3 | 1 | import torch.nn as nn |
4 | 2 |
|
5 | 3 | from model.encoder import Encoder |
|
9 | 7 | class Transformer(nn.Module): |
10 | 8 | def __init__(self, params): |
11 | 9 | super(Transformer, self).__init__() |
12 | | - self.params = params |
13 | | - self.hidden_dim = params.hidden_dim |
14 | | - |
15 | | - self.device = params.device |
16 | 10 | self.encoder = Encoder(params) |
17 | 11 | self.decoder = Decoder(params) |
18 | 12 |
|
19 | | - def create_subsequent_mask(self, target): |
20 | | - # target = [batch size, target length] |
21 | | - |
22 | | - batch_size, target_length = target.size() |
23 | | - ''' |
24 | | - if target length is 5 and diagonal is 1, this function returns |
25 | | - [[0, 1, 1, 1, 1], |
26 | | - [0, 0, 1, 1, 1], |
27 | | - [0, 0, 0, 1, 1], |
28 | | - [0, 0, 0, 0, 1], |
29 | | - [0, 0, 0, 0, 1]] |
30 | | - ''' |
31 | | - # torch.triu returns the upper triangular part of a matrix based on user defined diagonal |
32 | | - subsequent_mask = torch.triu(torch.ones(target_length, target_length), diagonal=1).bool().to(self.device) |
33 | | - # subsequent_mask = [target length, target length] |
34 | | - |
35 | | - # repeat subsequent mask 'batch size' times to cover all data instances in the batch |
36 | | - subsequent_mask = subsequent_mask.unsqueeze(0).repeat(batch_size, 1, 1) |
37 | | - # subsequent_mask = [batch size, target length, target length] |
38 | | - |
39 | | - return subsequent_mask |
40 | | - |
41 | | - def create_mask(self, source, target, subsequent_mask): |
42 | | - # source = [batch size, source length] |
43 | | - # target = [batch size, target length] |
44 | | - # subsequent_mask = [batch size, target length, target length] |
45 | | - source_length = source.shape[1] |
46 | | - target_length = target.shape[1] |
47 | | - |
48 | | - # create boolean tensors which will be used to mask padding tokens of both source and target sentence |
49 | | - source_mask = (source == self.params.pad_idx) |
50 | | - target_mask = (target == self.params.pad_idx) |
51 | | - # source_mask = [batch size, source length] |
52 | | - # target_mask = [batch size, target length] |
53 | | - ''' |
54 | | - if sentence is [2, 193, 9, 27, 10003, 1, 1, 1, 3] and 2 denotes <sos>, 3 denotes <eos> and 1 denotes <pad> |
55 | | - masking tensor will be [False, False, False, False, False, True, True, True, False] |
56 | | - ''' |
57 | | - # repeat sentence masking tensors 'sentence length' times |
58 | | - dec_enc_mask = source_mask.unsqueeze(1).repeat(1, target_length, 1) |
59 | | - source_mask = source_mask.unsqueeze(1).repeat(1, source_length, 1) |
60 | | - target_mask = target_mask.unsqueeze(1).repeat(1, target_length, 1) |
61 | | - |
62 | | - # source_mask = [batch size, source length, source length] |
63 | | - # target_mask = [batch size, target length, target length] |
64 | | - # dec_enc_mask = [batch size, target length, source length] |
65 | | - |
66 | | - # combine <pad> token masking tensor and subsequent masking tensor for decoder's self attention |
67 | | - target_mask = target_mask | subsequent_mask |
68 | | - # target_mask = [batch size, target length, target length] |
69 | | - |
70 | | - return source_mask, target_mask, dec_enc_mask |
71 | | - |
72 | | - def create_non_pad_mask(self, sentence): |
73 | | - # padding token shouldn't be used for the output tensor |
74 | | - # to use only non padding token, create non-pad masking tensor |
75 | | - return sentence.ne(self.params.pad_idx).type(torch.float).unsqueeze(-1) |
76 | | - |
77 | | - def create_positional_encoding(self, batch_size, sentence_len): |
78 | | - # PE(pos, 2i) = sin(pos/10000 ** (2*i / hidden_dim) |
79 | | - # PE(pos, 2i + 1) = cos(pos/10000 ** (2*i / hidden_dim) |
80 | | - sinusoid_table = np.array([pos/np.power(10000, 2*i/self.hidden_dim) |
81 | | - for pos in range(sentence_len) for i in range(self.hidden_dim)]) |
82 | | - # sinusoid_table = [sentence length * hidden dim] |
83 | | - |
84 | | - sinusoid_table = sinusoid_table.reshape(sentence_len, -1) |
85 | | - # sinusoid_table = [sentence length, hidden dim] |
86 | | - |
87 | | - sinusoid_table[0::2, :] = np.sin(sinusoid_table[0::2, :]) # calculate pe for even numbers |
88 | | - sinusoid_table[1::2, :] = np.sin(sinusoid_table[1::2, :]) # calculate pe for odd numbers |
89 | | - |
90 | | - # convert numpy based sinusoid table to torch.tensor and repeat it 'batch size' times |
91 | | - sinusoid_table = torch.FloatTensor(sinusoid_table).to(self.device) |
92 | | - sinusoid_table = sinusoid_table.unsqueeze(0).repeat(batch_size, 1, 1) |
93 | | - # sinusoid_table = [batch size, sentence length, hidden dim] |
94 | | - |
95 | | - return sinusoid_table |
96 | | - |
97 | 13 | def forward(self, source, target): |
98 | 14 | # source = [batch size, source length] |
99 | 15 | # target = [batch size, target length] |
100 | | - source_batch, source_len = source.size() |
101 | | - target_batch, target_len = target.size() |
102 | | - |
103 | | - # create masking tensor for self attention (encoder & decoder) and decoder's attention on the output of encoder |
104 | | - subsequent_mask = self.create_subsequent_mask(target) |
105 | | - source_mask, target_mask, dec_enc_mask = self.create_mask(source, target, subsequent_mask) |
106 | | - |
107 | | - # create non-pad masking tensor which will be used to extract non-padded tokens from output |
108 | | - source_non_pad = self.create_non_pad_mask(source) |
109 | | - target_non_pad = self.create_non_pad_mask(target) |
110 | | - # non_pad = [batch size, sentence length, 1] |
111 | | - |
112 | | - source_positional_encoding = self.create_positional_encoding(source_batch, source_len) |
113 | | - target_positional_encoding = self.create_positional_encoding(target_batch, target_len) |
114 | | - |
115 | | - source = self.encoder(source, source_mask, source_positional_encoding, source_non_pad) |
116 | | - output = self.decoder(target, source, target_mask, dec_enc_mask, target_positional_encoding, target_non_pad) |
117 | | - # output = [batch size, target length, output dim] |
118 | | - |
| 16 | + encoder_output = self.encoder(source) # [batch size, source length, hidden dim] |
| 17 | + output = self.decoder(target, source, encoder_output) # [batch size, target length, output dim] |
119 | 18 | return output |
120 | 19 |
|
121 | 20 | def count_parameters(self): |
|
0 commit comments