Clarify DenseTranspose

main
Aurélien Geron 2019-06-10 17:42:31 +08:00
parent 400920f0aa
commit abf3bba2d5
1 changed files with 14 additions and 8 deletions

View File

@ -380,8 +380,6 @@
"metadata": {},
"outputs": [],
"source": [
"K = keras.backend\n",
"\n",
"class DenseTranspose(keras.layers.Layer):\n",
" def __init__(self, dense, activation=None, **kwargs):\n",
" self.dense = dense\n",
@ -393,8 +391,8 @@
" initializer=\"zeros\")\n",
" super().build(batch_input_shape)\n",
" def call(self, inputs):\n",
" z = inputs @ K.transpose(self.dense.weights[0]) + self.biases\n",
" return self.activation(z)"
" z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)\n",
" return self.activation(z + self.biases)"
]
},
{
@ -403,20 +401,27 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
"dense_1 = keras.layers.Dense(100, activation=\"selu\")\n",
"dense_2 = keras.layers.Dense(30, activation=\"selu\")\n",
"\n",
"tied_encoder = keras.models.Sequential([\n",
" keras.layers.Flatten(input_shape=[28, 28]),\n",
" keras.layers.Dense(100, activation=\"selu\"),\n",
" keras.layers.Dense(30, activation=\"selu\"),\n",
" dense_1,\n",
" dense_2\n",
"])\n",
"\n",
"tied_decoder = keras.models.Sequential([\n",
" DenseTranspose(tied_encoder.layers[2], activation=\"selu\"),\n",
" DenseTranspose(tied_encoder.layers[1], activation=\"sigmoid\"),\n",
" DenseTranspose(dense_2, activation=\"selu\"),\n",
" DenseTranspose(dense_1, activation=\"sigmoid\"),\n",
" keras.layers.Reshape([28, 28])\n",
"])\n",
"\n",
"tied_ae = keras.models.Sequential([tied_encoder, tied_decoder])\n",
"\n",
"tied_ae.compile(loss=\"binary_crossentropy\",\n",
" optimizer=keras.optimizers.SGD(lr=1.5), metrics=[rounded_accuracy])\n",
"history = tied_ae.fit(X_train, X_train, epochs=10,\n",
@ -473,6 +478,7 @@
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
"K = keras.backend\n",
"X_train_flat = K.batch_flatten(X_train) # equivalent to .reshape(-1, 28 * 28)\n",
"X_valid_flat = K.batch_flatten(X_valid)\n",
"enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(\n",