Fix bug in training_step: target_Q_values must be a column vector

main
Aurélien Geron 2020-03-12 22:47:22 +13:00
parent 0c2c80d89e
commit 49715d4b74
1 changed files with 17 additions and 6 deletions

View File

@ -67,7 +67,7 @@
"from tensorflow import keras\n",
"assert tf.__version__ >= \"2.0\"\n",
"\n",
"if not tf.test.is_gpu_available():\n",
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
" if IS_COLAB:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
@ -574,6 +574,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
@ -638,7 +639,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@ -882,6 +883,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
@ -1274,6 +1276,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
@ -1392,7 +1395,9 @@
" states, actions, rewards, next_states, dones = experiences\n",
" next_Q_values = model.predict(next_states)\n",
" max_next_Q_values = np.max(next_Q_values, axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * max_next_Q_values\n",
" target_Q_values = (rewards +\n",
" (1 - dones) * discount_rate * max_next_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
@ -1505,6 +1510,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
@ -1536,7 +1542,9 @@
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
" target_Q_values = (rewards + \n",
" (1 - dones) * discount_rate * next_best_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
@ -1646,6 +1654,7 @@
"metadata": {},
"outputs": [],
"source": [
"keras.backend.clear_session()\n",
"tf.random.set_seed(42)\n",
"np.random.seed(42)\n",
"\n",
@ -1681,7 +1690,9 @@
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
" target_Q_values = (rewards + \n",
" (1 - dones) * discount_rate * next_best_Q_values)\n",
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
@ -2777,7 +2788,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
"version": "3.7.6"
}
},
"nbformat": 4,