Add missing math import and n_epochs = 20

main
Aurélien Geron 2023-11-15 21:23:37 +13:00
parent b38aff05a3
commit 873e1a986c
1 changed files with 7 additions and 4 deletions

View File

@ -2236,6 +2236,8 @@
} }
], ],
"source": [ "source": [
"n_epochs = 20\n",
"\n",
"lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)\n", "lr_scheduler = tf.keras.callbacks.LearningRateScheduler(exponential_decay_fn)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n", "history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n", " validation_data=(X_valid, y_valid),\n",
@ -2366,9 +2368,10 @@
} }
], ],
"source": [ "source": [
"n_epochs = 25\n", "import math\n",
"\n",
"batch_size = 32\n", "batch_size = 32\n",
"n_steps = n_epochs * np.ceil(len(X_train) / batch_size)\n", "n_steps = n_epochs * math.ceil(len(X_train) / batch_size)\n",
"exp_decay = ExponentialDecay(n_steps)\n", "exp_decay = ExponentialDecay(n_steps)\n",
"history = model.fit(X_train, y_train, epochs=n_epochs,\n", "history = model.fit(X_train, y_train, epochs=n_epochs,\n",
" validation_data=(X_valid, y_valid),\n", " validation_data=(X_valid, y_valid),\n",
@ -4561,7 +4564,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -4575,7 +4578,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.10" "version": "3.10.13"
}, },
"nav_menu": { "nav_menu": {
"height": "360px", "height": "360px",