Skip to content

Commit 09b9927

Browse files
committed
refactoring code structure
1 parent 5ae2ceb commit 09b9927

File tree

8 files changed

+227
-140
lines changed

8 files changed

+227
-140
lines changed

config/params.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
"feed_forward_dim": 1024,
1717
"n_layer": 6,
1818
"n_head": 8,
19+
"max_len": 20,
1920
"dropout": 0.1
2021
}

model/attention.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
import torch.nn.functional as F
4+
import numpy as np
45

56

67
class MultiHeadAttention(nn.Module):
@@ -59,7 +60,7 @@ def forward(self, query, key, value, mask=None):
5960
# self_attention = [batch size, sentence length, sentence length]
6061

6162
if mask is not None:
62-
self_attention = self_attention.masked_fill(mask, -1e10)
63+
self_attention = self_attention.masked_fill(mask, -np.inf)
6364

6465
# normalize self attention score by applying soft max function on each row
6566
attention_score = self.dropout(F.softmax(self_attention, dim=-1))

model/decoder.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from model.attention import MultiHeadAttention
55
from model.positionwise import PositionWiseFeedForward
6+
from model.ops import create_positional_encoding, create_non_pad_mask, create_subsequent_mask, create_target_mask
67

78

89
class DecoderLayer(nn.Module):
@@ -39,36 +40,35 @@ class Decoder(nn.Module):
3940
def __init__(self, params):
4041
super(Decoder, self).__init__()
4142
self.device = params.device
43+
self.hidden_dim = params.hidden_dim
4244

43-
self.token_embedding = nn.Embedding(params.output_dim, params.hidden_dim)
45+
self.token_embedding = nn.Embedding(params.output_dim, params.hidden_dim, padding_idx=params.pad_idx)
4446
self.decoder_layers = nn.ModuleList([DecoderLayer(params) for _ in range(params.n_layer)])
4547
self.fc = nn.Linear(params.hidden_dim, params.output_dim)
4648

4749
self.dropout = nn.Dropout(params.dropout)
4850
self.scale = torch.sqrt(torch.FloatTensor([params.hidden_dim])).to(self.device)
4951

50-
def forward(self, target, encoder_output, target_mask, dec_enc_mask, positional_encoding, target_non_pad):
52+
def forward(self, target, source, encoder_output):
5153
# target = [batch size, target length]
54+
# source = [batch size, source length]
5255
# encoder_output = [batch size, source length, hidden dim]
56+
target_batch, target_len = target.size()
5357

54-
# target_mask = [batch size, target length, target length]
55-
# dec_enc_mask = [batch size, target length, source length]
56-
# positional_encoding = [batch size, target length, hidden dim]
58+
subsequent_mask = create_subsequent_mask(target)
59+
target_mask, dec_enc_mask = create_target_mask(source, target, subsequent_mask)
60+
# target_mask = [batch size, target length, target length]
61+
# dec_enc_mask = [batch size, target length, source length]
62+
target_non_pad = create_non_pad_mask(target) # [batch size, target length, 1]
5763

58-
# target_non_pad = [batch size, target length, 1]
59-
60-
# print(f'[D] Before embedding: {target.shape}')
6164
embedded = self.token_embedding(target)
62-
# print(f'[D] Before embedding: {embedded.shape}')
63-
65+
positional_encoding = create_positional_encoding(target_batch, target_len, self.hidden_dim)
6466
target = self.dropout(embedded + positional_encoding)
6567

6668
for decoder_layer in self.decoder_layers:
6769
target = decoder_layer(target, encoder_output, target_mask, dec_enc_mask, target_non_pad)
6870
# target = [batch size, target length, hidden dim]
69-
# print(f'[D] After decoding: {target.shape}')
71+
7072
output = self.fc(target)
7173
# output = [batch size, target length, output dim]
72-
# print(f'[D] After predicting: {output.shape}')
73-
# print('------------------------------------------------------------')
7474
return output

model/encoder.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from model.attention import MultiHeadAttention
55
from model.positionwise import PositionWiseFeedForward
6+
from model.ops import create_positional_encoding, create_non_pad_mask, create_source_mask
67

78

89
class EncoderLayer(nn.Module):
910
def __init__(self, params):
1011
super(EncoderLayer, self).__init__()
1112
self.layer_norm = nn.LayerNorm(params.hidden_dim)
12-
1313
self.self_attention = MultiHeadAttention(params)
1414
self.position_wise_ffn = PositionWiseFeedForward(params)
1515

@@ -33,29 +33,26 @@ class Encoder(nn.Module):
3333
def __init__(self, params):
3434
super(Encoder, self).__init__()
3535
self.device = params.device
36+
self.hidden_dim = params.hidden_dim
3637

37-
self.token_embedding = nn.Embedding(params.input_dim, params.hidden_dim)
38+
self.token_embedding = nn.Embedding(params.input_dim, params.hidden_dim, padding_idx=params.pad_idx)
3839
self.encoder_layers = nn.ModuleList([EncoderLayer(params) for _ in range(params.n_layer)])
3940
self.dropout = nn.Dropout(params.dropout)
4041
self.scale = torch.sqrt(torch.FloatTensor([params.hidden_dim])).to(self.device)
4142

42-
def forward(self, source, source_mask, positional_encoding, source_non_pad):
43-
# source = [batch size, source length]
44-
# source_mask = [batch size, source length, source length]
45-
# positional_encoding = [batch size, source length, hidden dim]
46-
# source_non_pad = [batch size, source length, 1]
43+
def forward(self, source):
44+
# source = [batch size, source length]
45+
source_batch, source_len = source.size()
4746

48-
# define positional encoding which encodes token's positional information
49-
# print(f'[E] Before embedding: {source.shape}')
50-
embedded = self.token_embedding(source)
51-
# print(f'[E] After embedding: {embedded.shape}')
47+
source_mask = create_source_mask(source) # [batch size, source length, source length]
48+
source_non_pad = create_non_pad_mask(source) # [batch size, source length, 1]
5249

50+
embedded = self.token_embedding(source)
51+
positional_encoding = create_positional_encoding(source_batch, source_len, self.hidden_dim)
5352
source = self.dropout(embedded + positional_encoding)
5453
# source = [batch size, source length, hidden dim]
5554

5655
for encoder_layer in self.encoder_layers:
5756
source = encoder_layer(source, source_mask, source_non_pad)
5857
# source = [batch size, source length, hidden dim]
59-
# print(f'[E] After encoding: {source.shape}')
60-
# print('------------------------------------------------------------')
6158
return source

model/ops.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import pickle
2+
import numpy as np
3+
import torch
4+
5+
pickle_eng = open('pickles/eng.pickle', 'rb')
6+
eng = pickle.load(pickle_eng)
7+
pad_idx = eng.vocab.stoi['<pad>']
8+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9+
10+
11+
def create_subsequent_mask(target):
12+
# target = [batch size, target length]
13+
batch_size, target_length = target.size()
14+
'''
15+
if target length is 5 and diagonal is 1, this function returns
16+
[[0, 1, 1, 1, 1],
17+
[0, 0, 1, 1, 1],
18+
[0, 0, 0, 1, 1],
19+
[0, 0, 0, 1, 1],
20+
[0, 0, 0, 0, 1]]
21+
'''
22+
# torch.triu returns the upper triangular part of a matrix based on user defined diagonal
23+
subsequent_mask = torch.triu(torch.ones(target_length, target_length), diagonal=1).bool().to(device)
24+
# subsequent_mask = [target length, target length]
25+
26+
# repeat subsequent mask 'batch size' times to cover all data instances in the batch
27+
subsequent_mask = subsequent_mask.unsqueeze(0).repeat(batch_size, 1, 1)
28+
# subsequent_mask = [batch size, target length, target length]
29+
30+
return subsequent_mask
31+
32+
33+
def create_source_mask(source):
34+
'''
35+
create masking tensor for encoder's self attention
36+
if sentence is [2, 193, 9, 27, 10003, 1, 1, 1, 3] and 2 denotes <sos>, 3 denotes <eos> and 1 denotes <pad>
37+
masking tensor will be [False, False, False, False, False, True, True, True, False]
38+
:param source: [batch size, source length]
39+
:return: source mask
40+
'''
41+
source_length = source.shape[1]
42+
43+
# create boolean tensors which will be used to mask padding tokens of both source and target sentence
44+
source_mask = (source == pad_idx)
45+
# source_mask = [batch size, source length]
46+
47+
# repeat sentence masking tensors 'sentence length' times
48+
source_mask = source_mask.unsqueeze(1).repeat(1, source_length, 1)
49+
# source_mask = [batch size, source length, source length]
50+
51+
return source_mask
52+
53+
54+
def create_target_mask(source, target, subsequent_mask):
55+
'''
56+
create masking tensor for decoder's self attention and decoder's attention on the output of encoder
57+
if sentence is [2, 193, 9, 27, 10003, 1, 1, 1, 3] and 2 denotes <sos>, 3 denotes <eos> and 1 denotes <pad>
58+
masking tensor will be [False, False, False, False, False, True, True, True, False]
59+
:param source: [batch size, source length]
60+
:param target: [batch size, target length]
61+
:param subsequent_mask: [batch size, target length, target length]
62+
:return:
63+
'''
64+
target_length = target.shape[1]
65+
66+
# create boolean tensors which will be used to mask padding tokens of both source and target sentence
67+
source_mask = (source == pad_idx)
68+
target_mask = (target == pad_idx)
69+
# target_mask = [batch size, target length]
70+
71+
# repeat sentence masking tensors 'sentence length' times
72+
dec_enc_mask = source_mask.unsqueeze(1).repeat(1, target_length, 1)
73+
target_mask = target_mask.unsqueeze(1).repeat(1, target_length, 1)
74+
75+
# dec_enc_mask = [batch size, target length, source length]
76+
# target_mask = [batch size, target length, target length]
77+
78+
# combine <pad> token masking tensor and subsequent masking tensor for decoder's self attention
79+
target_mask = target_mask | subsequent_mask
80+
# target_mask = [batch size, target length, target length]
81+
82+
return target_mask, dec_enc_mask
83+
84+
85+
def create_non_pad_mask(sentence):
86+
'''
87+
create non-pad masking tensor which will be used to extract non-padded tokens from output
88+
if sentence is [2, 193, 9, 27, 1, 1, 1, 3]
89+
this function returns [[1], [1], [1], [1], [0], [0], [0], [1]]
90+
'''
91+
return sentence.ne(pad_idx).type(torch.float).unsqueeze(-1)
92+
93+
94+
def create_positional_encoding(batch_size, sentence_len, hidden_dim):
95+
# PE(pos, 2i) = sin(pos/10000 ** (2*i / hidden_dim))
96+
# PE(pos, 2i + 1) = cos(pos/10000 ** (2*i / hidden_dim))
97+
sinusoid_table = np.array([pos / np.power(10000, 2 * i / hidden_dim)
98+
for pos in range(sentence_len) for i in range(hidden_dim)])
99+
# sinusoid_table = [sentence length * hidden dim]
100+
101+
sinusoid_table = sinusoid_table.reshape(sentence_len, -1)
102+
# sinusoid_table = [sentence length, hidden dim]
103+
104+
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # calculate pe for even dimension
105+
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # calculate pe for odd dimension
106+
107+
# convert numpy based sinusoid table to torch.tensor and repeat it 'batch size' times
108+
sinusoid_table = torch.FloatTensor(sinusoid_table).to(device)
109+
sinusoid_table = sinusoid_table.unsqueeze(0).repeat(batch_size, 1, 1)
110+
# sinusoid_table = [batch size, sentence length, hidden dim]
111+
112+
return sinusoid_table

model/transformer.py

Lines changed: 2 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import numpy as np
2-
import torch
31
import torch.nn as nn
42

53
from model.encoder import Encoder
@@ -9,113 +7,14 @@
97
class Transformer(nn.Module):
108
def __init__(self, params):
119
super(Transformer, self).__init__()
12-
self.params = params
13-
self.hidden_dim = params.hidden_dim
14-
15-
self.device = params.device
1610
self.encoder = Encoder(params)
1711
self.decoder = Decoder(params)
1812

19-
def create_subsequent_mask(self, target):
20-
# target = [batch size, target length]
21-
22-
batch_size, target_length = target.size()
23-
'''
24-
if target length is 5 and diagonal is 1, this function returns
25-
[[0, 1, 1, 1, 1],
26-
[0, 0, 1, 1, 1],
27-
[0, 0, 0, 1, 1],
28-
[0, 0, 0, 0, 1],
29-
[0, 0, 0, 0, 1]]
30-
'''
31-
# torch.triu returns the upper triangular part of a matrix based on user defined diagonal
32-
subsequent_mask = torch.triu(torch.ones(target_length, target_length), diagonal=1).bool().to(self.device)
33-
# subsequent_mask = [target length, target length]
34-
35-
# repeat subsequent mask 'batch size' times to cover all data instances in the batch
36-
subsequent_mask = subsequent_mask.unsqueeze(0).repeat(batch_size, 1, 1)
37-
# subsequent_mask = [batch size, target length, target length]
38-
39-
return subsequent_mask
40-
41-
def create_mask(self, source, target, subsequent_mask):
42-
# source = [batch size, source length]
43-
# target = [batch size, target length]
44-
# subsequent_mask = [batch size, target length, target length]
45-
source_length = source.shape[1]
46-
target_length = target.shape[1]
47-
48-
# create boolean tensors which will be used to mask padding tokens of both source and target sentence
49-
source_mask = (source == self.params.pad_idx)
50-
target_mask = (target == self.params.pad_idx)
51-
# source_mask = [batch size, source length]
52-
# target_mask = [batch size, target length]
53-
'''
54-
if sentence is [2, 193, 9, 27, 10003, 1, 1, 1, 3] and 2 denotes <sos>, 3 denotes <eos> and 1 denotes <pad>
55-
masking tensor will be [False, False, False, False, False, True, True, True, False]
56-
'''
57-
# repeat sentence masking tensors 'sentence length' times
58-
dec_enc_mask = source_mask.unsqueeze(1).repeat(1, target_length, 1)
59-
source_mask = source_mask.unsqueeze(1).repeat(1, source_length, 1)
60-
target_mask = target_mask.unsqueeze(1).repeat(1, target_length, 1)
61-
62-
# source_mask = [batch size, source length, source length]
63-
# target_mask = [batch size, target length, target length]
64-
# dec_enc_mask = [batch size, target length, source length]
65-
66-
# combine <pad> token masking tensor and subsequent masking tensor for decoder's self attention
67-
target_mask = target_mask | subsequent_mask
68-
# target_mask = [batch size, target length, target length]
69-
70-
return source_mask, target_mask, dec_enc_mask
71-
72-
def create_non_pad_mask(self, sentence):
73-
# padding token shouldn't be used for the output tensor
74-
# to use only non padding token, create non-pad masking tensor
75-
return sentence.ne(self.params.pad_idx).type(torch.float).unsqueeze(-1)
76-
77-
def create_positional_encoding(self, batch_size, sentence_len):
78-
# PE(pos, 2i) = sin(pos/10000 ** (2*i / hidden_dim)
79-
# PE(pos, 2i + 1) = cos(pos/10000 ** (2*i / hidden_dim)
80-
sinusoid_table = np.array([pos/np.power(10000, 2*i/self.hidden_dim)
81-
for pos in range(sentence_len) for i in range(self.hidden_dim)])
82-
# sinusoid_table = [sentence length * hidden dim]
83-
84-
sinusoid_table = sinusoid_table.reshape(sentence_len, -1)
85-
# sinusoid_table = [sentence length, hidden dim]
86-
87-
sinusoid_table[0::2, :] = np.sin(sinusoid_table[0::2, :]) # calculate pe for even numbers
88-
sinusoid_table[1::2, :] = np.sin(sinusoid_table[1::2, :]) # calculate pe for odd numbers
89-
90-
# convert numpy based sinusoid table to torch.tensor and repeat it 'batch size' times
91-
sinusoid_table = torch.FloatTensor(sinusoid_table).to(self.device)
92-
sinusoid_table = sinusoid_table.unsqueeze(0).repeat(batch_size, 1, 1)
93-
# sinusoid_table = [batch size, sentence length, hidden dim]
94-
95-
return sinusoid_table
96-
9713
def forward(self, source, target):
9814
# source = [batch size, source length]
9915
# target = [batch size, target length]
100-
source_batch, source_len = source.size()
101-
target_batch, target_len = target.size()
102-
103-
# create masking tensor for self attention (encoder & decoder) and decoder's attention on the output of encoder
104-
subsequent_mask = self.create_subsequent_mask(target)
105-
source_mask, target_mask, dec_enc_mask = self.create_mask(source, target, subsequent_mask)
106-
107-
# create non-pad masking tensor which will be used to extract non-padded tokens from output
108-
source_non_pad = self.create_non_pad_mask(source)
109-
target_non_pad = self.create_non_pad_mask(target)
110-
# non_pad = [batch size, sentence length, 1]
111-
112-
source_positional_encoding = self.create_positional_encoding(source_batch, source_len)
113-
target_positional_encoding = self.create_positional_encoding(target_batch, target_len)
114-
115-
source = self.encoder(source, source_mask, source_positional_encoding, source_non_pad)
116-
output = self.decoder(target, source, target_mask, dec_enc_mask, target_positional_encoding, target_non_pad)
117-
# output = [batch size, target length, output dim]
118-
16+
encoder_output = self.encoder(source) # [batch size, source length, hidden dim]
17+
output = self.decoder(target, source, encoder_output) # [batch size, target length, output dim]
11918
return output
12019

12120
def count_parameters(self):

0 commit comments

Comments
 (0)