Fix hyperparameter tuning on Vertex AI

main
Aurélien Geron 2022-04-16 22:11:30 +12:00
parent b01d1862f3
commit ce104660c6
1 changed files with 129 additions and 130 deletions

View File

@ -60,28 +60,6 @@
"assert sys.version_info >= (3, 7)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TAlKky09pKzv"
},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "YqCwW7cMpKzw"
},
"outputs": [],
"source": [
"import sklearn\n",
"\n",
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
@ -93,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"metadata": {
"id": "0Piq5se2pKzx"
},
@ -115,7 +93,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -134,7 +112,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 4,
"metadata": {
"id": "Ekxzo6pOpKzy"
},
@ -193,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -263,7 +241,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"outputs": [
{
@ -278,7 +256,7 @@
" 'my_mnist_model/0001/variables/variables.index']"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -296,7 +274,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"outputs": [
{
@ -314,7 +292,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"outputs": [
{
@ -333,7 +311,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"metadata": {},
"outputs": [
{
@ -386,7 +364,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
@ -408,7 +386,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
@ -419,7 +397,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
@ -461,7 +439,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
@ -476,7 +454,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [
{
@ -485,7 +463,7 @@
"'{\"signature_name\": \"serving_default\", \"instances\": [[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0..., 0, 0]]]}'"
]
},
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
@ -503,7 +481,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
@ -517,7 +495,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [
{
@ -528,7 +506,7 @@
" [0. , 0.97, 0.01, 0. , 0. , 0. , 0. , 0.01, 0. , 0. ]])"
]
},
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
@ -549,7 +527,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
@ -564,7 +542,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
@ -585,7 +563,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"metadata": {
"scrolled": true
},
@ -598,7 +576,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"outputs": [
{
@ -610,7 +588,7 @@
" dtype=float32)"
]
},
"execution_count": 21,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
@ -628,7 +606,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"metadata": {},
"outputs": [
{
@ -639,7 +617,7 @@
" [0. , 0.97, 0.01, 0. , 0. , 0. , 0. , 0.01, 0. , 0. ]])"
]
},
"execution_count": 22,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
@ -662,7 +640,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 22,
"metadata": {
"scrolled": true
},
@ -714,7 +692,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 23,
"metadata": {},
"outputs": [
{
@ -740,7 +718,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 24,
"metadata": {},
"outputs": [
{
@ -762,7 +740,7 @@
" 'my_mnist_model/0002/variables/variables.index']"
]
},
"execution_count": 25,
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
@ -780,7 +758,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
@ -795,7 +773,7 @@
},
{
"cell_type": "code",
"execution_count": 27,
"execution_count": 26,
"metadata": {},
"outputs": [
{
@ -804,7 +782,7 @@
"dict_keys(['predictions'])"
]
},
"execution_count": 27,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
@ -815,7 +793,7 @@
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 27,
"metadata": {},
"outputs": [
{
@ -826,7 +804,7 @@
" [0. , 0.99, 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])"
]
},
"execution_count": 28,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
@ -857,7 +835,7 @@
},
{
"cell_type": "code",
"execution_count": 29,
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
@ -875,7 +853,7 @@
},
{
"cell_type": "code",
"execution_count": 30,
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
@ -891,7 +869,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
@ -907,7 +885,7 @@
},
{
"cell_type": "code",
"execution_count": 32,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
@ -954,7 +932,7 @@
},
{
"cell_type": "code",
"execution_count": 33,
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
@ -963,7 +941,7 @@
},
{
"cell_type": "code",
"execution_count": 34,
"execution_count": 33,
"metadata": {},
"outputs": [
{
@ -1000,7 +978,7 @@
},
{
"cell_type": "code",
"execution_count": 35,
"execution_count": 34,
"metadata": {},
"outputs": [
{
@ -1033,7 +1011,7 @@
},
{
"cell_type": "code",
"execution_count": 36,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@ -1042,7 +1020,7 @@
},
{
"cell_type": "code",
"execution_count": 37,
"execution_count": 36,
"metadata": {},
"outputs": [
{
@ -1053,7 +1031,7 @@
" [0. , 0.97, 0.01, 0. , 0. , 0. , 0. , 0.01, 0. , 0. ]])"
]
},
"execution_count": 37,
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
@ -1066,7 +1044,7 @@
},
{
"cell_type": "code",
"execution_count": 38,
"execution_count": 37,
"metadata": {},
"outputs": [
{
@ -1096,7 +1074,7 @@
},
{
"cell_type": "code",
"execution_count": 39,
"execution_count": 38,
"metadata": {},
"outputs": [
{
@ -1120,7 +1098,7 @@
},
{
"cell_type": "code",
"execution_count": 40,
"execution_count": 39,
"metadata": {},
"outputs": [
{
@ -1175,7 +1153,7 @@
},
{
"cell_type": "code",
"execution_count": 41,
"execution_count": 40,
"metadata": {},
"outputs": [
{
@ -1184,7 +1162,7 @@
"gcs_output_directory: \"gs://my_bucket/my_mnist_predictions/prediction-mnist-2022_04_12T21_30_08_071Z\""
]
},
"execution_count": 41,
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
@ -1195,7 +1173,7 @@
},
{
"cell_type": "code",
"execution_count": 42,
"execution_count": 41,
"metadata": {},
"outputs": [
{
@ -1220,7 +1198,7 @@
},
{
"cell_type": "code",
"execution_count": 43,
"execution_count": 42,
"metadata": {},
"outputs": [],
"source": [
@ -1230,7 +1208,7 @@
},
{
"cell_type": "code",
"execution_count": 44,
"execution_count": 43,
"metadata": {},
"outputs": [
{
@ -1239,7 +1217,7 @@
"0.98"
]
},
"execution_count": 44,
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
@ -1250,7 +1228,7 @@
},
{
"cell_type": "code",
"execution_count": 45,
"execution_count": 44,
"metadata": {},
"outputs": [
{
@ -1276,7 +1254,7 @@
},
{
"cell_type": "code",
"execution_count": 46,
"execution_count": 45,
"metadata": {},
"outputs": [
{
@ -1308,7 +1286,7 @@
},
{
"cell_type": "code",
"execution_count": 47,
"execution_count": 46,
"metadata": {},
"outputs": [
{
@ -1337,7 +1315,7 @@
},
{
"cell_type": "code",
"execution_count": 48,
"execution_count": 47,
"metadata": {},
"outputs": [],
"source": [
@ -1347,7 +1325,7 @@
},
{
"cell_type": "code",
"execution_count": 49,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
@ -1356,7 +1334,7 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 49,
"metadata": {},
"outputs": [
{
@ -1435,7 +1413,7 @@
},
{
"cell_type": "code",
"execution_count": 51,
"execution_count": 50,
"metadata": {},
"outputs": [
{
@ -1474,7 +1452,7 @@
},
{
"cell_type": "code",
"execution_count": 52,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
@ -1494,7 +1472,7 @@
},
{
"cell_type": "code",
"execution_count": 53,
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
@ -1518,7 +1496,7 @@
},
{
"cell_type": "code",
"execution_count": 54,
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
@ -1531,7 +1509,7 @@
},
{
"cell_type": "code",
"execution_count": 55,
"execution_count": 54,
"metadata": {},
"outputs": [
{
@ -1563,7 +1541,7 @@
},
{
"cell_type": "code",
"execution_count": 56,
"execution_count": 55,
"metadata": {},
"outputs": [],
"source": [
@ -1573,7 +1551,7 @@
},
{
"cell_type": "code",
"execution_count": 57,
"execution_count": 56,
"metadata": {},
"outputs": [
{
@ -1582,7 +1560,7 @@
"'/job:localhost/replica:0/task:0/device:GPU:0'"
]
},
"execution_count": 57,
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
@ -1594,7 +1572,7 @@
},
{
"cell_type": "code",
"execution_count": 58,
"execution_count": 57,
"metadata": {},
"outputs": [
{
@ -1603,7 +1581,7 @@
"'/job:localhost/replica:0/task:0/device:CPU:0'"
]
},
"execution_count": 58,
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
@ -1622,7 +1600,7 @@
},
{
"cell_type": "code",
"execution_count": 59,
"execution_count": 58,
"metadata": {},
"outputs": [
{
@ -1631,7 +1609,7 @@
"'/job:localhost/replica:0/task:0/device:CPU:0'"
]
},
"execution_count": 59,
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
@ -1652,7 +1630,7 @@
},
{
"cell_type": "code",
"execution_count": 60,
"execution_count": 59,
"metadata": {},
"outputs": [
{
@ -1661,7 +1639,7 @@
"\"'/job:localhost/replica:0/task:0/device:GPU:0'\""
]
},
"execution_count": 60,
"execution_count": 59,
"metadata": {},
"output_type": "execute_result"
}
@ -1684,7 +1662,7 @@
},
{
"cell_type": "code",
"execution_count": 61,
"execution_count": 60,
"metadata": {},
"outputs": [
{
@ -1724,7 +1702,7 @@
},
{
"cell_type": "code",
"execution_count": 62,
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
@ -1750,7 +1728,7 @@
},
{
"cell_type": "code",
"execution_count": 63,
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
@ -1777,7 +1755,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
@ -1798,7 +1776,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 64,
"metadata": {},
"outputs": [
{
@ -1807,7 +1785,7 @@
"tensorflow.python.distribute.values.MirroredVariable"
]
},
"execution_count": 65,
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
@ -1818,7 +1796,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 65,
"metadata": {},
"outputs": [
{
@ -1837,7 +1815,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 66,
"metadata": {},
"outputs": [
{
@ -1853,7 +1831,7 @@
"tensorflow.python.ops.resource_variable_ops.ResourceVariable"
]
},
"execution_count": 67,
"execution_count": 66,
"metadata": {},
"output_type": "execute_result"
}
@ -1868,7 +1846,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@ -1878,7 +1856,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 68,
"metadata": {},
"outputs": [
{
@ -1902,7 +1880,7 @@
},
{
"cell_type": "code",
"execution_count": 70,
"execution_count": 69,
"metadata": {},
"outputs": [
{
@ -1927,7 +1905,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 70,
"metadata": {},
"outputs": [
{
@ -1952,7 +1930,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 71,
"metadata": {},
"outputs": [
{
@ -1969,7 +1947,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
@ -2008,7 +1986,7 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
@ -2032,7 +2010,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
@ -2058,7 +2036,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 75,
"metadata": {},
"outputs": [
{
@ -2067,7 +2045,7 @@
"ClusterSpec({'ps': ['machine-a.example.com:2221'], 'worker': ['machine-a.example.com:2222', 'machine-b.example.com:2222']})"
]
},
"execution_count": 76,
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
@ -2079,7 +2057,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 76,
"metadata": {},
"outputs": [
{
@ -2088,7 +2066,7 @@
"'worker'"
]
},
"execution_count": 77,
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
@ -2099,7 +2077,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 77,
"metadata": {},
"outputs": [
{
@ -2108,7 +2086,7 @@
"0"
]
},
"execution_count": 78,
"execution_count": 77,
"metadata": {},
"output_type": "execute_result"
}
@ -2130,7 +2108,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 78,
"metadata": {},
"outputs": [
{
@ -2206,7 +2184,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
@ -2220,7 +2198,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -2250,7 +2228,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@ -2260,7 +2238,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -2285,7 +2263,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 83,
"metadata": {},
"outputs": [
{
@ -2356,7 +2334,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@ -2373,7 +2351,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 85,
"metadata": {},
"outputs": [
{
@ -2427,7 +2405,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
@ -2447,7 +2425,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 87,
"metadata": {},
"outputs": [
{
@ -2470,7 +2448,7 @@
"tuner_id = f'{tf_config[\"task\"][\"type\"]}{tf_config[\"task\"][\"index\"]}'\n",
"if tuner_id == \"chief0\":\n",
" tuner_id = \"chief\"\n",
" chief_ip = \"127.0.0.1\"\n",
" chief_ip = \"0.0.0.0\"\n",
" # extra code shows one way to start a worker on the chief machine\n",
" # import subprocess\n",
" # import sys\n",
@ -2487,7 +2465,8 @@
"import keras_tuner as kt\n",
"import tensorflow as tf\n",
"\n",
"gcs_path = \"gs://my_bucket/my_hp_search\" # replace with your bucket's name\n",
"gcs_path = \"gs://my_bucket/my_hp_search\"\n",
"gcs_path = gcs_path.replace(\"gs://\", \"/gcs/\") # uses GCS Fuse\n",
"\n",
"def build_model(hp):\n",
" n_hidden = hp.Int(\"n_hidden\", min_value=0, max_value=8, default=2)\n",
@ -2535,6 +2514,26 @@
" best_model.save(os.getenv(\"AIP_MODEL_DIR\"))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Replace `gs://my_bucket` with your bucket's name:"
]
},
{
"cell_type": "code",
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
"with open(\"my_keras_tuner_search.py\") as f:\n",
" script = f.read()\n",
"\n",
"with open(\"my_keras_tuner_search.py\", \"w\") as f:\n",
" f.write(script.replace(\"gs://my_bucket/\", f\"gs://{bucket_name}/\"))"
]
},
{
"cell_type": "code",
"execution_count": 89,