Use AdamW from tf.keras.optimizers instead of TensorFlow-Addons

main
Aurélien Geron 2023-11-14 18:20:45 +13:00
parent bde6c1704e
commit de0eb33694
1 changed files with 89 additions and 101 deletions

View File

@ -1780,36 +1780,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"On Colab or Kaggle, we need to install the TensorFlow-Addons library:"
"Note: Since TF 1.12, `AdamW` is no longer experimental. It is available at `tf.keras.optimizers.AdamW` instead of `tf.keras.optimizers.experimental.AdamW`."
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules:\n",
" %pip install -q -U tensorflow-addons"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import tensorflow_addons as tfa\n",
"\n",
"optimizer = tfa.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n",
" beta_1=0.9, beta_2=0.999)"
"optimizer = tf.keras.optimizers.AdamW(weight_decay=1e-5, learning_rate=0.001,\n",
" beta_1=0.9, beta_2=0.999)"
]
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 63,
"metadata": {},
"outputs": [
{
@ -1845,7 +1833,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 64,
"metadata": {},
"outputs": [
{
@ -1917,7 +1905,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
@ -1926,7 +1914,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 66,
"metadata": {},
"outputs": [
{
@ -1962,7 +1950,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 67,
"metadata": {},
"outputs": [
{
@ -2017,7 +2005,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
@ -2027,7 +2015,7 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
@ -2041,7 +2029,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@ -2056,7 +2044,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 71,
"metadata": {},
"outputs": [
{
@ -2125,7 +2113,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 72,
"metadata": {},
"outputs": [
{
@ -2162,7 +2150,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
@ -2179,7 +2167,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
@ -2203,7 +2191,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -2216,7 +2204,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 76,
"metadata": {},
"outputs": [
{
@ -2288,7 +2276,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 77,
"metadata": {
"scrolled": true
},
@ -2330,7 +2318,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
@ -2345,7 +2333,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
@ -2364,7 +2352,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 80,
"metadata": {},
"outputs": [
{
@ -2442,7 +2430,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 81,
"metadata": {},
"outputs": [
{
@ -2479,7 +2467,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -2493,7 +2481,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 83,
"metadata": {},
"outputs": [
{
@ -2562,7 +2550,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 84,
"metadata": {},
"outputs": [
{
@ -2606,7 +2594,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
@ -2622,7 +2610,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 86,
"metadata": {},
"outputs": [
{
@ -2666,7 +2654,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
@ -2692,7 +2680,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
@ -2727,7 +2715,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@ -2755,7 +2743,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
@ -2779,7 +2767,7 @@
},
{
"cell_type": "code",
"execution_count": 92,
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
@ -2798,7 +2786,7 @@
},
{
"cell_type": "code",
"execution_count": 93,
"execution_count": 92,
"metadata": {},
"outputs": [
{
@ -2844,7 +2832,7 @@
},
{
"cell_type": "code",
"execution_count": 94,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
@ -2885,7 +2873,7 @@
},
{
"cell_type": "code",
"execution_count": 95,
"execution_count": 94,
"metadata": {},
"outputs": [
{
@ -2974,7 +2962,7 @@
},
{
"cell_type": "code",
"execution_count": 96,
"execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
@ -2992,7 +2980,7 @@
},
{
"cell_type": "code",
"execution_count": 97,
"execution_count": 96,
"metadata": {},
"outputs": [],
"source": [
@ -3001,7 +2989,7 @@
},
{
"cell_type": "code",
"execution_count": 98,
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
@ -3022,7 +3010,7 @@
},
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 98,
"metadata": {},
"outputs": [
{
@ -3054,7 +3042,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 99,
"metadata": {},
"outputs": [],
"source": [
@ -3063,7 +3051,7 @@
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 100,
"metadata": {},
"outputs": [],
"source": [
@ -3082,7 +3070,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 101,
"metadata": {},
"outputs": [
{
@ -3130,7 +3118,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 102,
"metadata": {},
"outputs": [
{
@ -3146,7 +3134,7 @@
"[0.30816400051116943, 0.8849090933799744]"
]
},
"execution_count": 103,
"execution_count": 102,
"metadata": {},
"output_type": "execute_result"
}
@ -3157,7 +3145,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 103,
"metadata": {},
"outputs": [
{
@ -3173,7 +3161,7 @@
"[0.3628920316696167, 0.8700000047683716]"
]
},
"execution_count": 104,
"execution_count": 103,
"metadata": {},
"output_type": "execute_result"
}
@ -3198,7 +3186,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 104,
"metadata": {},
"outputs": [],
"source": [
@ -3207,7 +3195,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 105,
"metadata": {},
"outputs": [],
"source": [
@ -3218,7 +3206,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 106,
"metadata": {},
"outputs": [
{
@ -3228,7 +3216,7 @@
" 0.844]], dtype=float32)"
]
},
"execution_count": 107,
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
@ -3239,7 +3227,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 107,
"metadata": {},
"outputs": [
{
@ -3249,7 +3237,7 @@
" 0.723], dtype=float32)"
]
},
"execution_count": 108,
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
@ -3260,7 +3248,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 108,
"metadata": {},
"outputs": [
{
@ -3270,7 +3258,7 @@
" 0.183], dtype=float32)"
]
},
"execution_count": 109,
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
@ -3282,7 +3270,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 109,
"metadata": {},
"outputs": [
{
@ -3291,7 +3279,7 @@
"0.8717"
]
},
"execution_count": 110,
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
@ -3304,7 +3292,7 @@
},
{
"cell_type": "code",
"execution_count": 111,
"execution_count": 110,
"metadata": {},
"outputs": [],
"source": [
@ -3315,7 +3303,7 @@
},
{
"cell_type": "code",
"execution_count": 112,
"execution_count": 111,
"metadata": {},
"outputs": [],
"source": [
@ -3330,7 +3318,7 @@
},
{
"cell_type": "code",
"execution_count": 113,
"execution_count": 112,
"metadata": {},
"outputs": [
{
@ -3375,7 +3363,7 @@
},
{
"cell_type": "code",
"execution_count": 114,
"execution_count": 113,
"metadata": {},
"outputs": [
{
@ -3385,7 +3373,7 @@
" dtype=float32)"
]
},
"execution_count": 114,
"execution_count": 113,
"metadata": {},
"output_type": "execute_result"
}
@ -3406,7 +3394,7 @@
},
{
"cell_type": "code",
"execution_count": 115,
"execution_count": 114,
"metadata": {},
"outputs": [],
"source": [
@ -3417,7 +3405,7 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 115,
"metadata": {},
"outputs": [
{
@ -3512,7 +3500,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 116,
"metadata": {},
"outputs": [],
"source": [
@ -3543,7 +3531,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 117,
"metadata": {},
"outputs": [],
"source": [
@ -3559,7 +3547,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
@ -3578,7 +3566,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -3600,7 +3588,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@ -3616,7 +3604,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 121,
"metadata": {},
"outputs": [
{
@ -3653,7 +3641,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 122,
"metadata": {},
"outputs": [
{
@ -3739,7 +3727,7 @@
"<keras.callbacks.History at 0x7fb9f02fc070>"
]
},
"execution_count": 123,
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
@ -3752,7 +3740,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 123,
"metadata": {},
"outputs": [
{
@ -3768,7 +3756,7 @@
"[1.5061508417129517, 0.4675999879837036]"
]
},
"execution_count": 124,
"execution_count": 123,
"metadata": {},
"output_type": "execute_result"
}
@ -3805,7 +3793,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 124,
"metadata": {},
"outputs": [
{
@ -3891,7 +3879,7 @@
"[1.4236289262771606, 0.5073999762535095]"
]
},
"execution_count": 125,
"execution_count": 124,
"metadata": {},
"output_type": "execute_result"
}
@ -3948,7 +3936,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 125,
"metadata": {
"scrolled": true
},
@ -4042,7 +4030,7 @@
"[1.4607702493667603, 0.5026000142097473]"
]
},
"execution_count": 126,
"execution_count": 125,
"metadata": {},
"output_type": "execute_result"
}
@ -4103,7 +4091,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 126,
"metadata": {},
"outputs": [
{
@ -4183,7 +4171,7 @@
"[1.4779616594314575, 0.498199999332428]"
]
},
"execution_count": 127,
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
@ -4244,7 +4232,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -4262,7 +4250,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
@ -4285,7 +4273,7 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@ -4307,7 +4295,7 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 130,
"metadata": {},
"outputs": [
{
@ -4316,7 +4304,7 @@
"0.4984"
]
},
"execution_count": 131,
"execution_count": 130,
"metadata": {},
"output_type": "execute_result"
}
@ -4348,7 +4336,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [
@ -4372,7 +4360,7 @@
},
{
"cell_type": "code",
"execution_count": 133,
"execution_count": 132,
"metadata": {},
"outputs": [
{
@ -4404,7 +4392,7 @@
},
{
"cell_type": "code",
"execution_count": 134,
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
@ -4428,7 +4416,7 @@
},
{
"cell_type": "code",
"execution_count": 135,
"execution_count": 134,
"metadata": {},
"outputs": [
{
@ -4501,7 +4489,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.9.10"
},
"nav_menu": {
"height": "360px",