Sync notebook code with book code (rename max_dims to embed_size)

main
Aurélien Geron 2022-04-16 17:30:08 +12:00
parent d5d16c3202
commit b67e51af2c
1 changed files with 10 additions and 21 deletions

View File

@ -1544,13 +1544,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"704/704 [==============================] - 280s 395ms/step - loss: 0.5038 - accuracy: 0.7496 - val_loss: 0.6706 - val_accuracy: 0.6752\n",
"Epoch 2/5\n",
"704/704 [==============================] - 277s 393ms/step - loss: 0.4499 - accuracy: 0.7892 - val_loss: 0.3494 - val_accuracy: 0.8500\n",
@ -2416,11 +2410,12 @@
"class PositionalEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, max_length, embed_size, dtype=tf.float32, **kwargs):\n",
" super().__init__(dtype=dtype, **kwargs)\n",
" max_dims = (embed_size + 1) // 2 * 2 # round up to nearest even number\n",
" p, i = np.meshgrid(np.arange(max_length), 2 * np.arange(max_dims // 2))\n",
" pos_emb = np.empty((1, max_length, max_dims))\n",
" pos_emb[0, :, ::2] = np.sin(p / 10_000 ** (i / max_dims)).T\n",
" pos_emb[0, :, 1::2] = np.cos(p / 10_000 ** (i / max_dims)).T\n",
" assert embed_size % 2 == 0, \"embed_size must be even\"\n",
" p, i = np.meshgrid(np.arange(max_length),\n",
" 2 * np.arange(embed_size // 2))\n",
" pos_emb = np.empty((1, max_length, embed_size))\n",
" pos_emb[0, :, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T\n",
" pos_emb[0, :, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T\n",
" self.pos_encodings = tf.constant(pos_emb.astype(self.dtype))\n",
" self.supports_masking = True\n",
"\n",
@ -3251,13 +3246,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"313/313 [==============================] - 4s 8ms/step - loss: 0.6910 - accuracy: 0.5095 - val_loss: 0.6825 - val_accuracy: 0.5645\n",
"Epoch 2/20\n",
"313/313 [==============================] - 2s 7ms/step - loss: 0.6678 - accuracy: 0.5659 - val_loss: 0.6635 - val_accuracy: 0.6105\n",
@ -4313,6 +4302,7 @@
}
],
"metadata": {
"accelerator": "GPU",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
@ -4339,8 +4329,7 @@
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
},
"accelerator": "GPU"
}
},
"nbformat": 4,
"nbformat_minor": 4