From 8ebdcffc6b65257f76c9197776c6fd3f47d5eb4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 23 Nov 2020 16:52:37 +1300 Subject: [PATCH] Work around TF Agents issue: env.step(1) => env.step(np.array(1)) --- 18_reinforcement_learning.ipynb | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index e6d3717..b17da27 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -1860,13 +1860,20 @@ "env.reset()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning**: since TF Agents 0.4.0, there seems to be an issue with passing an integer to the `env.step()` method (it raises an `AttributeError`). You need to wrap it in a NumPy array, as done below. Please see [TF Agents Issue #520](https://github.com/tensorflow/agents/issues/520) for more details." + ] + }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [], "source": [ - "env.step(1) # Fire" + "env.step(np.array(1)) # Fire" ] }, { @@ -2074,9 +2081,9 @@ "source": [ "env.seed(42)\n", "env.reset()\n", - "time_step = env.step(1) # FIRE\n", + "time_step = env.step(np.array(1)) # FIRE\n", "for _ in range(4):\n", - " time_step = env.step(3) # LEFT" + " time_step = env.step(np.array(3)) # LEFT" ] }, { @@ -2215,7 +2222,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Create the replay buffer:" + "Create the replay buffer (this may use a lot of RAM, so please reduce the buffer size if you get an out-of-memory error):" ] }, { @@ -2521,7 +2528,7 @@ " lives = tf_env.pyenv.envs[0].ale.lives()\n", " if prev_lives != lives:\n", " tf_env.reset()\n", - " tf_env.pyenv.envs[0].step(1)\n", + " tf_env.pyenv.envs[0].step(np.array(1))\n", " prev_lives = lives\n", "\n", "watch_driver = DynamicStepDriver(\n", @@ -2790,7 +2797,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.6" + "version": "3.7.8" } }, "nbformat": 4,