Add example preprocessing for Breakout game in notebook for chapter 16

main
Aurélien Geron 2017-09-25 14:08:42 +02:00
parent 02c41c9bc0
commit edf4006ab4
1 changed files with 118 additions and 3 deletions

View File

@ -1652,9 +1652,7 @@
{
"cell_type": "code",
"execution_count": 65,
"metadata": {
"collapsed": true
},
"metadata": {},
"outputs": [],
"source": [
"n_steps = 4000000 # total number of training steps\n",
@ -1815,6 +1813,123 @@
"plot_animation(frames)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extra material"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocessing for Breakout"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a preprocessing function you can use to train a DQN for the Breakout-v0 Atari game:"
]
},
{
"cell_type": "code",
"execution_count": 71,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def preprocess_observation(obs):\n",
" img = obs[34:194:2, ::2] # crop and downsize\n",
" return np.mean(img, axis=2).reshape(80, 80) / 255.0"
]
},
{
"cell_type": "code",
"execution_count": 72,
"metadata": {},
"outputs": [],
"source": [
"env = gym.make(\"Breakout-v0\")\n",
"obs = env.reset()\n",
"for step in range(10):\n",
" obs, _, _, _ = env.step(1)\n",
"\n",
"img = preprocess_observation(obs)"
]
},
{
"cell_type": "code",
"execution_count": 73,
"metadata": {},
"outputs": [],
"source": [
"plt.figure(figsize=(11, 7))\n",
"plt.subplot(121)\n",
"plt.title(\"Original observation (160×210 RGB)\")\n",
"plt.imshow(obs)\n",
"plt.axis(\"off\")\n",
"plt.subplot(122)\n",
"plt.title(\"Preprocessed observation (80×80 grayscale)\")\n",
"plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As you can see, a single image does not give you the direction and speed of the ball, which are crucial informations for playing this game. For this reason, it is best to actually combine several consecutive observations to create the environment's state representation. One way to do that is to create a multi-channel image, with one channel per recent observation. Another is to merge all recent observations into a single-channel image, using `np.max()`. In this case, we need to dim the older images so that the DQN can distinguish the past from the present."
]
},
{
"cell_type": "code",
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
"def combine_observations_multichannel(preprocessed_observations):\n",
" return np.array(preprocessed_observations).transpose([1, 2, 0])\n",
"\n",
"def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n",
" dimmed_observations = [obs * dim_factor**index\n",
" for index, obs in enumerate(reversed(preprocessed_observations))]\n",
" return np.max(np.array(dimmed_observations), axis=0)\n",
"\n",
"n_observations_per_state = 3\n",
"preprocessed_observations = deque([], maxlen=n_observations_per_state)\n",
"\n",
"obs = env.reset()\n",
"for step in range(10):\n",
" obs, _, _, _ = env.step(1)\n",
" preprocessed_observations.append(preprocess_observation(obs))"
]
},
{
"cell_type": "code",
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
"img1 = combine_observations_multichannel(preprocessed_observations)\n",
"img2 = combine_observations_singlechannel(preprocessed_observations)\n",
"\n",
"plt.figure(figsize=(11, 7))\n",
"plt.subplot(121)\n",
"plt.title(\"Multichannel state\")\n",
"plt.imshow(img1, interpolation=\"nearest\")\n",
"plt.axis(\"off\")\n",
"plt.subplot(122)\n",
"plt.title(\"Singlechannel state\")\n",
"plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},