From 3ef350ab4c900baad5037cdf439178def3b60aeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 27 May 2019 20:35:00 +0800 Subject: [PATCH] Fix figure name and clarify a couple code examples --- 18_reinforcement_learning.ipynb | 82 ++++++--------------------------- 1 file changed, 15 insertions(+), 67 deletions(-) diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index 0b3bbef..0ba4cc9 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -233,7 +233,7 @@ "$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n", "```\n", "\n", - "Alternatively, you can install the `pyvirtualdisplay` Python library which wraps Xvfb:\n", + "Alternatively, you can install the [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) Python library which wraps Xvfb:\n", "\n", "```bash\n", "python3 -m pip install -U pyvirtualdisplay\n", @@ -1319,8 +1319,8 @@ " indices = np.random.randint(len(replay_memory), size=batch_size)\n", " batch = [replay_memory[index] for index in indices]\n", " states, actions, rewards, next_states, dones = [\n", - " np.array([experience[index] for experience in batch])\n", - " for index in range(5)]\n", + " np.array([experience[field_index] for experience in batch])\n", + " for field_index in range(5)]\n", " return states, actions, rewards, next_states, dones" ] }, @@ -1580,7 +1580,7 @@ "plt.plot(rewards)\n", "plt.xlabel(\"Episode\", fontsize=14)\n", "plt.ylabel(\"Sum of rewards\", fontsize=14)\n", - "save_fig(\"dqn_rewards_plot\")\n", + "save_fig(\"double_dqn_rewards_plot\")\n", "plt.show()" ] }, @@ -2057,7 +2057,7 @@ " current_frame_delta = np.maximum(obs[..., 3] - obs[..., :3].mean(axis=-1), 0.)\n", " img[..., 0] += current_frame_delta\n", " img[..., 2] += current_frame_delta\n", - " img = (img - img.min()) / (img.max() - img.min())\n", + " img = np.clip(img / 150, 0, 1)\n", " plt.imshow(img)\n", " plt.axis(\"off\")" ] @@ -2459,58 +2459,6 @@ "train_agent(n_iterations=200) # change this to 10 million or more!" ] }, - { - "cell_type": "code", - "execution_count": 119, - "metadata": {}, - "outputs": [], - "source": [ - "num_eval_episodes = 10\n", - "eval_metrics = [\n", - " tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),\n", - " tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 120, - "metadata": {}, - "outputs": [], - "source": [ - "eval_tf_env = suite_atari.load(\n", - " environment_name,\n", - " max_episode_steps=max_episode_steps,\n", - " gym_env_wrappers=[AtariPreprocessing, FrameStack4])\n", - "\n", - "eval_tf_env = TFPyEnvironment(eval_tf_env)" - ] - }, - { - "cell_type": "code", - "execution_count": 121, - "metadata": {}, - "outputs": [], - "source": [ - "from tf_agents.eval import metric_utils\n", - "\n", - "results = metric_utils.eager_compute(\n", - " eval_metrics,\n", - " eval_tf_env,\n", - " agent.policy,\n", - " num_episodes=num_eval_episodes,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": {}, - "outputs": [], - "source": [ - "results" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -2534,7 +2482,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 119, "metadata": {}, "outputs": [], "source": [ @@ -2548,7 +2496,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 120, "metadata": {}, "outputs": [], "source": [ @@ -2557,7 +2505,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 121, "metadata": {}, "outputs": [], "source": [ @@ -2573,7 +2521,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ @@ -2596,7 +2544,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -2608,7 +2556,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -2617,7 +2565,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -2640,7 +2588,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -2689,7 +2637,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -2700,7 +2648,7 @@ }, { "cell_type": "code", - "execution_count": 132, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [