@@ -70,13 +70,16 @@ def make_batch(seq_data):
7070# 기존처럼 one-hot 인코딩을 사용한다면 입력값의 형태는 [None, n_class] 여야합니다.
7171Y = tf .placeholder (tf .int32 , [None ])
7272
73+ # dropout prob for RNN
74+ keep_prob = tf .placeholder (tf .float32 , [])
75+
7376W = tf .Variable (tf .random_normal ([n_hidden , n_class ]))
7477b = tf .Variable (tf .random_normal ([n_class ]))
7578
7679# RNN 셀을 생성합니다.
7780cell1 = tf .nn .rnn_cell .BasicLSTMCell (n_hidden )
7881# 과적합 방지를 위한 Dropout 기법을 사용합니다.
79- cell1 = tf .nn .rnn_cell .DropoutWrapper (cell1 , output_keep_prob = 0.5 )
82+ cell1 = tf .nn .rnn_cell .DropoutWrapper (cell1 , output_keep_prob = keep_prob )
8083# 여러개의 셀을 조합해서 사용하기 위해 셀을 추가로 생성합니다.
8184cell2 = tf .nn .rnn_cell .BasicLSTMCell (n_hidden )
8285
@@ -108,7 +111,9 @@ def make_batch(seq_data):
108111
109112for epoch in range (total_epoch ):
110113 _ , loss = sess .run ([optimizer , cost ],
111- feed_dict = {X : input_batch , Y : target_batch })
114+ feed_dict = {X : input_batch ,
115+ Y : target_batch ,
116+ keep_prob : 0.5 })
112117
113118 print ('Epoch:' , '%04d' % (epoch + 1 ),
114119 'cost =' , '{:.6f}' .format (loss ))
@@ -127,7 +132,9 @@ def make_batch(seq_data):
127132input_batch , target_batch = make_batch (seq_data )
128133
129134predict , accuracy_val = sess .run ([prediction , accuracy ],
130- feed_dict = {X : input_batch , Y : target_batch })
135+ feed_dict = {X : input_batch ,
136+ Y : target_batch ,
137+ keep_prob :1 })
131138
132139predict_words = []
133140for idx , val in enumerate (seq_data ):
0 commit comments