Skip to content

Commit 91023ef

Browse files
authored
Change .weights to trainable_weights
1 parent 049fbab commit 91023ef

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ def _map_fn_train(img):
9797
with tf.GradientTape() as tape:
9898
fake_hr_patchs = G(lr_patchs)
9999
mse_loss = tl.cost.mean_squared_error(fake_hr_patchs, hr_patchs, is_mean=True)
100-
grad = tape.gradient(mse_loss, G.weights)
101-
g_optimizer_init.apply_gradients(zip(grad, G.weights))
100+
grad = tape.gradient(mse_loss, G.trainable_weights)
101+
g_optimizer_init.apply_gradients(zip(grad, G.trainable_weights))
102102
step += 1
103103
epoch = step//n_step_epoch
104104
print("Epoch: [{}/{}] step: [{}/{}] time: {}s, mse: {} ".format(
@@ -124,7 +124,7 @@ def _map_fn_train(img):
124124
g_loss = mse_loss + vgg_loss + g_gan_loss
125125
grad = tape.gradient(g_loss, G.trainable_weights)
126126
g_optimizer.apply_gradients(zip(grad, G.trainable_weights))
127-
grad = tape.gradient(d_loss, D.weights)
127+
grad = tape.gradient(d_loss, D.trainable_weights)
128128
d_optimizer.apply_gradients(zip(grad, D.trainable_weights))
129129
step += 1
130130
epoch = step//n_step_epoch

0 commit comments

Comments
 (0)