Save agent's breakout performance to an animated gif

main
Aurélien Geron 2019-05-28 09:30:16 +08:00
parent 3ef350ab4c
commit 4c3b7b9b06
1 changed files with 124 additions and 52 deletions

View File

@ -176,46 +176,6 @@
"An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"env.render()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"img = env.render(mode=\"rgb_array\")\n",
"img.shape"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def plot_environment(env, figsize=(5,4)):\n",
" plt.figure(figsize=figsize)\n",
" img = env.render(mode=\"rgb_array\")\n",
" plt.imshow(img)\n",
" plt.axis(\"off\")\n",
" return img"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -244,7 +204,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -255,6 +215,46 @@
" pass"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"env.render()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"img = env.render(mode=\"rgb_array\")\n",
"img.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"def plot_environment(env, figsize=(5,4)):\n",
" plt.figure(figsize=figsize)\n",
" img = env.render(mode=\"rgb_array\")\n",
" plt.imshow(img)\n",
" plt.axis(\"off\")\n",
" return img"
]
},
{
"cell_type": "code",
"execution_count": 11,
@ -2450,13 +2450,84 @@
" log_metrics(train_metrics)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the next cell to train the agent for 10,000 steps. Then look at its behavior by running the following cell. You can run these two cells as many times as you wish. The agent will keep improving!"
]
},
{
"cell_type": "code",
"execution_count": 118,
"metadata": {},
"outputs": [],
"source": [
"train_agent(n_iterations=200) # change this to 10 million or more!"
"train_agent(n_iterations=10000)"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"frames = []\n",
"def save_frames(trajectory):\n",
" global frames\n",
" frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n",
"\n",
"prev_lives = tf_env.pyenv.envs[0].ale.lives()\n",
"def reset_and_fire_on_life_lost(trajectory):\n",
" global prev_lives\n",
" lives = tf_env.pyenv.envs[0].ale.lives()\n",
" if prev_lives != lives:\n",
" tf_env.reset()\n",
" tf_env.pyenv.envs[0].step(1)\n",
" prev_lives = lives\n",
"\n",
"watch_driver = DynamicStepDriver(\n",
" tf_env,\n",
" agent.policy,\n",
" observers=[save_frames, reset_and_fire_on_life_lost, ShowProgress(1000)],\n",
" num_steps=1000)\n",
"final_time_step, final_policy_state = watch_driver.run()\n",
"\n",
"plot_animation(frames)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to save an animated GIF to show off your agent to your friends, here's one way to do it:"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"import PIL\n",
"\n",
"image_path = os.join(\"images\", \"rl\", \"breakout.gif\")\n",
"frame_images = [PIL.Image.fromarray(frame) for frame in frames[:150]]\n",
"frame_images[0].save(image_path, format='GIF',\n",
" append_images=frame_images[1:],\n",
" save_all=True,\n",
" duration=30,\n",
" loop=0)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"%%html\n",
"<img src=\"images/rl/breakout.gif\" />"
]
},
{
@ -2482,10 +2553,11 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"np.random.seed(42)\n",
"\n",
"mem = deque(maxlen=1000000)\n",
@ -2496,7 +2568,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
@ -2505,7 +2577,7 @@
},
{
"cell_type": "code",
"execution_count": 121,
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
@ -2521,7 +2593,7 @@
},
{
"cell_type": "code",
"execution_count": 122,
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
@ -2544,7 +2616,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -2556,7 +2628,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -2565,7 +2637,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [
@ -2588,7 +2660,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 129,
"metadata": {},
"outputs": [],
"source": [
@ -2637,7 +2709,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 130,
"metadata": {},
"outputs": [],
"source": [
@ -2648,7 +2720,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 131,
"metadata": {},
"outputs": [],
"source": [