From f0f83903bd5856d78858eeac69ebd7d9e579cead Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Fri, 10 May 2019 21:30:18 +0800 Subject: [PATCH] Fix the transformer (use final encoder outputs) --- 16_nlp_with_rnns_and_attention.ipynb | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/16_nlp_with_rnns_and_attention.ipynb b/16_nlp_with_rnns_and_attention.ipynb index 45e1d97..ff30fb1 100644 --- a/16_nlp_with_rnns_and_attention.ipynb +++ b/16_nlp_with_rnns_and_attention.ipynb @@ -1116,17 +1116,18 @@ "metadata": {}, "outputs": [], "source": [ + "Z = encoder_in\n", "for N in range(6):\n", - " encoder_attn = keras.layers.Attention(use_scale=True)\n", - " encoder_in = encoder_attn([encoder_in, encoder_in])\n", - " masked_decoder_attn = keras.layers.Attention(use_scale=True, causal=True)\n", - " decoder_in = masked_decoder_attn([decoder_in, decoder_in])\n", - " decoder_attn = keras.layers.Attention(use_scale=True)\n", - " final_enc = decoder_attn([decoder_in, encoder_in])\n", + " Z = keras.layers.Attention(use_scale=True)([Z, Z])\n", "\n", - "output_layer = keras.layers.TimeDistributed(\n", - " keras.layers.Dense(vocab_size, activation=\"softmax\"))\n", - "outputs = output_layer(final_enc)" + "encoder_outputs = Z\n", + "Z = decoder_in\n", + "for N in range(6):\n", + " Z = keras.layers.Attention(use_scale=True, causal=True)([Z, Z])\n", + " Z = keras.layers.Attention(use_scale=True)([Z, encoder_outputs])\n", + "\n", + "outputs = keras.layers.TimeDistributed(\n", + " keras.layers.Dense(vocab_size, activation=\"softmax\"))(Z)" ] }, {