File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -97,8 +97,8 @@ def _map_fn_train(img):
9797 with tf .GradientTape () as tape :
9898 fake_hr_patchs = G (lr_patchs )
9999 mse_loss = tl .cost .mean_squared_error (fake_hr_patchs , hr_patchs , is_mean = True )
100- grad = tape .gradient (mse_loss , G .weights )
101- g_optimizer_init .apply_gradients (zip (grad , G .weights ))
100+ grad = tape .gradient (mse_loss , G .trainable_weights )
101+ g_optimizer_init .apply_gradients (zip (grad , G .trainable_weights ))
102102 step += 1
103103 epoch = step // n_step_epoch
104104 print ("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} " .format (
@@ -124,7 +124,7 @@ def _map_fn_train(img):
124124 g_loss = mse_loss + vgg_loss + g_gan_loss
125125 grad = tape .gradient (g_loss , G .trainable_weights )
126126 g_optimizer .apply_gradients (zip (grad , G .trainable_weights ))
127- grad = tape .gradient (d_loss , D .weights )
127+ grad = tape .gradient (d_loss , D .trainable_weights )
128128 d_optimizer .apply_gradients (zip (grad , D .trainable_weights ))
129129 step += 1
130130 epoch = step // n_step_epoch
You can’t perform that action at this time.
0 commit comments