Fix the transformer (use final encoder outputs)
parent
7b1c890195
commit
f0f83903bd
|
@ -1116,17 +1116,18 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"Z = encoder_in\n",
|
||||
"for N in range(6):\n",
|
||||
" encoder_attn = keras.layers.Attention(use_scale=True)\n",
|
||||
" encoder_in = encoder_attn([encoder_in, encoder_in])\n",
|
||||
" masked_decoder_attn = keras.layers.Attention(use_scale=True, causal=True)\n",
|
||||
" decoder_in = masked_decoder_attn([decoder_in, decoder_in])\n",
|
||||
" decoder_attn = keras.layers.Attention(use_scale=True)\n",
|
||||
" final_enc = decoder_attn([decoder_in, encoder_in])\n",
|
||||
" Z = keras.layers.Attention(use_scale=True)([Z, Z])\n",
|
||||
"\n",
|
||||
"output_layer = keras.layers.TimeDistributed(\n",
|
||||
" keras.layers.Dense(vocab_size, activation=\"softmax\"))\n",
|
||||
"outputs = output_layer(final_enc)"
|
||||
"encoder_outputs = Z\n",
|
||||
"Z = decoder_in\n",
|
||||
"for N in range(6):\n",
|
||||
" Z = keras.layers.Attention(use_scale=True, causal=True)([Z, Z])\n",
|
||||
" Z = keras.layers.Attention(use_scale=True)([Z, encoder_outputs])\n",
|
||||
"\n",
|
||||
"outputs = keras.layers.TimeDistributed(\n",
|
||||
" keras.layers.Dense(vocab_size, activation=\"softmax\"))(Z)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue