From 4c3b7b9b066e4d8019183815c7e09afdd8f28b70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Tue, 28 May 2019 09:30:16 +0800 Subject: [PATCH] Save agent's breakout performance to an animated gif --- 18_reinforcement_learning.ipynb | 176 ++++++++++++++++++++++---------- 1 file changed, 124 insertions(+), 52 deletions(-) diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index 0ba4cc9..1e72747 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -176,46 +176,6 @@ "An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment)." ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "env.render()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "img = env.render(mode=\"rgb_array\")\n", - "img.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_environment(env, figsize=(5,4)):\n", - " plt.figure(figsize=figsize)\n", - " img = env.render(mode=\"rgb_array\")\n", - " plt.imshow(img)\n", - " plt.axis(\"off\")\n", - " return img" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -244,7 +204,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -255,6 +215,46 @@ " pass" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "env.render()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "img = env.render(mode=\"rgb_array\")\n", + "img.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_environment(env, figsize=(5,4)):\n", + " plt.figure(figsize=figsize)\n", + " img = env.render(mode=\"rgb_array\")\n", + " plt.imshow(img)\n", + " plt.axis(\"off\")\n", + " return img" + ] + }, { "cell_type": "code", "execution_count": 11, @@ -2450,13 +2450,84 @@ " log_metrics(train_metrics)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Run the next cell to train the agent for 10,000 steps. Then look at its behavior by running the following cell. You can run these two cells as many times as you wish. The agent will keep improving!" + ] + }, { "cell_type": "code", "execution_count": 118, "metadata": {}, "outputs": [], "source": [ - "train_agent(n_iterations=200) # change this to 10 million or more!" + "train_agent(n_iterations=10000)" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "metadata": {}, + "outputs": [], + "source": [ + "frames = []\n", + "def save_frames(trajectory):\n", + " global frames\n", + " frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n", + "\n", + "prev_lives = tf_env.pyenv.envs[0].ale.lives()\n", + "def reset_and_fire_on_life_lost(trajectory):\n", + " global prev_lives\n", + " lives = tf_env.pyenv.envs[0].ale.lives()\n", + " if prev_lives != lives:\n", + " tf_env.reset()\n", + " tf_env.pyenv.envs[0].step(1)\n", + " prev_lives = lives\n", + "\n", + "watch_driver = DynamicStepDriver(\n", + " tf_env,\n", + " agent.policy,\n", + " observers=[save_frames, reset_and_fire_on_life_lost, ShowProgress(1000)],\n", + " num_steps=1000)\n", + "final_time_step, final_policy_state = watch_driver.run()\n", + "\n", + "plot_animation(frames)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to save an animated GIF to show off your agent to your friends, here's one way to do it:" + ] + }, + { + "cell_type": "code", + "execution_count": 120, + "metadata": {}, + "outputs": [], + "source": [ + "import PIL\n", + "\n", + "image_path = os.join(\"images\", \"rl\", \"breakout.gif\")\n", + "frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n", + "frame_images[0].save(image_path, format='GIF',\n", + " append_images=frame_images[1:],\n", + " save_all=True,\n", + " duration=30,\n", + " loop=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 121, + "metadata": {}, + "outputs": [], + "source": [ + "%%html\n", + "" ] }, { @@ -2482,10 +2553,11 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 122, "metadata": {}, "outputs": [], "source": [ + "from collections import deque\n", "np.random.seed(42)\n", "\n", "mem = deque(maxlen=1000000)\n", @@ -2496,7 +2568,7 @@ }, { "cell_type": "code", - "execution_count": 120, + "execution_count": 123, "metadata": {}, "outputs": [], "source": [ @@ -2505,7 +2577,7 @@ }, { "cell_type": "code", - "execution_count": 121, + "execution_count": 124, "metadata": {}, "outputs": [], "source": [ @@ -2521,7 +2593,7 @@ }, { "cell_type": "code", - "execution_count": 122, + "execution_count": 125, "metadata": {}, "outputs": [], "source": [ @@ -2544,7 +2616,7 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 126, "metadata": {}, "outputs": [], "source": [ @@ -2556,7 +2628,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 127, "metadata": {}, "outputs": [], "source": [ @@ -2565,7 +2637,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 128, "metadata": {}, "outputs": [], "source": [ @@ -2588,7 +2660,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 129, "metadata": {}, "outputs": [], "source": [ @@ -2637,7 +2709,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 130, "metadata": {}, "outputs": [], "source": [ @@ -2648,7 +2720,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 131, "metadata": {}, "outputs": [], "source": [