From 49715d4b740f2785367a908d3ae785b20857b55d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= <ageron@users.noreply.github.com>
Date: Thu, 12 Mar 2020 22:47:22 +1300
Subject: [PATCH] Fix bug in training_step: target_Q_values must be a column
 vector

---
 18_reinforcement_learning.ipynb | 23 +++++++++++++++++------
 1 file changed, 17 insertions(+), 6 deletions(-)

diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb
index fb744dd..1db21d5 100644
--- a/18_reinforcement_learning.ipynb
+++ b/18_reinforcement_learning.ipynb
@@ -67,7 +67,7 @@
     "from tensorflow import keras\n",
     "assert tf.__version__ >= \"2.0\"\n",
     "\n",
-    "if not tf.test.is_gpu_available():\n",
+    "if not tf.config.list_physical_devices('GPU'):\n",
     "    print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
     "    if IS_COLAB:\n",
     "        print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
@@ -574,6 +574,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "keras.backend.clear_session()\n",
     "tf.random.set_seed(42)\n",
     "np.random.seed(42)\n",
     "\n",
@@ -638,7 +639,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 26,
+   "execution_count": null,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -882,6 +883,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "keras.backend.clear_session()\n",
     "np.random.seed(42)\n",
     "tf.random.set_seed(42)\n",
     "\n",
@@ -1274,6 +1276,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "keras.backend.clear_session()\n",
     "tf.random.set_seed(42)\n",
     "np.random.seed(42)\n",
     "\n",
@@ -1392,7 +1395,9 @@
     "    states, actions, rewards, next_states, dones = experiences\n",
     "    next_Q_values = model.predict(next_states)\n",
     "    max_next_Q_values = np.max(next_Q_values, axis=1)\n",
-    "    target_Q_values = rewards + (1 - dones) * discount_rate * max_next_Q_values\n",
+    "    target_Q_values = (rewards +\n",
+    "                       (1 - dones) * discount_rate * max_next_Q_values)\n",
+    "    target_Q_values = target_Q_values.reshape(-1, 1)\n",
     "    mask = tf.one_hot(actions, n_outputs)\n",
     "    with tf.GradientTape() as tape:\n",
     "        all_Q_values = model(states)\n",
@@ -1505,6 +1510,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "keras.backend.clear_session()\n",
     "tf.random.set_seed(42)\n",
     "np.random.seed(42)\n",
     "\n",
@@ -1536,7 +1542,9 @@
     "    best_next_actions = np.argmax(next_Q_values, axis=1)\n",
     "    next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
     "    next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
-    "    target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
+    "    target_Q_values = (rewards + \n",
+    "                       (1 - dones) * discount_rate * next_best_Q_values)\n",
+    "    target_Q_values = target_Q_values.reshape(-1, 1)\n",
     "    mask = tf.one_hot(actions, n_outputs)\n",
     "    with tf.GradientTape() as tape:\n",
     "        all_Q_values = model(states)\n",
@@ -1646,6 +1654,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
+    "keras.backend.clear_session()\n",
     "tf.random.set_seed(42)\n",
     "np.random.seed(42)\n",
     "\n",
@@ -1681,7 +1690,9 @@
     "    best_next_actions = np.argmax(next_Q_values, axis=1)\n",
     "    next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
     "    next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
-    "    target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
+    "    target_Q_values = (rewards + \n",
+    "                       (1 - dones) * discount_rate * next_best_Q_values)\n",
+    "    target_Q_values = target_Q_values.reshape(-1, 1)\n",
     "    mask = tf.one_hot(actions, n_outputs)\n",
     "    with tf.GradientTape() as tape:\n",
     "        all_Q_values = model(states)\n",
@@ -2777,7 +2788,7 @@
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
-   "version": "3.7.3"
+   "version": "3.7.6"
   }
  },
  "nbformat": 4,