Clarify DenseTranspose
parent
400920f0aa
commit
abf3bba2d5
|
@ -380,8 +380,6 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"K = keras.backend\n",
|
||||
"\n",
|
||||
"class DenseTranspose(keras.layers.Layer):\n",
|
||||
" def __init__(self, dense, activation=None, **kwargs):\n",
|
||||
" self.dense = dense\n",
|
||||
|
@ -393,8 +391,8 @@
|
|||
" initializer=\"zeros\")\n",
|
||||
" super().build(batch_input_shape)\n",
|
||||
" def call(self, inputs):\n",
|
||||
" z = inputs @ K.transpose(self.dense.weights[0]) + self.biases\n",
|
||||
" return self.activation(z)"
|
||||
" z = tf.matmul(inputs, self.dense.weights[0], transpose_b=True)\n",
|
||||
" return self.activation(z + self.biases)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -403,20 +401,27 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
"dense_1 = keras.layers.Dense(100, activation=\"selu\")\n",
|
||||
"dense_2 = keras.layers.Dense(30, activation=\"selu\")\n",
|
||||
"\n",
|
||||
"tied_encoder = keras.models.Sequential([\n",
|
||||
" keras.layers.Flatten(input_shape=[28, 28]),\n",
|
||||
" keras.layers.Dense(100, activation=\"selu\"),\n",
|
||||
" keras.layers.Dense(30, activation=\"selu\"),\n",
|
||||
" dense_1,\n",
|
||||
" dense_2\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"tied_decoder = keras.models.Sequential([\n",
|
||||
" DenseTranspose(tied_encoder.layers[2], activation=\"selu\"),\n",
|
||||
" DenseTranspose(tied_encoder.layers[1], activation=\"sigmoid\"),\n",
|
||||
" DenseTranspose(dense_2, activation=\"selu\"),\n",
|
||||
" DenseTranspose(dense_1, activation=\"sigmoid\"),\n",
|
||||
" keras.layers.Reshape([28, 28])\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"tied_ae = keras.models.Sequential([tied_encoder, tied_decoder])\n",
|
||||
"\n",
|
||||
"tied_ae.compile(loss=\"binary_crossentropy\",\n",
|
||||
" optimizer=keras.optimizers.SGD(lr=1.5), metrics=[rounded_accuracy])\n",
|
||||
"history = tied_ae.fit(X_train, X_train, epochs=10,\n",
|
||||
|
@ -473,6 +478,7 @@
|
|||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
"K = keras.backend\n",
|
||||
"X_train_flat = K.batch_flatten(X_train) # equivalent to .reshape(-1, 28 * 28)\n",
|
||||
"X_valid_flat = K.batch_flatten(X_valid)\n",
|
||||
"enc1, dec1, X_train_enc1, X_valid_enc1 = train_autoencoder(\n",
|
||||
|
|
Loading…
Reference in New Issue