From edf4006ab45e72e7a5e3e3c87d322041aa868313 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Mon, 25 Sep 2017 14:08:42 +0200 Subject: [PATCH] Add example preprocessing for Breakout game in notebook for chapter 16 --- 16_reinforcement_learning.ipynb | 121 +++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 3 deletions(-) diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index 7840009..b353f0f 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1652,9 +1652,7 @@ { "cell_type": "code", "execution_count": 65, - "metadata": { - "collapsed": true - }, + "metadata": {}, "outputs": [], "source": [ "n_steps = 4000000 # total number of training steps\n", @@ -1815,6 +1813,123 @@ "plot_animation(frames)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Extra material" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocessing for Breakout" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here is a preprocessing function you can use to train a DQN for the Breakout-v0 Atari game:" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "def preprocess_observation(obs):\n", + " img = obs[34:194:2, ::2] # crop and downsize\n", + " return np.mean(img, axis=2).reshape(80, 80) / 255.0" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": {}, + "outputs": [], + "source": [ + "env = gym.make(\"Breakout-v0\")\n", + "obs = env.reset()\n", + "for step in range(10):\n", + " obs, _, _, _ = env.step(1)\n", + "\n", + "img = preprocess_observation(obs)" + ] + }, + { + "cell_type": "code", + "execution_count": 73, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(11, 7))\n", + "plt.subplot(121)\n", + "plt.title(\"Original observation (160×210 RGB)\")\n", + "plt.imshow(obs)\n", + "plt.axis(\"off\")\n", + "plt.subplot(122)\n", + "plt.title(\"Preprocessed observation (80×80 grayscale)\")\n", + "plt.imshow(img, interpolation=\"nearest\", cmap=\"gray\")\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, a single image does not give you the direction and speed of the ball, which are crucial informations for playing this game. For this reason, it is best to actually combine several consecutive observations to create the environment's state representation. One way to do that is to create a multi-channel image, with one channel per recent observation. Another is to merge all recent observations into a single-channel image, using `np.max()`. In this case, we need to dim the older images so that the DQN can distinguish the past from the present." + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [], + "source": [ + "def combine_observations_multichannel(preprocessed_observations):\n", + " return np.array(preprocessed_observations).transpose([1, 2, 0])\n", + "\n", + "def combine_observations_singlechannel(preprocessed_observations, dim_factor=0.5):\n", + " dimmed_observations = [obs * dim_factor**index\n", + " for index, obs in enumerate(reversed(preprocessed_observations))]\n", + " return np.max(np.array(dimmed_observations), axis=0)\n", + "\n", + "n_observations_per_state = 3\n", + "preprocessed_observations = deque([], maxlen=n_observations_per_state)\n", + "\n", + "obs = env.reset()\n", + "for step in range(10):\n", + " obs, _, _, _ = env.step(1)\n", + " preprocessed_observations.append(preprocess_observation(obs))" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [], + "source": [ + "img1 = combine_observations_multichannel(preprocessed_observations)\n", + "img2 = combine_observations_singlechannel(preprocessed_observations)\n", + "\n", + "plt.figure(figsize=(11, 7))\n", + "plt.subplot(121)\n", + "plt.title(\"Multichannel state\")\n", + "plt.imshow(img1, interpolation=\"nearest\")\n", + "plt.axis(\"off\")\n", + "plt.subplot(122)\n", + "plt.title(\"Singlechannel state\")\n", + "plt.imshow(img2, interpolation=\"nearest\", cmap=\"gray\")\n", + "plt.axis(\"off\")\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {},