2016-10-08 22:17:45 +02:00
{
"cells": [
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"**Chapter 18 – Reinforcement Learning**"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"_This notebook contains all the sample code and solutions to the exercises in chapter 18._"
2016-10-08 22:17:45 +02:00
]
},
2019-11-06 14:06:55 +01:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table align=\"left\">\n",
" <td>\n",
2022-02-19 10:09:28 +01:00
" <a href=\"https://colab.research.google.com/github/ageron/handson-ml3/blob/main/18_reinforcement_learning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
2019-11-06 14:06:55 +01:00
" </td>\n",
2021-05-25 05:17:10 +02:00
" <td>\n",
2022-02-19 10:09:28 +01:00
" <a target=\"_blank\" href=\"https://kaggle.com/kernels/welcome?src=https://github.com/ageron/handson-ml3/blob/main/18_reinforcement_learning.ipynb\"><img src=\"https://kaggle.com/static/images/open-in-kaggle.svg\" /></a>\n",
2021-05-25 05:17:10 +02:00
" </td>\n",
2019-11-06 14:06:55 +01:00
"</table>"
]
},
2022-02-19 10:09:28 +01:00
{
"cell_type": "markdown",
"metadata": {
"id": "dFXIv9qNpKzt",
"tags": []
},
2016-10-08 22:17:45 +02:00
"source": [
2021-10-15 10:46:27 +02:00
"# Setup"
]
},
{
"cell_type": "markdown",
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "8IPbJEmZpKzu"
},
2021-10-15 10:46:27 +02:00
"source": [
2022-02-19 11:03:20 +01:00
"This project requires Python 3.7 or above:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 1,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "TFSU3FCOpKzu"
},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"import sys\n",
2021-10-15 10:46:27 +02:00
"\n",
2022-02-19 11:03:20 +01:00
"assert sys.version_info >= (3, 7)"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TAlKky09pKzv"
},
"source": [
"It also requires Scikit-Learn ≥ 1.0.1:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 2,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "YqCwW7cMpKzw"
},
"outputs": [],
"source": [
2021-05-25 02:07:29 +02:00
"import sklearn\n",
2019-11-06 14:06:55 +01:00
"\n",
2022-02-19 10:09:28 +01:00
"assert sklearn.__version__ >= \"1.0.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GJtVEqxfpKzw"
},
"source": [
2022-02-28 23:41:27 +01:00
"And TensorFlow ≥ 2.8:"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 3,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "0Piq5se2pKzx"
},
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"import tensorflow as tf\n",
2016-10-08 22:17:45 +02:00
"\n",
2022-02-28 23:41:27 +01:00
"assert tf.__version__ >= \"2.8.0\""
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DDaDoLQTpKzx"
},
"source": [
2022-04-05 11:47:12 +02:00
"As we did in earlier chapters, let's define the default font sizes to make the figures prettier. We will also display some Matplotlib animations, and there are several possible options to do that: we will use the Javascript option."
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 4,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "8d4TH3NbpKzx"
},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"import matplotlib.animation\n",
2016-10-08 22:17:45 +02:00
"import matplotlib.pyplot as plt\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-02-19 10:09:28 +01:00
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
2022-04-05 11:47:12 +02:00
"plt.rc('ytick', labelsize=10)\n",
"plt.rc('animation', html='jshtml')"
2022-02-19 10:09:28 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RcoUIRsvpKzy"
},
"source": [
"And let's create the `images/rl` folder (if it doesn't already exist), and define the `save_fig()` function which is used through this notebook to save the figures in high-res for the book:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 5,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "PQFH5Y9PpKzy"
},
"outputs": [],
"source": [
"from pathlib import Path\n",
2016-10-08 22:17:45 +02:00
"\n",
2021-10-15 10:46:27 +02:00
"IMAGES_PATH = Path() / \"images\" / \"rl\"\n",
"IMAGES_PATH.mkdir(parents=True, exist_ok=True)\n",
2016-10-08 22:17:45 +02:00
"\n",
2019-05-26 17:30:39 +02:00
"def save_fig(fig_id, tight_layout=True, fig_extension=\"png\", resolution=300):\n",
2021-10-15 10:46:27 +02:00
" path = IMAGES_PATH / f\"{fig_id}.{fig_extension}\"\n",
2016-10-08 22:17:45 +02:00
" if tight_layout:\n",
" plt.tight_layout()\n",
2019-05-26 17:30:39 +02:00
" plt.savefig(path, format=fig_extension, dpi=resolution)"
2017-06-08 15:44:00 +02:00
]
},
2022-02-19 10:09:28 +01:00
{
"cell_type": "markdown",
"metadata": {
"id": "YTsawKlapKzy"
},
"source": [
"This chapter can be very slow without a GPU, so let's make sure there's one, or else issue a warning:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 6,
2022-02-19 10:09:28 +01:00
"metadata": {
"id": "Ekxzo6pOpKzy"
},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"No GPU was detected. Neural nets can be very slow without a GPU.\n"
]
}
],
2022-02-19 10:09:28 +01:00
"source": [
"if not tf.config.list_physical_devices('GPU'):\n",
" print(\"No GPU was detected. Neural nets can be very slow without a GPU.\")\n",
" if \"google.colab\" in sys.modules:\n",
" print(\"Go to Runtime > Change runtime and select a GPU hardware \"\n",
" \"accelerator.\")\n",
" if \"kaggle_secrets\" in sys.modules:\n",
" print(\"Go to Settings > Accelerator and select GPU.\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's install the gym library, which provides many environments for Reinforcement Learning. Some of these environments require an X server to plot graphics, so we need to install xvfb on Colab or Kaggle (that's an in-memory X server, since the runtimes are not hooked to a screen). We also need to install pyvirtualdisplay, which provides a Python interface to xvfb. And let's also install the Box2D and Atari environments. By running the following cell, you also accept the Atari ROM license."
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 7,
2022-02-19 10:09:28 +01:00
"metadata": {},
"outputs": [],
"source": [
"if \"google.colab\" in sys.modules or \"kaggle_secrets\" in sys.modules:\n",
2022-04-05 11:47:12 +02:00
" %pip install -q -U gym\n",
" %pip install -q -U gym[box2d,atari,accept-rom-license]\n",
2022-02-19 10:09:28 +01:00
" !apt update &> /dev/null && apt install -y xvfb &> /dev/null\n",
2022-04-05 11:47:12 +02:00
" %pip install -q -U pyvirtualdisplay"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"**Warning**: some environments (including the Cart-Pole) require access to your display, which opens up a separate window. In general you can safely ignore that window. However, if Jupyter is running on a headless server (ie. without a screen) it will raise an exception. Examples of headless servers include Colab, Kaggle, or Docker containers. One way to avoid this is to install an X server like [Xvfb](http://en.wikipedia.org/wiki/Xvfb), which performs all graphical operations on a virtual display, in memory. You can then start Jupyter using the `xvfb-run` command:\n",
"\n",
"```bash\n",
"$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter lab\n",
"```\n",
"\n",
"Alternatively, you can install the [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) Python library which wraps Xvfb, and lets you create a virtual display. Let's create a virtual display using `pyvirtualdisplay`, if it is installed:\n"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 8,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"try:\n",
" import pyvirtualdisplay\n",
"\n",
" display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()\n",
"except ImportError:\n",
" pass"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# Introduction to OpenAI gym"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"In this notebook we will be using [OpenAI gym](https://gym.openai.com/), a great toolkit for developing and comparing Reinforcement Learning algorithms. It provides many environments for your learning *agents* to interact with. Let's import Gym and make a new environment:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 9,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"import gym\n",
"\n",
"env = gym.make(\"CartPole-v1\")"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"The Cart-Pole (version 1) is a very simple environment composed of a cart that can move left or right, and pole placed vertically on top of it. The agent must move the cart left or right to keep the pole upright."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"**Tip**: you can use `gym.envs.registry.all()` to get the full list of available environments:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 10,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"['ALE/Tetris-v5',\n",
" 'ALE/Tetris-ram-v5',\n",
" 'ALE/Asterix-v5',\n",
" 'ALE/Asterix-ram-v5',\n",
" 'ALE/Asteroids-v5',\n",
" '...']"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – shows the first few environments\n",
"envs = gym.envs.registry.all()\n",
"[env.id for env in envs][:5] + [\"...\"]"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Let's initialize the environment by calling is `reset()` method. This returns an observation:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 11,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.0273956 , -0.00611216, 0.03585979, 0.0197368 ], dtype=float32)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"obs = env.reset(seed=42)\n",
"obs"
2016-10-08 22:17:45 +02:00
]
},
2019-05-28 03:30:16 +02:00
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2019-05-28 03:30:16 +02:00
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"Observations vary depending on the environment. In this case it is a 1D NumPy array composed of 4 floats: they represent the cart's horizontal position, its velocity, the angle of the pole (0 = vertical), and the angular velocity."
2019-05-28 03:30:16 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment)."
2019-05-28 03:30:16 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 12,
2019-05-28 03:30:16 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"(400, 600, 3)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
2019-05-28 03:30:16 +02:00
"source": [
"img = env.render(mode=\"rgb_array\")\n",
2022-04-05 11:47:12 +02:00
"img.shape # height, width, channels (3 = Red, Green, Blue)"
2019-05-28 03:30:16 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 13,
2019-05-28 03:30:16 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAADICAYAAACuyvefAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAGVElEQVR4nO3dzY4bWRmA4a9sd/d0t8RkEpQwAoSQRggh2MEFIHqB2LDOPlJug1vgIlhklQ3bSFmEDdIo7IKIYBDTIWnCQJOQ/nFX1WGRQcLTaTdlV8WfM8+zPGVZZ1F6fapc9qlKKQGQxWjVEwD4X6IEpCJKQCqiBKQiSkAqogSkMrnkuOcFgCFUFx2wUgJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSCVyaonwLujmR5HW09nxkaTrRhvvreiGbGORIlelFLiyce/jr///jcz4ze+/+P4+o9+vqJZsY5EiZ6UaKbH0UyPZkbb+mxF82FduadEP0qJ0tSrngXvAFGiF6WUaFtRYnmiRE+slOiHKNEPl2/0RJTohcs3+iJK9KJtzmL68rPZwaqKra98dTUTYm2JEr0oTR318cuZsaoaxcbuByuaEetKlBhQFaOxR+HoRpQYVDXeWPUUWDOixHCqsFKiM1FiQFWMrJToSJQYVGWlREeixGCqykqJ7kSJXpS2ifKmA1X1tqfCmhMlelHaOuLNWYJORIletE2tSfRClOhFac5CleiDKNGL1j8E0BNRohelqaMUKyWWJ0r0oj49iijtzFg13ojKt290JEr04vgfT6K0zczYe+/fiNHm9opmxLoSJXpy/tJtNJ5YKdGZKDGYajSOCFGiG1FiMNV4ElXlFKMbZwyDqUZjPzOhM1FiMO4psQhRYjDVaBLh8o2OnDEsrZTyxgcnq5HTi+6cNfTioo0oXb7RlSjRg+K3b/RGlFheKVHas1XPgneEKNGLiy7foCtRYmmluHyjP6JEL6yU6IsosbRS2jg7eXlufLy1u4LZsO5EiaWVpo7Tw4MvjFaxc+2bK5kP602UGIYtu1mQKDGYykaULECUGMxoIkp0J0oMxJbdLEaUGEzlnhILECUGY6XEIkSJpZXSvnlvXP8QwAJEiaWVpo6wESU9ESWW9vonJqJEP0SJpbWtLbvpjyixNJdv9EmUWFpp64hoVz0N3hGixNLa5vzlW1VVUdkdlwV4uo25jo6Oommaua85fPpJtGfTmbHJ7tU4qSPql+f/0uS/qqqK3d1dmwswo7rkBqUbBV9yN2/ejAcPHsx9zc9++K249dPvzYz98em/4he/+m28Orn4z9+uXbsW9+/fjytXrvQxVdbLhZ9EVkrM9fz589jf35/7msOP3o/Tdjv+evpR1GUjvrb5SRwf/y3295/Eq5OLNxSYTqfRtu5FMUuUWNpJuxMPX/wk/ll/GBER+yffjTj9LJrWQpvuRImlHUy//XmQXq/IT9rd+PTfP4hWlFiAb98YRN200Xp2iQWIEku7sfnn+GDyLF5/L1Jia/QqvrH5u2jcL2IBcy/fDg8P39I0yKquL9866dOnf4k/fPzLeHLynajLRny49ac4OHh86UPebdvGixcvYjTy2fhlM+8b17lRunPnTt9zYc0cHHxxl5LzHj5+Fg8fP4uIe53e+/T0NO7evRs7OzsLzo51dfv27QuPeU6Jufb29uLevW6x+X9dv349Hj16FFevXh3k/UntwueUrJuBVEQJSEWUgFRECUhFlIBU/MyEuW7duhV7e3uDvPfOzk5sb28P8t6sL48EAKvgkQBgPYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkMrkkuPVW5kFwOeslIBURAlIRZSAVEQJSEWUgFRECUjlPyRsVw4Q1yV5AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 360x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-28 03:30:16 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – creates a little function to render and plot an environment\n",
"\n",
"def plot_environment(env, figsize=(5, 4)):\n",
2019-05-28 03:30:16 +02:00
" plt.figure(figsize=figsize)\n",
" img = env.render(mode=\"rgb_array\")\n",
" plt.imshow(img)\n",
" plt.axis(\"off\")\n",
2022-04-05 11:47:12 +02:00
" return img\n",
"\n",
2019-05-26 17:30:39 +02:00
"plot_environment(env)\n",
"plt.show()"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Let's see how to interact with an environment. Your agent will need to select an action from an \"action space\" (the set of possible actions). Let's see what this environment's action space looks like:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 14,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"Discrete(2)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"env.action_space"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Yep, just two possible actions: accelerate towards the left or towards the right."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Since the pole is leaning toward the right (`obs[2] > 0`), let's accelerate the cart toward the right:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 15,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.02727336, 0.18847767, 0.03625453, -0.26141977], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"action = 1 # accelerate right\n",
"obs, reward, done, info = env.step(action)\n",
"obs"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Notice that the cart is now moving toward the right (`obs[1] > 0`). The pole is still tilted toward the right (`obs[2] > 0`), but its angular velocity is now negative (`obs[3] < 0`), so it will likely be tilted toward the left after the next step."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 16,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAVgAAADqCAYAAADnGV2KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAHWElEQVR4nO3dz25cVx3A8d+9M/Y4tZvIIsWKi/jTboAliyzZ8AiE98gTRCiPEPEczRYpbFgiIYMQtF1AiUoEaik2CiGNZ+bew8KVUGrP2E7z89wz/XyWvtb4LK6+OTlz7j1NKSUAeP3aVQ8AYF0JLEASgQVIIrAASQQWIInAAiQZn3PdHi6A5ZpFF8xgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiDJeNUDgMsopY/Dv/w2+tnxqWu77/woxpPtFYwKziawVKXv5vHkN+/F9Nnhl640sbP3rsAyKJYIqErp5hGlrHoYcCECS1VKN48SAksdBJaqlN4MlnoILFXp+3kUgaUSAktVSteZwVINgaUqfTePsAZLJQSWqhRLBFREYKnKi39/cuZDBptvfiNGm1srGBEsJrBUZf7505OdBF+y8cb1aMebKxgRLCawrIW2HUc0bmeGxR3JWmhGo2iaZtXDgJcILGuhMYNlgNyRrIWmNYNleASWtWAGyxC5I1kL1mAZIoGlGsseMLCLgCFyR1KVUvqzLzSNGSyDI7BU5eRdBFAHgaUi5eREA6iEwFKPUqLvZqseBVyYwFIVM1hqIrBUo5RiDZaqCCxVKZYIqIjAUpES/RmvKoShEliqUbouXhz94/SFpomtG3tXPyA4h8BSjVK6mD47PPXzpmljcv2tFYwIlhNY1kI7Gq96CHCKwLIGmmhGG6seBJwisNSviWgFlgESWNaCJQKGSGBZA5YIGCaBZS20YzNYhkdgqV7TmMEyTAJLNUq/+ESDph1d4UjgYgSWapQlj8k6y4AhEliq0XezpedywdAILNU4eReswFIPgaUafTfXV6oisFTj5F2wCks9BJZqOM2A2ggs1Si+5KIyAks1+vk0zlwiaGzSYpgElmo8++SjM0+V3fnmO9FuTFYwIlhOYKnGogcN2o1JRONWZnjclVSvGY2j8SwXAySwVK8dbXhWlkESWKrXjMahsAyRwFK9th1HYycBAySwVM8MlqESWKrXjsb2wjJIAksVSikLX0PgZdsMlcBSiRL9whduN9ZgGSSBpQ6lROm7VY8CLkVgqUIp5czHZGHIBJY6lGVLBDBMAksVSpQonSUC6iKw1KH0S0+VhSESWOpgDZYKCSxVKKUsPjLGFi0GarzqAfD1dZnjX7rj5/H8s49P/bwdT2L7re9e6rPsmeWqNOfcmA5AIs2TJ0/izp07MZ+f/1//m9cn8fOf/TBG7ctxfH48j1/88s/x10+fn/8ZN2/Gw4cP49q1a688ZjjDwn+xzWBZmePj4zg4OIjZbHbu7+7ffDP6n34/SjOJF912NE2JrfZZdF0Xf/zTB/Hhx5+d+xm3bt2Kvu9fx9DhQgSWakzLG/Hhf34c/5p+K5qmi1uTj2I/fh3Tme1bDJPAUoV5P4nfP/1JPC1vR0QTUTbiby9+EIezeUzn7616eHAmuwiowqxM4mi+Fy8vdzXxz+nbcTz3VQHDJLBUrZSI6dy6KsMksFRh0v43vrP1fkT8P6ZtzON7134X3Xy6uoHBEtZgqULpZ7Hz+a9iZ/pp/P343WibPr699X5sTv8Qswts84JVWLoP9v79+xa3SHN0dBQPHjy48Napkz2wTZQv1mGbKBFRousvdpvu7OzE3bt3Y2Nj4xVHDKfdu3dv4T7YpYE9OjoSWNI8fvw4bt++faEHDV6Hvb29ODg48KABr9Xu7u6rPWiwu7v7+kcDXzg8PLzSx1bbto0bN27E9vb2lf1Nvt58yQWQRGABkggsQBKBBUgisABJBBYgiSe5WJn9/f149OjRpU4j+Co2Nzdja2vrSv4WRDjRAOCrWriZ2xIBQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQZn3O9uZJRAKwhM1iAJAILkERgAZIILEASgQVIIrAASf4HJtl/NcHAwnIAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 360x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – displays the environment\n",
2019-05-26 17:30:39 +02:00
"plot_environment(env)\n",
2022-04-05 11:47:12 +02:00
"save_fig(\"cart_pole_plot\")\n",
"plt.show()"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Looks like it's doing what we're telling it to do!"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
"The environment also tells the agent how much reward it got during the last step:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 17,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"1.0"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
"reward"
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
"When the game is over, the environment returns `done=True`:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 18,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
"done"
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Finally, `info` is an environment-specific dictionary that can provide some extra information that you may find useful for debugging or for training. For example, in some games it may indicate how many lives the agent has."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 19,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"{}"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
"info"
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"The sequence of steps between the moment the environment is reset until it is done is called an \"episode\". At the end of an episode (i.e., when `step()` returns `done=True`), you should reset the environment before you continue to use it."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 20,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"if done:\n",
" obs = env.reset()"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Now how can we make the poll remain upright? We will need to define a _policy_ for that. This is the strategy that the agent will use to select an action at each step. It can use all the past actions and observations to decide what to do."
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# A simple hard-coded policy"
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Let's hard code a simple strategy: if the pole is tilting to the left, then push the cart to the left, and _vice versa_. Let's see if that works:"
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 21,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"def basic_policy(obs):\n",
" angle = obs[2]\n",
" return 0 if angle < 0 else 1\n",
"\n",
"totals = []\n",
"for episode in range(500):\n",
" episode_rewards = 0\n",
2022-04-05 11:47:12 +02:00
" obs = env.reset(seed=episode)\n",
2019-05-26 17:30:39 +02:00
" for step in range(200):\n",
" action = basic_policy(obs)\n",
" obs, reward, done, info = env.step(action)\n",
" episode_rewards += reward\n",
" if done:\n",
" break\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" totals.append(episode_rewards)"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 22,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"(41.698, 8.389445512070509, 24.0, 63.0)"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"import numpy as np\n",
"\n",
2022-02-19 10:09:28 +01:00
"np.mean(totals), np.std(totals), min(totals), max(totals)"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Well, as expected, this strategy is a bit too basic: the best it did was to keep the poll up for only 68 steps. This environment is considered solved when the agent keeps the poll up for 200 steps."
2016-10-09 11:01:56 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Let's visualize one episode. You can learn more about Matplotlib animations in the [Matplotlib tutorial notebook](tools_matplotlib.ipynb#Animations)."
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 23,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – this cell displays an animation of one episode\n",
2019-05-26 17:30:39 +02:00
"\n",
"def update_scene(num, frames, patch):\n",
" patch.set_data(frames[num])\n",
" return patch,\n",
"\n",
"def plot_animation(frames, repeat=False, interval=40):\n",
" fig = plt.figure()\n",
" patch = plt.imshow(frames[0])\n",
" plt.axis('off')\n",
2022-04-05 11:47:12 +02:00
" anim = matplotlib.animation.FuncAnimation(\n",
2019-05-26 17:30:39 +02:00
" fig, update_scene, fargs=(frames, patch),\n",
" frames=len(frames), repeat=repeat, interval=interval)\n",
" plt.close()\n",
2022-04-05 11:47:12 +02:00
" return anim\n",
"\n",
"def show_one_episode(policy, n_max_steps=200, seed=42):\n",
" frames = []\n",
" env = gym.make(\"CartPole-v1\")\n",
" np.random.seed(seed)\n",
" obs = env.reset(seed=seed)\n",
" for step in range(n_max_steps):\n",
" frames.append(env.render(mode=\"rgb_array\"))\n",
" action = policy(obs)\n",
" obs, reward, done, info = env.step(action)\n",
" if done:\n",
" break\n",
" env.close()\n",
" return plot_animation(frames)\n",
"\n",
"show_one_episode(basic_policy)"
2016-10-09 11:01:56 +02:00
]
},
2016-10-08 22:17:45 +02:00
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Clearly the system is unstable and after just a few wobbles, the pole ends up too tilted: game over. We will need to be smarter than that!"
2016-10-09 11:01:56 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# Neural Network Policies"
2016-10-09 11:01:56 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2020-10-17 16:04:51 +02:00
"Let's create a neural network that will take observations as inputs, and output the probabilities of actions to take for each observation. To choose an action, the network will estimate a probability for each action, then we will select an action randomly according to the estimated probabilities. In the case of the Cart-Pole environment, there are just two possible actions (left or right), so we only need one output neuron: it will output the probability `p` of the action 0 (left), and of course the probability of action 1 (right) will be `1 - p`."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 24,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
2016-10-09 11:01:56 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"import tensorflow as tf\n",
2016-10-09 11:01:56 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on the CPU\n",
2019-05-26 17:30:39 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2022-04-05 11:47:12 +02:00
" tf.keras.layers.Dense(5, activation=\"relu\"),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(1, activation=\"sigmoid\"),\n",
2019-05-26 17:30:39 +02:00
"])"
2016-10-09 11:01:56 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"In this particular environment, the past actions and observations can safely be ignored, since each observation contains the environment's full state. If there were some hidden state then you may need to consider past actions and observations in order to try to infer the hidden state of the environment. For example, if the environment only revealed the position of the cart but not its velocity, you would have to consider not only the current observation but also the previous observation in order to estimate the current velocity. Another example is if the observations are noisy: you may want to use the past few observations to estimate the most likely current state. Our problem is thus as simple as can be: the current observation is noise-free and contains the environment's full state."
2016-10-09 11:01:56 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"You may wonder why we plan to pick a random action based on the probability given by the policy network, rather than just picking the action with the highest probability. This approach lets the agent find the right balance between _exploring_ new actions and _exploiting_ the actions that are known to work well. Here's an analogy: suppose you go to a restaurant for the first time, and all the dishes look equally appealing so you randomly pick one. If it turns out to be good, you can increase the probability to order it next time, but you shouldn't increase that probability to 100%, or else you will never try out the other dishes, some of which may be even better than the one you tried."
2016-10-09 11:01:56 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Let's write a small policy function that will use the neural net to get the probability of moving left, then let's use it to run one episode:"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 25,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – a function that creates an animation for a given policy model\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"def pg_policy(obs):\n",
" left_proba = model.predict(obs[np.newaxis])\n",
" return int(np.random.rand() > left_proba)\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"np.random.seed(42)\n",
"show_one_episode(pg_policy)"
2016-10-08 22:17:45 +02:00
]
},
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Yeah... pretty bad. The neural network will have to learn to do better. First let's see if it is capable of learning the basic policy we used earlier: go left if the pole is tilting left, and go right if it is tilting right."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Let's see if it can learn a better policy on its own. One that does not wobble as much."
2016-10-08 22:17:45 +02:00
]
},
2016-10-09 11:01:56 +02:00
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-09 11:01:56 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# Policy Gradients"
2016-10-09 11:01:56 +02:00
]
},
2016-10-08 22:17:45 +02:00
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"To train this neural network we will need to define the target probabilities `y`. If an action is good we should increase its probability, and conversely if it is bad we should reduce it. But how do we know whether an action is good or bad? The problem is that most actions have delayed effects, so when you win or lose points in an episode, it is not clear which actions contributed to this result: was it just the last action? Or the last 10? Or just one action 50 steps earlier? This is called the _credit assignment problem_.\n",
"\n",
2022-04-05 11:47:12 +02:00
"The _Policy Gradients_ algorithm tackles this problem by first playing multiple episodes, then making the actions near positive rewards slightly more likely, while actions near negative rewards are made slightly less likely. First we play, then we go back and think about what we did."
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Let's start by creating a function to play a single step using the model. We will also pretend for now that whatever action it takes is the right one, so we can compute the loss and its gradients. We will just save these gradients for now, and modify them later depending on how good or bad the action turned out to be."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 26,
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"def play_one_step(env, obs, model, loss_fn):\n",
" with tf.GradientTape() as tape:\n",
" left_proba = model(obs[np.newaxis])\n",
" action = (tf.random.uniform([1, 1]) > left_proba)\n",
" y_target = tf.constant([[1.]]) - tf.cast(action, tf.float32)\n",
" loss = tf.reduce_mean(loss_fn(y_target, left_proba))\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" grads = tape.gradient(loss, model.trainable_variables)\n",
2022-04-05 11:47:12 +02:00
" obs, reward, done, info = env.step(int(action))\n",
2019-05-26 17:30:39 +02:00
" return obs, reward, done, grads"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"If `left_proba` is high, then `action` will most likely be `False` (since a random number uniformally sampled between 0 and 1 will probably not be greater than `left_proba`). And `False` means 0 when you cast it to a number, so `y_target` would be equal to 1 - 0 = 1. In other words, we set the target to 1, meaning we pretend that the probability of going left should have been 100% (so we took the right action)."
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Now let's create another function that will rely on the `play_one_step()` function to play multiple episodes, returning all the rewards and gradients, for each episode and each step:"
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 27,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-08 22:17:45 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"def play_multiple_episodes(env, n_episodes, n_max_steps, model, loss_fn):\n",
" all_rewards = []\n",
" all_grads = []\n",
" for episode in range(n_episodes):\n",
" current_rewards = []\n",
" current_grads = []\n",
" obs = env.reset()\n",
" for step in range(n_max_steps):\n",
" obs, reward, done, grads = play_one_step(env, obs, model, loss_fn)\n",
" current_rewards.append(reward)\n",
" current_grads.append(grads)\n",
" if done:\n",
" break\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" all_rewards.append(current_rewards)\n",
" all_grads.append(current_grads)\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" return all_rewards, all_grads"
2016-10-08 22:17:45 +02:00
]
},
2017-04-30 10:21:27 +02:00
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2017-04-30 10:21:27 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"The Policy Gradients algorithm uses the model to play the episode several times (e.g., 10 times), then it goes back and looks at all the rewards, discounts them and normalizes them. So let's create couple functions for that: the first will compute discounted rewards; the second will normalize the discounted rewards across many episodes."
2017-04-30 10:21:27 +02:00
]
},
2016-10-08 22:17:45 +02:00
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 28,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-08 22:17:45 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"def discount_rewards(rewards, discount_factor):\n",
2019-05-26 17:30:39 +02:00
" discounted = np.array(rewards)\n",
" for step in range(len(rewards) - 2, -1, -1):\n",
2022-04-05 11:47:12 +02:00
" discounted[step] += discounted[step + 1] * discount_factor\n",
2019-05-26 17:30:39 +02:00
" return discounted\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"def discount_and_normalize_rewards(all_rewards, discount_factor):\n",
" all_discounted_rewards = [discount_rewards(rewards, discount_factor)\n",
2019-05-26 17:30:39 +02:00
" for rewards in all_rewards]\n",
" flat_rewards = np.concatenate(all_discounted_rewards)\n",
" reward_mean = flat_rewards.mean()\n",
" reward_std = flat_rewards.std()\n",
" return [(discounted_rewards - reward_mean) / reward_std\n",
" for discounted_rewards in all_discounted_rewards]"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Say there were 3 actions, and after each action there was a reward: first 10, then 0, then -50. If we use a discount factor of 80%, then the 3rd action will get -50 (full credit for the last reward), but the 2nd action will only get -40 (80% credit for the last reward), and the 1st action will get 80% of -40 (-32) plus full credit for the first reward (+10), which leads to a discounted reward of -22:"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 29,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([-22, -40, -50])"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"discount_rewards([10, 0, -50], discount_factor=0.8)"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"To normalize all discounted rewards across all episodes, we compute the mean and standard deviation of all the discounted rewards, and we subtract the mean from each discounted reward, and divide by the standard deviation:"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 30,
2019-05-26 17:30:39 +02:00
"metadata": {
"scrolled": true
},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"[array([-0.28435071, -0.86597718, -1.18910299]),\n",
" array([1.26665318, 1.0727777 ])]"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"discount_and_normalize_rewards([[10, 0, -50], [10, 20]],\n",
" discount_factor=0.8)"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 31,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"n_iterations = 150\n",
"n_episodes_per_update = 10\n",
"n_max_steps = 200\n",
2022-04-05 11:47:12 +02:00
"discount_factor = 0.95"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 32,
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – let's create the neural net and reset the environment, for\n",
"# reproducibility\n",
"\n",
"tf.random.set_seed(42)\n",
"\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(5, activation=\"relu\"),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\"),\n",
"])\n",
"\n",
"obs = env.reset(seed=42)"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 33,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=0.01)\n",
"loss_fn = tf.keras.losses.binary_crossentropy"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 34,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 150/150, mean rewards: 193.1"
]
}
],
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"for iteration in range(n_iterations):\n",
" all_rewards, all_grads = play_multiple_episodes(\n",
" env, n_episodes_per_update, n_max_steps, model, loss_fn)\n",
2022-04-05 11:47:12 +02:00
"\n",
" # extra code – displays some debug info during training\n",
" total_rewards = sum(map(sum, all_rewards))\n",
" print(f\"\\rIteration: {iteration + 1}/{n_iterations},\"\n",
" f\" mean rewards: {total_rewards / n_episodes_per_update:.1f}\", end=\"\")\n",
"\n",
2019-05-26 17:30:39 +02:00
" all_final_rewards = discount_and_normalize_rewards(all_rewards,\n",
2022-04-05 11:47:12 +02:00
" discount_factor)\n",
2019-05-26 17:30:39 +02:00
" all_mean_grads = []\n",
" for var_index in range(len(model.trainable_variables)):\n",
" mean_grads = tf.reduce_mean(\n",
" [final_reward * all_grads[episode_index][step][var_index]\n",
" for episode_index, final_rewards in enumerate(all_final_rewards)\n",
" for step, final_reward in enumerate(final_rewards)], axis=0)\n",
" all_mean_grads.append(mean_grads)\n",
2016-10-08 22:17:45 +02:00
"\n",
2022-04-05 11:47:12 +02:00
" optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables))"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 35,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – displays the animation\n",
"np.random.seed(42)\n",
"show_one_episode(pg_policy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Extra Material – Markov Chains"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"The following transition probabilities correspond to the Markov Chain represented in Figure 18– 7. Let's run this stochastic process a few times to see what it looks like:"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 36,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Run #1: 0 0 3 \n",
"Run #2: 0 1 2 1 2 1 2 1 2 1 3 \n",
"Run #3: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
"Run #4: 0 3 \n",
"Run #5: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
"Run #6: 0 1 3 \n",
"Run #7: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
"Run #8: 0 0 0 1 2 1 2 1 3 \n",
"Run #9: 0 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 3 \n",
"Run #10: 0 0 0 1 2 1 3 \n"
]
}
],
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"np.random.seed(42)\n",
2016-10-23 15:32:33 +02:00
"\n",
2019-05-26 17:30:39 +02:00
"transition_probabilities = [ # shape=[s, s']\n",
" [0.7, 0.2, 0.0, 0.1], # from s0 to s0, s1, s2, s3\n",
2022-04-05 11:47:12 +02:00
" [0.0, 0.0, 0.9, 0.1], # from s1 to s0, s1, s2, s3\n",
" [0.0, 1.0, 0.0, 0.0], # from s2 to s0, s1, s2, s3\n",
" [0.0, 0.0, 0.0, 1.0]] # from s3 to s0, s1, s2, s3\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"n_max_steps = 1000 # to avoid blocking in case of an infinite loop\n",
"terminal_states = [3]\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"def run_chain(start_state):\n",
" current_state = start_state\n",
2019-05-26 17:30:39 +02:00
" for step in range(n_max_steps):\n",
" print(current_state, end=\" \")\n",
2022-04-05 11:47:12 +02:00
" if current_state in terminal_states:\n",
2019-05-26 17:30:39 +02:00
" break\n",
2022-04-05 11:47:12 +02:00
" current_state = np.random.choice(\n",
" range(len(transition_probabilities)),\n",
" p=transition_probabilities[current_state]\n",
" )\n",
2019-05-26 17:30:39 +02:00
" else:\n",
" print(\"...\", end=\"\")\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" print()\n",
"\n",
2022-04-05 11:47:12 +02:00
"for idx in range(10):\n",
" print(f\"Run #{idx + 1}: \", end=\"\")\n",
" run_chain(start_state=0)"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# Markov Decision Process"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Let's define some transition probabilities, rewards and possible actions. For example, in state s0, if action a0 is chosen then with proba 0.7 we will go to state s0 with reward +10, with probability 0.3 we will go to state s1 with no reward, and with never go to state s2 (so the transition probabilities are `[0.7, 0.3, 0.0]`, and the rewards are `[+10, 0, 0]`):"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 37,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"transition_probabilities = [ # shape=[s, a, s']\n",
" [[0.7, 0.3, 0.0], [1.0, 0.0, 0.0], [0.8, 0.2, 0.0]],\n",
" [[0.0, 1.0, 0.0], None, [0.0, 0.0, 1.0]],\n",
" [None, [0.8, 0.1, 0.1], None]\n",
"]\n",
"rewards = [ # shape=[s, a, s']\n",
" [[+10, 0, 0], [0, 0, 0], [0, 0, 0]],\n",
" [[0, 0, 0], [0, 0, 0], [0, 0, -50]],\n",
" [[0, 0, 0], [+40, 0, 0], [0, 0, 0]]\n",
"]\n",
2019-05-26 17:30:39 +02:00
"possible_actions = [[0, 1, 2], [0, 2], [1]]"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# Q-Value Iteration"
2016-10-08 22:17:45 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 38,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Q_values = np.full((3, 3), -np.inf) # -np.inf for impossible actions\n",
2019-05-26 17:30:39 +02:00
"for state, actions in enumerate(possible_actions):\n",
2022-04-05 11:47:12 +02:00
" Q_values[state, actions] = 0.0 # for all possible actions"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 39,
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"gamma = 0.90 # the discount factor\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"history1 = [] # extra code – needed for the figure below\n",
2019-05-26 17:30:39 +02:00
"for iteration in range(50):\n",
" Q_prev = Q_values.copy()\n",
2022-04-05 11:47:12 +02:00
" history1.append(Q_prev) # extra code\n",
2019-05-26 17:30:39 +02:00
" for s in range(3):\n",
" for a in possible_actions[s]:\n",
" Q_values[s, a] = np.sum([\n",
" transition_probabilities[s][a][sp]\n",
2022-02-19 10:09:28 +01:00
" * (rewards[s][a][sp] + gamma * Q_prev[sp].max())\n",
2019-05-26 17:30:39 +02:00
" for sp in range(3)])\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"history1 = np.array(history1) # extra code"
2016-10-08 22:17:45 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 40,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([[18.91891892, 17.02702702, 13.62162162],\n",
" [ 0. , -inf, -4.87971488],\n",
" [ -inf, 50.13365013, -inf]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"Q_values"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 41,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([0, 0, 1])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
2016-10-08 22:17:45 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Q_values.argmax(axis=1) # optimal action for each state"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"The optimal policy for this MDP, when using a discount factor of 0.90, is to choose action a0 when in state s0, and choose action a0 when in state s1, and finally choose action a1 (the only possible action) when in state s2. If you try again with a discount factor of 0.95 instead of 0.90, you will find that the optimal action for state s1 becomes a2. This is because the discount factor is larger so the agent values the future more, and it is therefore ready to pay an immediate penalty in order to get more future rewards."
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"# Q-Learning"
2016-10-23 15:32:33 +02:00
]
},
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"Q-Learning works by watching an agent play (e.g., randomly) and gradually improving its estimates of the Q-Values. Once it has accurate Q-Value estimates (or close enough), then the optimal policy consists in choosing the action that has the highest Q-Value (i.e., the greedy policy)."
2016-10-23 15:32:33 +02:00
]
},
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"We will need to simulate an agent moving around in the environment, so let's define a function to perform some action and get the new state and a reward:"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 42,
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"outputs": [],
"source": [
2019-05-26 17:30:39 +02:00
"def step(state, action):\n",
" probas = transition_probabilities[state][action]\n",
" next_state = np.random.choice([0, 1, 2], p=probas)\n",
" reward = rewards[state][action][next_state]\n",
" return next_state, reward"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We also need an exploration policy, which can be any policy, as long as it visits every possible state many times. We will just use a random policy, since the state space is very small:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 43,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"def exploration_policy(state):\n",
" return np.random.choice(possible_actions[state])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's initialize the Q-Values like earlier, and run the Q-Learning algorithm:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 44,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – initializes the Q-Values, just like earlier\n",
2019-05-26 17:30:39 +02:00
"np.random.seed(42)\n",
"Q_values = np.full((3, 3), -np.inf)\n",
"for state, actions in enumerate(possible_actions):\n",
2022-04-05 11:47:12 +02:00
" Q_values[state][actions] = 0"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"alpha0 = 0.05 # initial learning rate\n",
"decay = 0.005 # learning rate decay\n",
"gamma = 0.90 # discount factor\n",
"state = 0 # initial state\n",
"history2 = [] # extra code – needed for the figure below\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"for iteration in range(10_000):\n",
" history2.append(Q_values.copy()) # extra code\n",
2019-05-26 17:30:39 +02:00
" action = exploration_policy(state)\n",
" next_state, reward = step(state, action)\n",
2022-04-05 11:47:12 +02:00
" next_value = Q_values[next_state].max() # greedy policy at the next step\n",
2019-05-26 17:30:39 +02:00
" alpha = alpha0 / (1 + iteration * decay)\n",
" Q_values[state, action] *= 1 - alpha\n",
" Q_values[state, action] += alpha * (reward + gamma * next_value)\n",
" state = next_state\n",
"\n",
2022-04-05 11:47:12 +02:00
"history2 = np.array(history2) # extra code"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 46,
2017-09-25 14:08:10 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsAAAAEQCAYAAAC++cJdAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAA9a0lEQVR4nO3dd5xU1d3H8c+PDovSREAUQaWIiQELChIC9pbHJ9FYMRo1GjV5rIk9GhM1IebRx6ixxG5i1KhRbASRtcSCDRWlSFOqCAJKW2D39/xx7rizDWZ2Z/bOzv2+X6953Tu3ze/A7JnfnDn3HHN3RERERESSolncAYiIiIiINCYlwCIiIiKSKEqARURERCRRlACLiIiISKIoARYRERGRRFECLCIiIiKJogRYREREYmNmI83MzWyruGOR5FACLE2SmfWOKsw94o4lLmY218wujDsOEWk8ZtbTzO4ws/lmtt7MFpjZnWa27WbOO9nMVjVWnFl6DegBLIs7EEkOJcBSb/WpiM3sSDMrN7NedeyfZGZ/y1/U2YsS7aPqet4Ir3+VmU2pZdeewK2NFYeIxMvM+gBvA98CTgJ2AkYDuwBvmVnv+KKrycxaZXKcu69398WumbmkESkBlnppQEX8FLAU+Ekt1/wWIam7Kw8hF5xMPxzq4u5fuPuaXMUjIgXvFqAC2N/dJ7j7Z+4+Edg/2n5LfS9sZh2iBo0lZva1mb2U/gubmXUxs4eiBo+1ZvaRmf2k2jVKzewvZna9mX0B/Cete8N+Zvamma0xs7fNbLe086p0gUi1VkfnTDGz1WY2MfrcSX+9S8zs8+jY+83sSjObW99/A0kWJcBSX/WqiN19A3A/cLKZWbXdpwKzgYlmNtrM3ooq4iVm9qiZ9awrmNr6kNXWTcLMBprZM2nXfcjMumda6LTK9dHo2nPT9n3fzN4xs3VmNsfMrklPcqMuC1eZ2d1mtgL4W7T992Y2PfpQmWtmY8ysTbTvZOBKYJfo9TzaVqMLhJn1MrMnorJ9bWaPp7fGp1qSzexYM5sVHfMvU787kYJnZp2Bg4Fbqn/xjZ7fChxiZp3qcW0DngF6AocDg4GXgRfNrEd0WBvg3Wj/LsD/Abeb2X7VLjcaMOC7wI/Ttl8HXAzsRujq8LdaPgPStQYuAU4BhgIdgdvSYj6WUDdeFl1zKnB+FsWWhFMCLFnLQUV8F9Ab2Dftmq0IFefd0c9grQiV23cIFe5WwEMNjLsHoVKfAgwhJOvtgafMLNO/hT2j5U8Jfdb2jK59ECGhvZnw4XAKcBRwbbXzzwemAXsAl0bbVkfH7wycBRxLqNQBHgb+BEyPXq9HtK162Qz4F9CN8O86CtgG+Fe1D5newDHAD4ADCR9012RYdhGJT19CYjm1jv0fR/v71uPao4BBwFHuPsndZ7r7FYQGiRMB3H2Bu//R3Se7+2x3vwN4HDiu2rXmuPsF7j7N3dNjvcLdJ7r7NOBqYAAh4a5LC+DsKJ4PgOuBUWl19TnAve7+V3ef4e7XAW/Wo+ySUC3iDkCapGwq4knVd7r7NDP7D6HFd0K0+QigE3BvdMzdaafMNrMzgalmtq27z69n3GcC77v7RakNZvZj4EtCQloj1lpi/yLKJ1e4++K0XZcBf3T3e6Lns8zsIuBBM/tlWt+2l9x9TLVr/jbt6Vwzuxa4kPCBsdbCjSsbq71edfsTvizs6O5zo7IdD8wE9gNeiI5rAZzs7iujY+6glu4oIlKw6uonm/qiu96q3uz2oLv/bDPX3B1oB3xRrVG2DbAjgJk1J7TgHkNIXFsTGipKq13rnTpe44O09YXRcmugrvq8zN2nVzunJaEl+EtCAn1ntXPeBPrVcT2RKpQAS0NstiIGqKMyvgu41cw6uvsKQgvoc+6+IDpnN0IL8CCgc9o1e1F3hbk5uwMjrPY7oXckgwR4M9ceEiW9Kc2AtkB3YFG07e3qJ1q4oe5cQj/q9kDz6JGNnYGFqeQXwN1nm9lCYCCVCfCnqeQ3spDwISQihe0TQp27C+HXnup2BjYCcwj1ZspXGVy7GfA5odtCdanzLwQuILS8fgisIvzCVb3+WF3Ha2xIW099dmzql7eN1Z7Xdo5umpN6UwIs9ZFNRQy1V8aPADcCx5vZU4Sf448EMLMSYBwhaTsRWELoAvEKocWhNhXRMr35omW1Y5oR+rnVNnTY53VcN1PNgN8Aj9ay74u09SofDma2N/CP6NzzgBXAfxF+7suGUfeHQfr2DbXsU1cokQLn7l+a2fPAWWZ2Q3r3MzNrB5wNPBF9wV1Z13Xq8C6h+1SFu8+u45jhwFh3fyB6TSO0tq7I8rVyZRqhK9s9aduGxBSLNEFKgCVrWVbEuPvMWq6x2sz+QegG0YWQJD4d7R5ASHgvdfc50XV/uJmwUklmj7T1QdWOeRc4mtAKWj0RzMYGarbQvgsMqK2sm7EPsCC9G4SZbV/tmPW1vF51HwM9zax3WheIHQj9gD/OMiYRKUxnA68DL5jZ5YTGiB0J/fg3AP+zmfObmdmgats2Ehob/gM8aWa/IiSX3Qn3erzg7q8AM4BjzGw4YSSfXwB9gPdyUK76+D/gHjN7i9A48gNgL2B5TPFIE6OWH6mvswlJ2Qtmtq+ZbWdmI4HxZFYRQ+gGsRuh5fM+d0/95PUZUAb83Mx2MLPDgN/WcY2UmcA84Coz62dmBwKXVzvmFqAD8LCZ7RVde38LQ/9skUG8KXOB/cyse9qNflcTWrOvNrNvmdkAMzvKzMbUfRkgfKj0NLMTonjOpOZNJXOB7c1sNzPbysxa13KdF4D3CXdW725h5Iu/ERLzF7Mom4gUqKhBYA/gI+ABQt0wkfAL2KDN3CcAoUvWe9UepdE9CocS6oo7CTfdPgL0p7K/7u8I3cSeI9xMvJpoJJs4uPs/CJ8LvyeU41uEUSLWxRWTNDHuroce9XoA2xIqywVAOeHn9FeATllc4/3ovH7Vth8DzCJUZpOAg6LjRkb7e0fP90g7ZxgwGVhLaCU5rJZj+gL/JLQSrCVU9H8GWm0iRifcHZ16/n1Cy8sGYG7a9gOj8q8hdPV4G/h52v65wIW1XP86Qqv1KsJd1WeGP81v9rdOi9kJN7HVuB6hf/S/gK+jxxPAtmn7rwKmVHvtk4FVcb+X9NBDj/o9CC2x64Ej4o4l7kdU542NOw49msbD3NWHXHLDzH5BGLLrR+7+ZNzxiIgkgZkdTegKcaO7r407nsYQdbc7E3ie0I3jSEIr9ZHu/kScsUnToARYciqJFbGIiDQuM2sLjCWMZd6W8KvcGHePrVuGNC1KgEVEREQkUXQTnIiIiIgkihJgEREREUmUJj8OcMeOHX2nnXaKO4xGtXr1akpKSuIOo1GpzMUvaeUFeOedd5a6e9d8vobqyGRIWpmTVl5IZpnzWUc2+QS4W7duvP12jdlli1ppaSkjR46MO4xGpTIXv6SVF8DMPs33a6iOTIaklTlp5YVkljmfdaS6QIiIiIhIoigBFhEREZFEUQIsIiIiIomiBFhEREREEkUJsIiIiIgkihJgEREREUkUJcAiIiIikihKgEVEREQkUZQAi4iIiEiiKAEWERERkURRAiwiIiIiiaIEWEREREQSRQmwiIiIiCSKEmARERERSRQlwCIiIiKSKEqARURERCRRlACLiIiISKIoARYRERGRRFECLCIiIiKJogRYRERERBJFCbCIiIiIJIoSYBERERFJFCXAIiIiIpIoSoBFREREJFGUAIuIiDSy8nKYPh0qKmDdOpg0CcrKYMMG+Ne/wrqI5E+LuAMQEREpZu5gVnXbccfBo4/WfU7//jBxIqxZA88/Dz16wA9+UPM6IlI/SoBFRETy5Lrr4NJLK5+vWQMLFmw6+YXQOrzNNrXve/55OOig3MUokkTqAiEiIpIn6ckvQLt20Ldvw6558MEwZkzDriGSdLEkwGa2nZlNNLOpZvaRmZ0Tbe9sZuPN7JNo2SmO+EREROqrogJmzIAjjtj0cf/936F7RPpjwwZYuRL+9jdo2TIc17p1zXMvugimTct56CKJEVcL8EbgAnffGdgbONvMBgIXAxP
"text/plain": [
"<Figure size 720x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – this cell generates and saves Figure 18– 9\n",
"\n",
2019-05-26 17:30:39 +02:00
"true_Q_value = history1[-1, 0, 0]\n",
2016-10-23 15:32:33 +02:00
"\n",
2019-05-26 17:30:39 +02:00
"fig, axes = plt.subplots(1, 2, figsize=(10, 4), sharey=True)\n",
"axes[0].set_ylabel(\"Q-Value$(s_0, a_0)$\", fontsize=14)\n",
"axes[0].set_title(\"Q-Value Iteration\", fontsize=14)\n",
"axes[1].set_title(\"Q-Learning\", fontsize=14)\n",
"for ax, width, history in zip(axes, (50, 10000), (history1, history2)):\n",
" ax.plot([0, width], [true_Q_value, true_Q_value], \"k--\")\n",
" ax.plot(np.arange(width), history[:, 0, 0], \"b-\", linewidth=2)\n",
" ax.set_xlabel(\"Iterations\", fontsize=14)\n",
" ax.axis([0, width, 0, 24])\n",
2022-04-05 11:47:12 +02:00
" ax.grid(True)\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"save_fig(\"q_value_plot\")\n",
"plt.show()"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"# Deep Q-Network"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's build the DQN. Given a state, it will estimate, for each possible action, the sum of discounted future rewards it can expect after it plays that action (but before it sees its outcome):"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 47,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on the CPU\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"input_shape = [4] # == env.observation_space.shape\n",
"n_outputs = 2 # == env.action_space.n\n",
2019-05-26 17:30:39 +02:00
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(32, activation=\"elu\", input_shape=input_shape),\n",
" tf.keras.layers.Dense(32, activation=\"elu\"),\n",
" tf.keras.layers.Dense(n_outputs)\n",
2019-05-26 17:30:39 +02:00
"])"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2019-05-26 17:30:39 +02:00
"To select an action using this DQN, we just pick the action with the largest predicted Q-value. However, to ensure that the agent explores the environment, we choose a random action with probability `epsilon`."
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 48,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"def epsilon_greedy_policy(state, epsilon=0):\n",
" if np.random.rand() < epsilon:\n",
2022-04-05 11:47:12 +02:00
" return np.random.randint(n_outputs) # random action\n",
2019-05-26 17:30:39 +02:00
" else:\n",
2022-04-05 11:47:12 +02:00
" Q_values = model.predict(state[np.newaxis])[0]\n",
" return Q_values.argmax() # optimal action according to the DQN"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"We will also need a replay buffer. It will contain the agent's experiences, in the form of tuples: `(obs, action, reward, next_obs, done)`. We can use the `deque` class for that:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 49,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"\n",
2022-04-05 11:47:12 +02:00
"replay_buffer = deque(maxlen=2000)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Note**: for very large replay buffers, you may want to use a circular buffer instead, as random access time will be O(1) instead of O(N). Or you can check out DeepMind's [Reverb library](https://github.com/deepmind/reverb)."
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"# extra code – A basic circular buffer implementation\n",
"\n",
"class ReplayBuffer:\n",
" def __init__(self, max_size):\n",
" self.buffer = np.empty(max_size, dtype=np.object)\n",
" self.max_size = max_size\n",
" self.index = 0\n",
" self.size = 0\n",
"\n",
" def append(self, obj):\n",
" self.buffer[self.index] = obj\n",
" self.size = min(self.size + 1, self.max_size)\n",
" self.index = (self.index + 1) % self.max_size\n",
"\n",
" def sample(self, batch_size):\n",
" indices = np.random.randint(self.size, size=batch_size)\n",
" return self.buffer[indices]"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"And let's create a function to sample experiences from the replay buffer. It will return 5 NumPy arrays: `[obs, actions, rewards, next_obs, dones]`."
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 51,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"def sample_experiences(batch_size):\n",
2022-04-05 11:47:12 +02:00
" indices = np.random.randint(len(replay_buffer), size=batch_size)\n",
" batch = [replay_buffer[index] for index in indices]\n",
2019-05-26 17:30:39 +02:00
" states, actions, rewards, next_states, dones = [\n",
2019-05-27 14:35:00 +02:00
" np.array([experience[field_index] for experience in batch])\n",
2022-04-05 11:47:12 +02:00
" for field_index in range(5)\n",
" ]\n",
2019-05-26 17:30:39 +02:00
" return states, actions, rewards, next_states, dones"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"Now we can create a function that will use the DQN to play one step, and record its experience in the replay buffer:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 52,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"def play_one_step(env, state, epsilon):\n",
" action = epsilon_greedy_policy(state, epsilon)\n",
" next_state, reward, done, info = env.step(action)\n",
2022-04-05 11:47:12 +02:00
" replay_buffer.append((state, action, reward, next_state, done))\n",
2019-05-26 17:30:39 +02:00
" return next_state, reward, done, info"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"Lastly, let's create a function that will sample some experiences from the replay buffer and perform a training step:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 53,
"metadata": {},
"outputs": [],
"source": [
"# extra code – for reproducibility, and to generate the next figure\n",
"env.reset(seed=42)\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"rewards = [] \n",
"best_score = 0"
]
},
{
"cell_type": "code",
"execution_count": 54,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"batch_size = 32\n",
2022-04-05 11:47:12 +02:00
"discount_factor = 0.95\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=1e-2)\n",
2021-10-17 04:04:08 +02:00
"loss_fn = tf.keras.losses.mean_squared_error\n",
2019-05-26 17:30:39 +02:00
"\n",
"def training_step(batch_size):\n",
" experiences = sample_experiences(batch_size)\n",
" states, actions, rewards, next_states, dones = experiences\n",
" next_Q_values = model.predict(next_states)\n",
2022-02-19 10:09:28 +01:00
" max_next_Q_values = next_Q_values.max(axis=1)\n",
2020-03-12 10:47:22 +01:00
" target_Q_values = (rewards +\n",
2022-04-05 11:47:12 +02:00
" (1 - dones) * discount_factor * max_next_Q_values)\n",
2020-03-12 10:47:22 +01:00
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
2019-05-26 17:30:39 +02:00
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
" Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n",
" loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" grads = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_variables))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, let's train the model!"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 55,
2019-05-26 17:30:39 +02:00
"metadata": {
"scrolled": true
},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode: 600, Steps: 200, eps: 0.010"
]
}
],
2019-05-26 17:30:39 +02:00
"source": [
"for episode in range(600):\n",
" obs = env.reset() \n",
" for step in range(200):\n",
" epsilon = max(1 - episode / 500, 0.01)\n",
" obs, reward, done, info = play_one_step(env, obs, epsilon)\n",
" if done:\n",
" break\n",
2022-04-05 11:47:12 +02:00
"\n",
" # extra code – displays debug info, stores data for the next figure, and\n",
" # keeps track of the best model weights so far\n",
" print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n",
" end=\"\")\n",
" rewards.append(step)\n",
" if step >= best_score:\n",
" best_weights = model.get_weights()\n",
" best_score = step\n",
"\n",
2019-05-26 17:30:39 +02:00
" if episode > 50:\n",
" training_step(batch_size)\n",
"\n",
2022-04-05 11:47:12 +02:00
"model.set_weights(best_weights) # extra code – restores the best model weights"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 56,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjAAAAEQCAYAAACutU7EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACvsElEQVR4nO29d5Qcx3U9fF9P2oxFziBAEASYAJAEc1qKpESJsnK0LUu2JVqWHCT9bGVZkmXJlKz02UqmlWVZwZJoBeaAZQ4IBIic0yItdrF5d2LX90d3dVdVV/f07O5sAOqes2dmOlRV18x2vb7vvveIMQYDAwMDAwMDg8kEa7wHYGBgYGBgYGBQKYwBY2BgYGBgYDDpYAwYAwMDAwMDg0kHY8AYGBgYGBgYTDoYA8bAwMDAwMBg0sEYMAYGBgYGBgaTDsnxHsBoYsaMGWzx4sVVaXtgYAD19fVVaftMhJmvymDmq3KYOasMZr4qg5mvylCt+dqwYUMHY2ymbt8ZZcAsXrwY69evr0rbra2taGlpqUrbZyLMfFUGM1+Vw8xZZTDzVRnMfFWGas0XER0K22dcSAYGBgYGBgaTDsaAMTAwMDAwMJh0MAaMgYGBgYGBwaTDmBkwRLSQiNYS0Q4i2kZEf+9un0ZEDxPRHvd1qnDOx4hoLxHtIqJXjNVYDQwMDAwMDCY2xpKBKQL4f4yxCwBcDeD9RHQhgI8CeJQxtgzAo+5nuPveBuAiALcD+BYRJcZwvAYGBgYGBgYTFGNmwDDGjjPGNrrv+wDsADAfwGsB/Mg97EcAXue+fy2AnzPGcoyxAwD2ArhyrMZrYGBgYGBgMHExLhoYIloM4FIAzwOYzRg7DjhGDoBZ7mHzARwRTmtztxkYGBgYGBic5SDG2Nh2SNQA4HEAn2eM/YaIuhljzcL+LsbYVCL6JoBnGWP/7W7/HoD7GGO/Vtq7E8CdADB79uzLf/7zn1dl3P39/WhoaKhK22cizHxVhtGar0cOFXDFnCSmZEja/tTRAs6fmsCsutF9ZnmirYCjfTZevTSNxjSFHpctMvx2XwHTMoTbFqcq7meoyPDbvXnkS8DqWQmsnJnE5mP9SKRrsKDRwsaTJbxskdPu/p4SnjpaxLVzkzhvqux13tRexEunSsgkCa9bmkIm6Y/5ybYC2vpsXDIziVyJoa3PRk9Ovj8uaLRQtIETAzYAgAhYNTOBl06VcP7UBI702xjIM5w/NYG+PMM185JoSBMeO1zA7DoLG9uLuHx2EjNqCbtOl3DDgsrnIgzPHStiUZOFeQ3675j/xvIlhkcOF/CKc1JIWOHfWRgeO1zApbMSmFpj4Ym2Ai6cnsCMWrlPmzE8eLCIloVJ1Caj+2gftANzwc+/eWESNZrzj/XbONJn46q58VOZHeotoWOI4Wi/jUIJeP2yFCyS2z45YOOhQwVcMiOB8+qygf9JxhgeOVTE3Abnu1w9M4G2PhsLGi1kEoTl05zfW9FmePBgAQsbLWxqLyFlAX+0NI2XOkrY21WS2qxLEc5psrCjU97OcfGMBBZPsXD/gQJKtrwvYTn7N58qYdXMBI722ZjbYGGL+xtf1mxha4e+3dFETZLwqvn5qtzzb7755g2MsTW6fWOayI6IUgB+DeCnjLHfuJtPEtFcxthxIpoLoN3d3gZgoXD6AgDH1DYZY3cDuBsA1qxZw6qVeMgkNaoMZr4qw2jMV0d/Du964BEsP38ZXnvNYmnfuz56L6bXp7HhUy8bUR8q3vXRewEAd1x7CVpWzQs97um9Hbj/kecBAP/yzltBVNnCee9Lx/HAIxtBBAymmvB3b77G7TuHf37tRfjx2m344BtvwtT6NO771WY8drgNjdNm4d0tl0rtfPUbT2HL0R4wBrzjlstw7XkzAtfy4KGit60+nUBNylmUBvJFFI+WULQZalMJ1KUT6BzI48mjNvIlG48cds4jAh474rzvSEzFJ151Id71wFqvzVLtVGze1o2uwQI++SeVz0UY+PgP3nWHdj//ja3d1Y5fPrwOb7vlCly2aKr22DCc7M3iXQ88io09TbjnfdfhXZ+4H//w8vPxppZl0nGP7jiJXzy4HtQ0G3e9cWVkm9f+66M41pPHR952C9JJxxB6aNsJ/OLBDUg0z8EXXn9J4JzF7rV+5O23xh47P4fjb15zDZbPaZS2/fuje/Do4d3osOux+uJk4H/ypbZu/PTBp73PrUeKsAUbl8/9Nx7bg//dvVs69zXXrcTvN2xH54CNxoyz9OaLNvpyBaQSzm+gqUY2aHuzBbTbdXjzwoV4+NA2TK3zjS6bMXQNFvDoYWcMjx0uQsXDroE6pXb0DGUdptWn0dCQHvN7/pgZMOT8l34PwA7G2FeFXb8D8E4Ad7mvvxW2/w8RfRXAPADLALwwVuM1MJhssN07aa4oP6ZxlrVzIF+1vovuo+FD207g+mUzUJeWby0l4S7PmLPIV4KDnQMAgJXzp8BWnkKLJaftknudvCuxT45TfTmcM60OBzsHodkdwKf/6CK85QrnOepLD+zEt1r3AQD+9pbz8L6W87DiU/cjW5AHxNsHgIFcCdmi/ARcsp2FBxjeXIwUJXe+svnKn8z5UE/0ZL35LZSCE8nnpGeoULbNvmzRO3ZmYwaAw2AAwOn+6v1mc8Xg9fPfcTHkxzGkzFnYb+iZfZ3e+6l1KXQNFlAo2SiUGN50+QLPKHt4+0m858frUSgx3HrBbHz3nTLR8Nf/vQH7TvWj4I7r8Q/f7Bk52UIJKz71gHYMvM+izXDl4mn45Xuv0Q90FNHa2lr1PlSMpQbmOgDvAPAyItrk/r0KjuFyGxHtAXCb+xmMsW0AfglgO4AHALyfMVZ9LszAYJKC38fUm2+chXqkKNoMO0/04s6fbMDHf7MldGzq+7g40DGAWY0Z1GeSYFCvz/nMveHqZ69fxtDZn8esxhrpuCiIxoXobkm4O1QXBABvEQaAVII8A0sdL+AbXWMJ3v9QofLbKf8t5Yq2Z8Do5pFPS5zLa6xxjN2eId9YSbpzXVB9JqOIfDHYNv8+1O/M2x/zn2nzkW7vPTfmizZDyba9awMcF5DuPUcqYaFQYl6/4rmZpIUwD2BtynedJhNjbCGPIcaMgWGMPQXfgFdxS8g5nwfw+aoNysDgDAJfLAohDEw1YdsMAznnSfrw6UHN2Pwx2IwhEXor0ONQ5wAWT68HUdAg4017ho1iyHD0DhWRL9megcH3vni4C6sWNGv7FQ0U8T03ZsobMFZg0RMZpDhG1GiDD0dljuKg6A4+X7S9xV63qPNZUY1NHZpqUzjWk/VYKQBIua6kQhWtb5WpBHzjv6DSfMr+chgQmJq6tGNMlGwbRZtJhjCF/L44UgkL+aLt9aueW5dOoj8XdB3VpH0DZjg6p8kCk4nXwOAMAV8M1Zt+tZYA0Sgpd2OXGJhhDOhAxyAWz6iDRRQwyFTGhe9V+znVnwPgGxiMMTy55xRe/61n8INnDmr7taQn5OBio1saZjYIBkzS8hZ9DpF1CVknqwo+f9lhMDCcmciXbM9lqWORKnGLcQamWzBg0i4dURxjBoZf00gYGJU14gYMZ1ISFGTyAMDSGBrpJKFQ8tmuhDKxdWl9ajSRgUnpqJ0zBGfulRkYnGXgt1b1Blqth3zRaCl7Yxd2V8o6ZAsldPTnsHBqnXu+vJ9/9gwYTwsjH9ihGjDw2aK97X3avnWsC+DT8rqFeoZgwKQTVsC4U9mo0UAlLJvHwGg0IOXAr4Ux/zu3I777eC4kR9PRNei7kPi0VtOFFMXAhBlOcRgY1TCs9RgY14BJ6H9TqnECcBeSYMBYlRswScPAGBgYTHQwpr/5xqHxhwPRaBFv7LreRrJQ80WsJpVwGJiQtj0mxtsuH6caMGDlF9gwA8ZjYJRFpyZleYwCEKaB8d+PlgamkmZsj4Gp3DgQv3PfhaQ70pmXOMNq4hoYgYHhbesEwqMFnYjXEyaHGCqlMpSZbbPAvHINDDdEkprfEaB39aRdF1LJdT2pvzdVLM9Rmz47NDDGgDEwOEPgaWCUm361GBjx6di5sYffKMUxVDoevmY4N/Ag28AUw0V
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – this cell generates and saves Figure 18– 10\n",
2019-05-26 17:30:39 +02:00
"plt.figure(figsize=(8, 4))\n",
"plt.plot(rewards)\n",
"plt.xlabel(\"Episode\", fontsize=14)\n",
"plt.ylabel(\"Sum of rewards\", fontsize=14)\n",
2022-04-05 11:47:12 +02:00
"plt.grid(True)\n",
2019-05-26 17:30:39 +02:00
"save_fig(\"dqn_rewards_plot\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 57,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – shows an animation of the trained DQN playing one episode\n",
"show_one_episode(epsilon_greedy_policy)"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2021-03-09 10:21:08 +01:00
"Not bad at all! 😀"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"## Fixed Q-Value Targets"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create the online DQN:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 58,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – creates the same DQN model as earlier\n",
"\n",
2019-05-26 17:30:39 +02:00
"tf.random.set_seed(42)\n",
"\n",
2021-10-17 04:04:08 +02:00
"model = tf.keras.Sequential([\n",
2022-04-05 11:47:12 +02:00
" tf.keras.layers.Dense(32, activation=\"elu\", input_shape=input_shape),\n",
2021-10-17 04:04:08 +02:00
" tf.keras.layers.Dense(32, activation=\"elu\"),\n",
" tf.keras.layers.Dense(n_outputs)\n",
2022-04-05 11:47:12 +02:00
"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now create the target DQN: it's just a clone of the online DQN:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 59,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"target = tf.keras.models.clone_model(model) # clone the model's architecture\n",
"target.set_weights(model.get_weights()) # copy the weights"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we use the same code as above except for the line marked with `# <= CHANGED`:"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {},
"outputs": [],
"source": [
"env.reset(seed=42)\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"rewards = [] \n",
"best_score = 0\n",
"\n",
2019-05-26 17:30:39 +02:00
"batch_size = 32\n",
2022-04-05 11:47:12 +02:00
"discount_factor = 0.95\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=1e-2)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
"\n",
"replay_buffer = deque(maxlen=2000) # resets the replay buffer\n",
2019-05-26 17:30:39 +02:00
"\n",
"def training_step(batch_size):\n",
" experiences = sample_experiences(batch_size)\n",
" states, actions, rewards, next_states, dones = experiences\n",
2022-04-05 11:47:12 +02:00
" next_Q_values = target.predict(next_states) # <= CHANGED\n",
" max_next_Q_values = next_Q_values.max(axis=1)\n",
" target_Q_values = (rewards +\n",
" (1 - dones) * discount_factor * max_next_Q_values)\n",
2020-03-12 10:47:22 +01:00
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
2019-05-26 17:30:39 +02:00
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
" Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n",
" loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" grads = tape.gradient(loss, model.trainable_variables)\n",
" optimizer.apply_gradients(zip(grads, model.trainable_variables))"
]
},
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"Again, this is the same code as earlier, except for the lines marked with `# <= CHANGED`:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 61,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode: 600, Steps: 200, eps: 0.010"
]
}
],
2019-05-26 17:30:39 +02:00
"source": [
"for episode in range(600):\n",
" obs = env.reset() \n",
" for step in range(200):\n",
" epsilon = max(1 - episode / 500, 0.01)\n",
" obs, reward, done, info = play_one_step(env, obs, epsilon)\n",
" if done:\n",
" break\n",
2022-04-05 11:47:12 +02:00
"\n",
" # extra code – displays debug info, stores data for the next figure, and\n",
" # keeps track of the best model weights so far\n",
" print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n",
" end=\"\")\n",
2019-05-26 17:30:39 +02:00
" rewards.append(step)\n",
2021-03-09 10:21:08 +01:00
" if step >= best_score:\n",
2019-05-26 17:30:39 +02:00
" best_weights = model.get_weights()\n",
" best_score = step\n",
2022-04-05 11:47:12 +02:00
"\n",
" if episode > 50:\n",
2019-05-26 17:30:39 +02:00
" training_step(batch_size)\n",
2022-04-05 11:47:12 +02:00
" if episode % 50 == 0: # <= CHANGED\n",
" target.set_weights(model.get_weights()) # <= CHANGED\n",
"\n",
2019-05-26 17:30:39 +02:00
" # Alternatively, you can do soft updates at each step:\n",
2022-04-05 11:47:12 +02:00
" #if episode > 50:\n",
" #training_step(batch_size)\n",
2019-05-26 17:30:39 +02:00
" #target_weights = target.get_weights()\n",
" #online_weights = model.get_weights()\n",
2022-04-05 11:47:12 +02:00
" #for index, online_weight in enumerate(online_weights):\n",
" # target_weights[index] = (0.99 * target_weights[index]\n",
" # + 0.01 * online_weight)\n",
2019-05-26 17:30:39 +02:00
" #target.set_weights(target_weights)\n",
"\n",
2022-04-05 11:47:12 +02:00
"model.set_weights(best_weights) # extra code – restores the best model weights"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 62,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAEKCAYAAAD+ckdtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACShUlEQVR4nO19d5wcR5n2887MBu0qp1VOlixZlmTZlqNkeeQANtxh4MgcmTMcmHDAgQkfRzg4ouG44+AMJh7YBAM2tnH2OEi2ZOWcs7S7WoXd1eYJ9f3RXd1V1dU9PWlndrce/VbTXV1dXV3dXW+9mRhjMDAwMDAwMBhciJS7AwYGBgYGBgbFhyHwBgYGBgYGgxCGwBsYGBgYGAxCGAJvYGBgYGAwCGEIvIGBgYGBwSBErNwdKCbGjx/PZs2aVdQ2Ozs7UV9fX9Q2ByrMWMgw4yHDjIcMMx4uzFjIKPZ4bNiw4TRjbIJaPqgI/KxZs7B+/fqitplIJBCPx4va5kCFGQsZZjxkmPGQYcbDhRkLGcUeDyI6ois3InoDAwMDA4NBCEPgDQwMDAwMBiEMgTcwMDAwMBiEMATewMDAwMBgEKLfCDwRTSeiZ4hoFxHtIKKP2eVjiegJItpn/44RzvksEe0noj1E9Mr+6quBgYGBgcFAR39y8CkAn2SMXQTgagAfJqKFAO4E8BRjbB6Ap+x92MfeAuBiALcA+B8iivZjfw0MDAwMDAYs+o3AM8YaGWMb7e3zAHYBmArgNgC/tKv9EsBr7e3bANzHGOtljB0CsB/Alf3VXwMDAwMDg4EMKke6WCKaBeA5AIsAHGWMjRaOnWOMjSGi/wbwEmPs/+zyewD8jTH2R6Wt2wHcDgANDQ2X33fffUXta0dHB4YPH17UNgcqBuNY7DqTxqgawpTh2de6vWmG9U0pXDslBiIq2nicOJ9BR5Jh/tiBLaDKdzw2n0phxsgIRlYT1py0xnfNyRSWT4khGiHtOS+dTOGSiVEwBmxtSWPpxCg2nkrj2in9E9rjVFcGLV0MF4/3f2YdHR1AdT12nk0DABaOjWL32TQuHBPFgbY0Zo+M4HB7BtVRAgEYXUN4qTGFqybH0FBHeOJICqNrCPPGRPD8iRQyylS9eHwUEQK2tKQD+zpzZATTR0Sw+kQK5UoO3tfXh0smDcOIakJbL8PZngxWTLW+o740w5NHkugRbuPKSTGc68lgX2tGaicWAaaPiOBQm1U+c2QEJzoySAnVhlcRVs2IoUp4dxhjeOFECuOGRTCqmjBlOOGpoym09zHEIkB8ehVGVuvftXywtjGFReOt9/PpY0mnf2NrCS1dDNeM78O0ccWbS1etWrWBMbZMLe/3QDdENBzA/QA+zhhrJ/IdVN0Bz/vJGLsbwN0AsGzZMlbsYAomQIOLwTgW777zYQDA4W+8Omvdz/95G36z7ShWXXUprp07vmjj8ZF7N2Ff83k8+vqVBbdVTuQ7Hu++82FMHFGD966YjZ9t342NbXXYfKwLU2ZegPdfN8dTf1djO3786PN41eJJyGSAR3c0YfHUUdh2ogu3rrgcl84Yo7lKcTErxHuTSCTw0/3D8ML+0wCARVNHYvuJLiyZNgpbj3dhzoR6HGzpduq/85qZeODAEdSPn4zFF87AfY+/AAD45M0X4i/790KcKhkDTrERqK2KInGgBX7TKGPA+OE1eN2lk/DAgUO+9UoNxghNqMXmY61O2ZJFC3Hb0qlYc+A0fv/EWgAAkdXnmtEN2Hj0HA60dDp9zsaLivVGTZqOz9yywDn2t22NuOexjc7+6jtvwHsee9rZv2LxArzmyhmF3aSNQ6c78e5HE7hhwUTcumgS/vT0Vk+dldPq+mUu7VcCT0RVsIj7bxhjf7KLm4loMmOskYgmAzhllx8HMF04fRqAk/3XWwMDGc3tvQCAjt5UUdtNpjJIpjPZKw5inDrfi1P2+B450wkAONfVp63b1WeNf2Nbj8MF7Gk+bx8L5mb7G8fPdTnb+0912GUWUT/Y0inVTaYtCtbVm0JvKi2UZ0AEHPoPdzHx9p++hN5kBslIBpfPHIP7//la7fX/7YHteGDLSSTTDKOGVWHLv72iODeWI17z3UeRVkQQ20+04balU537vv+fr8HlM8fium89jXSGobsvjTdePg3ffuMlAIBT53tw5deeAgBcOXssJo2sxYNbLJKw/gs3YfzwGgDAP/1qPf688YRE4Nu6k9K10/Y1P3vrAvzH33YjpYpHCkBP0np2J851O/e85s4bsPNkO97/KyvSakN9/2jH+9OKngDcA2AXY+wu4dCDAN5lb78LwANC+VuIqIaIZgOYB2Bdf/XXwKC/kGGsbKLTSkLGZtH8xPI6xKLWFMYXSGViUD0435PErjNpiBLKlE1UIr5stE3g+9IO0QOAvjRDVDknQoQMY0hnvMdEEBHSGateDsNadBDc58tx6LS1wMlkvOPCAPSmMqip8idR9TWueqS+2uVVp4yqRU8qeKHHvzjnkkVUVTuSBMjf9Yp54zGiNobPvWqB9rxSoD85+OUA3gFgGxFttss+B+AbAH5PRO8DcBTAGwGAMbaDiH4PYCcsC/wPM8Yqa3luYFAEZBg0yqehAdEGiBMAfwIogwBU2wS+DKZEgXjPz1/G+iM9mDKq1injXGLUh2bxe+hOpiWJTl8qg4hCnYkIaQaLwAdQ7miEkMkwZFhwvVIjQt5ndNAm8OmMvLAjEBhjFoGPuUSchOUbAaiziToRUCssBGqqouhNyhIx9fXgfeHvWjFfH95PxtzrWH2MYtuXLG/vROJYEa/oj34j8IyxF+C/wL7R55yvAfhayTplYFARGLocvDjppzWcXLZzqqJK3Qph4dcfOQdATzj87o8vcLr70khlXALVm0p7uPQoWYujdAYSEVQRjRAyzGo77MKpVFA5+JOtlqoirSzsiKxx60mmJcItdp8IqKu27ruuKipJSmpiEfSm0mCMwc/Gi/eEHy/mAtHl4AVJQZleTBPJzsCgzMgwmZMdSshoOPh8RPQDCX6Elg9FV18afSl3XHpTGc+YRLjoncHD3YsgsghoNk6/1NBx8JzoZTwcvCXtSGWYwsHL53IOXr3/mlgEGQZJr65emzmLCnm/GOBtZhiTOPhyYOB9HQYGgwxDWQcv3rfDwWeZlcTJsloh8OXilHKBH6Hl9KgnKXPwfToCb3Pm6UwGqhBDuhZZIvp0JrzqoxQgeOVU/J7SysKOiNBrG6rJHLwgoidZBy+CLwp6U/6Gq7wnpRDRO0sRJkgKitp+eBgCb2BQZmRY5emQ+wsiB8/VzkFGY0AWEX2FQfdc/RhpJhjZpUQjOy0HD4dwRwNWRJaIvvw6eAAeP37+mFXVDMGyQwDgz8GTq4NXwQ3zznX2oamtB4B3ceHq4OX9YkAU0bssfPHazwWGwBsYlBlDVTwPyBNrzkZ2RB4RfZnVzABcNylAb13hJ1IXjez6RCO7dMYzJg7hzjBfoz3AGqNMCGO8UiNCXh08vyePaoaAHttIribmo4MHob7aj4O3znnNf7+Aq//jKZ8ecSv6UhjZ2VcQJHPlkiz1e6AbAwMDGYwNZTM7F3yiJ2eezz4pVlWgDr7lfG/gcT8JBV/odSscfG8q7SHilhW99dYEWtHb10plMmV3k1Nf8YjDwVu/UZGD7+Mieh8regLqanw4eJvrP9fl+r57dfC8D8UfFHHRUG4dvCHwBgZlhmiMM9Qgi+it7Zgtcg6z6KlWRPQVwMBnDZrib0Vv/falM5IUoDeZ0VjRExiz9NfBInq7zVR5OXjScPC8PxnF9oKIHD92kYNXH242Dj4MSmFkp3OtL9fIGwJvYJAjik2LhzKB14rosxAicagq0YpeJWQqfEX0wnZ7j8t99qW9fvAREvzHA4aLJA6+vEZ23nWPj5EdgB4dBy+5yVFWHTyHLkqkamRXTEiBbhyplBHRGxhUNEr1jRYxSuaAg+QmZ8/DYRlNgldEX66JVITIDepovX+gG7dye7cbDrkvlUFMa0VvLQyDFkScaCY1evz+BBEh5SG01v26CxXXD74npdHBi+0huxU9R28q4xvoJmyc+1z
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – this cell plots the learning curve\n",
2019-05-26 17:30:39 +02:00
"plt.figure(figsize=(8, 4))\n",
"plt.plot(rewards)\n",
"plt.xlabel(\"Episode\", fontsize=14)\n",
"plt.ylabel(\"Sum of rewards\", fontsize=14)\n",
2022-04-05 11:47:12 +02:00
"plt.grid(True)\n",
2019-05-26 17:30:39 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 63,
2019-05-26 17:30:39 +02:00
"metadata": {
2022-04-05 11:47:12 +02:00
"scrolled": true,
"tags": []
2019-05-26 17:30:39 +02:00
},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – shows an animation of the trained DQN playing one episode\n",
"show_one_episode(epsilon_greedy_policy)"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"## Double DQN"
2019-05-26 17:30:39 +02:00
]
},
{
2022-04-05 11:47:12 +02:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"The code is exactly the same as for fixed Q-Value targets, except for the section marked as changed in the `training_step()` function:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 64,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode: 600, Steps: 200, eps: 0.010"
]
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"tf.random.set_seed(42)\n",
"\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(32, activation=\"elu\", input_shape=input_shape),\n",
" tf.keras.layers.Dense(32, activation=\"elu\"),\n",
" tf.keras.layers.Dense(n_outputs)\n",
"])\n",
"\n",
"target = tf.keras.models.clone_model(model) # clone the model's architecture\n",
"target.set_weights(model.get_weights()) # copy the weights\n",
"\n",
"env.reset(seed=42)\n",
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"rewards = [] \n",
"best_score = 0\n",
"\n",
2019-05-26 17:30:39 +02:00
"batch_size = 32\n",
2022-04-05 11:47:12 +02:00
"discount_factor = 0.95\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=1e-2)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
2019-05-26 17:30:39 +02:00
"\n",
"def training_step(batch_size):\n",
" experiences = sample_experiences(batch_size)\n",
" states, actions, rewards, next_states, dones = experiences\n",
2022-04-05 11:47:12 +02:00
"\n",
" #################### CHANGED SECTION ####################\n",
" next_Q_values = model.predict(next_states) # not target.predict(...)\n",
2022-02-19 10:09:28 +01:00
" best_next_actions = next_Q_values.argmax(axis=1)\n",
2019-05-26 17:30:39 +02:00
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
2022-04-05 11:47:12 +02:00
" max_next_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
" #########################################################\n",
"\n",
" target_Q_values = (rewards +\n",
" (1 - dones) * discount_factor * max_next_Q_values)\n",
2020-03-12 10:47:22 +01:00
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
2019-05-26 17:30:39 +02:00
" mask = tf.one_hot(actions, n_outputs)\n",
" with tf.GradientTape() as tape:\n",
" all_Q_values = model(states)\n",
" Q_values = tf.reduce_sum(all_Q_values * mask, axis=1, keepdims=True)\n",
" loss = tf.reduce_mean(loss_fn(target_Q_values, Q_values))\n",
2022-04-05 11:47:12 +02:00
"\n",
2019-05-26 17:30:39 +02:00
" grads = tape.gradient(loss, model.trainable_variables)\n",
2022-04-05 11:47:12 +02:00
" optimizer.apply_gradients(zip(grads, model.trainable_variables))\n",
"\n",
"replay_buffer = deque(maxlen=2000)\n",
"\n",
"for episode in range(600):\n",
" obs = env.reset() \n",
" for step in range(200):\n",
" epsilon = max(1 - episode / 500, 0.01)\n",
" obs, reward, done, info = play_one_step(env, obs, epsilon)\n",
" if done:\n",
" break\n",
"\n",
" print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n",
" end=\"\")\n",
" rewards.append(step)\n",
" if step >= best_score:\n",
" best_weights = model.get_weights()\n",
" best_score = step\n",
"\n",
" if episode > 50:\n",
" training_step(batch_size)\n",
" if episode % 50 == 0:\n",
" target.set_weights(model.get_weights())\n",
"\n",
"model.set_weights(best_weights)"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 65,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAEKCAYAAAD+ckdtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACDvElEQVR4nO2dd5gcxZn/v++EzUF5lbMEKIAMAgESYmTAgHG68zknjG1snH2+s/HZP9vndE7gdD5sbGzMHQbbYBwAk4RGAQklJJRzllbSale72jixfn90V3d1d3VPz07PbKrP8+yzMx2qa2p6+q03FjHGoFAoFAqFYnAR6usOKBQKhUKhCB4l4BUKhUKhGIQoAa9QKBQKxSBECXiFQqFQKAYhSsArFAqFQjEIifR1B4Jk1KhRbOrUqYG22dnZierq6kDbHKiosbCixsOKGg8rajxM1FhYCXo8Nm/efI4xNtq+fVAJ+KlTp2LTpk2BthmPxxGLxQJtc6CixsKKGg8rajysqPEwUWNhJejxIKKjsu3KRK9QKBQKxSBECXiFQqFQKAYhSsArFAqFQjEIUQJeoVAoFIpBSMkEPBFNIqIVRLSbiHYS0Wf07SOI6Hki2q//Hy6c8yUiOkBEe4no5lL1VaFQKBSKgU4pNfg0gM8zxi4BcDWATxDRHAB3A1jOGJsFYLn+Hvq+dwKYC+AWAP9DROES9lehUCgUigFLyQQ8Y6yRMfaK/rodwG4AEwC8GcDv9MN+B+At+us3A3iUMZZgjB0GcADAVaXqr0KhUCgUAxnqi+ViiWgqgFUA5gE4xhgbJuw7zxgbTkT/DeBlxtj/6dsfAPAPxthjtrbuBHAnADQ0NFzx6KOPBtrXjo4O1NTUBNrmQEWNhRU1HlYKHY9MlmHtqTQWjYvghWMp9KS17eOqQ7iQZKgtIzDGsLAhglfOZlAdBQ60ZqVtzRoWwqjKEJq6szjRnkUiA0wfFsKh1ixmDgshw4ApdSEcbstiWn0Iq06kkQn4UVgbSuKaSdXY1ZxBV5ohy4BF4yJYfiyFdBaojhJumhJBiMhxLmMMa06mcc34CCIhbf+FBMOelgyuGqeVL4kfTyHLgMXjI9h0Jo1rx0fQ2MlwrjuLY+1ZTKwJYcGYYEudbG9KY2x1CKOr8tMNi/FbOdSawfZzGcwYFkKICM3dWSyZEAEJ4/nKmTSOXNDukavGRjCxVus3YwzLj6VxIal96cPLCURAS0/w8rCujDC8ghANmffr1SOTGD8iuPFYtmzZZsbYQvv2khe6IaIaAI8D+Cxj7AJJbm5+qGSbY/QZY/cDuB8AFi5cyIIupqAKNJiosbCixsNKoePxi5UH8cCOPWgrH4PH9p5wPe7llkpsP9llvLc/QhgDZo6pwWsmDcOfNru3M31UNQ6d68YXb7kYTxzYI22rt2h6E+EYqrB6/zlj+8Sp0/Hn/XuM93fcejVmNdQ6zn9qWyMeePYVVI6ejH+7+SIAwLIfxnH4XAKf+OcY0hmG2595DgBwKFmHNQfO4ZYlC/HB+9Za2jny3RuD+UA6t9/9FMoiIez71q15nVeM38pvfrMBqw40WbYtXngpll3cYLz/t2+9gHMdKQBAtL4B733jZQCAE+e78MFnV0jbDeoeAPh94Gx/YUNlSZ4dJRXwRBSFJtwfZoz9Wd98hojGMcYaiWgcgLP69hMAJgmnTwRwqnS9VSgUpaS5IwEAaGrX/v/xo9dgzf4m/PTFA5bjDpztMF5fN2sU/vdDiyz7P/PoFrx6vBU9abl2zznWok0SulMZAMD+b9+KaDgYr+UDaw7jm0/uQktn0rK9I6Fd6/+9YQ6++eQuJDPyPrb3aELpbHuPse3wuU4AQCKdRVaQHPvOtAMAkjk+b1CU6jq5yGSd/eDja75P4aNLp+PZnaeREsY6pZtrfvyOBRhWFcXtv90IAPjV+xfipjkNCIp/bG/EXQ+/YryfN6EOT37qOsTj8cCu4UUpo+gJwAMAdjPG7hV2/Q3AB/TXHwDwV2H7O4monIimAZgFYEOp+qtQKEpLlvH/2otwiBAOOR9RGUG4hUNOdStEhCwDUjkEUUg/N6EL+HCAqhs3qw+rilq2c+FYFtb2u3lIed+4TEoLwimZziKbNU/sTGi+jLJIgKrnAIAkRl7xO8xkGXpSWVSVRRAOETLCmPHJQThEGFdfaWwfV18RaB+ryq06dERyPxeTUmrwiwG8D8B2Itqqb/sPAN8F8Eci+hCAYwDeBgCMsZ1E9EcAu6BF4H+CMZZxtKpQKAYV/EEcCREiYedDXHxQR1wEfCbLkJZoeCJcGCTSWRCZQjUIeL+HVZZZtnMBzy0F4meR9Y1PdhrbTE0+kc6gPGImFHUmM5Y2hwqy+Zg4BF1JbeJTXR52CPi0cI+NFYT62IAFfHWZNfErKrmfi0nJBDxjbA3kfnUAuMHlnG8D+HbROqVQKPoNXJvlD19Ng/cW8HINXhOMqRxRc/zUnlQmUO0dAKK6plZv1+AzVmGccVHh+efiAl78zMl0FhGJMJeNxVBDjOnq0ic+mgYfMu6rjkQa2060AdAmdXUVphgcUWWdkBVKVZlVxJZ6EjaoVpNTKBQDF6bH0HLzcyRMUg1dRGbyDIcIWZZbgzdM9OlsoNo77wPgNPsbGnxE63fWRYPnp3HBLh6VSGdREXWe1wcJUX2K7PPy8WaMYd3BZgBcgzcnS3f932Yj8DESIsukIOj7oLrcqsHLJmbFRAl4hULRL7Br8JFQKKeAl2mtRIRMFjk1+LAh4DM5r5Mv3ESfsgXR8T6V5TLR2zR4MZ3ZHmTHkW0bzDBnUpVx7zy1vRGf/cNWAE4NfuORFuN4Ps6jasqQSAUfPOjQ4EtsZVECXqFQ9Cu4MIuECOEcGo9MMIdDWhtplwh14zhdc+tJZYM30ev9tk8yTB+8HkTnZqInHmTn1OCT6ax0YuA2WRhKcKvNaSFmobosjDCZ1hKx7gAX8Gu++Nqi9MepwZdWwA+tqAyFQtFv4YJd9MHn0njcougzjBntuEFkavDFMtFzDf6et11mXAswJwB+o+jF45KZLGQfbchp8JKPy8d7mOBLryqPIBIKGcJf/Kb591QRDaMiGnwl9IqIPciutCJXCXiFQtEv4M/rjOCDzxU4JtOIQkTIZnMH2XFFrieVDdxEzzV0LlRq9UCupN6nfKPoRR0+kcq4mOgL7/dAQirg0wxnL/RYAueqy8IIhQAekiFq8MVOW7NPHJWAVygUQxL+wM6IPvgcJk2vPPhcJnp+ajE0eC44UjaBntQ1+LKIvyh6w0Tv0OCViV7mg//HjkZc9Z3lWKsH2AFODV5U4UudWRj0RDIXSsArFIp+AX9gZwQffC4NSx5Fz9PkssI2+UQAABJF0OAjNhM9n6hwHzzf7xZFb2QU8CA7YZ+bD16Z6IG9p7Wqfs/uPG1s0zR4MtYasPrgSxzVrjR4hUIxFOEPbC70wj7S5KSCWy9qIpropdH2+v+edEa64EshRIwgO7NiGqBp32J+v5vWzZVNmQafSGchywDMkRU46JCN3IgazfcuFgaqKtMW7OHV68RbodQadZkKslMoFEMRwwcvRtHnzIOXa+aMwZIHL4uSJ0GDD7pIjJkmZzfRa9cKOXzsVvgYmBq8tdCN1EQ/xDR4mYSXpbqVRUJ6dUPtvTiZC3pilwulwSsUiiGJ4YPPmD74XEFJMsEc5lH0uTR4wwdfAhN9yBT4YRI1ePn5PKMgK4uiT2elwnzImeglEp5Xr7MjavCiTC912ppKk1MoFEOafDR471K1/nzwPaliBtllLe9NDV47zk0o2xfeEUlmspbCN8Y5RQ6yk12zL5F1p1OvP8/5+bsvBwBbLXpnHnwxWfXvyzBhmLagTbTEPn8l4BUKRT+B112HsfhLb33wjFmLzMjaETX44Avd6Bp72kz549cKCQvbuAt4a8ChxQefykg1/2JH0fcz+S6ly7Zc7GWT6gFYBbx4KwT9vcuYPLIKF42tBaA0eIVCMUTh5ugsY4ZAzuWzdPPBA7Cste6lwbvtLwSj0E3WWrkumc4gEg45KtXZMTR4o5KdkAfvkiZX7Cy5/ibfZf1
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# extra code – this cell plots the learning curve\n",
"plt.figure(figsize=(8, 4))\n",
"plt.plot(rewards)\n",
"plt.xlabel(\"Episode\", fontsize=14)\n",
"plt.ylabel(\"Sum of rewards\", fontsize=14)\n",
"plt.grid(True)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 66,
"metadata": {
"scrolled": true
},
2019-05-26 17:30:39 +02:00
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – shows an animation of the trained DQN playing one episode\n",
"show_one_episode(epsilon_greedy_policy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dueling Double DQN"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 67,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"tf.random.set_seed(42) # extra code – ensures reproducibility on the CPU\n",
"\n",
"input_states = tf.keras.layers.Input(shape=[4])\n",
"hidden1 = tf.keras.layers.Dense(32, activation=\"elu\")(input_states)\n",
"hidden2 = tf.keras.layers.Dense(32, activation=\"elu\")(hidden1)\n",
"state_values = tf.keras.layers.Dense(1)(hidden2)\n",
"raw_advantages = tf.keras.layers.Dense(n_outputs)(hidden2)\n",
"advantages = raw_advantages - tf.reduce_max(raw_advantages, axis=1,\n",
" keepdims=True)\n",
"Q_values = state_values + advantages\n",
"model = tf.keras.Model(inputs=[input_states], outputs=[Q_values])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The rest is the same code as earlier:"
]
},
{
"cell_type": "code",
"execution_count": 68,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Episode: 600, Steps: 190, eps: 0.010"
]
}
],
"source": [
"# extra code – trains the model\n",
"\n",
"batch_size = 32\n",
"discount_factor = 0.95\n",
"optimizer = tf.keras.optimizers.Nadam(learning_rate=5e-3)\n",
"loss_fn = tf.keras.losses.mean_squared_error\n",
"\n",
"target = tf.keras.models.clone_model(model) # clone the model's architecture\n",
"target.set_weights(model.get_weights()) # copy the weights\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"env.reset(seed=42)\n",
"replay_buffer = deque(maxlen=2000)\n",
2019-05-26 17:30:39 +02:00
"rewards = []\n",
"best_score = 0\n",
"\n",
"for episode in range(600):\n",
" obs = env.reset() \n",
" for step in range(200):\n",
" epsilon = max(1 - episode / 500, 0.01)\n",
" obs, reward, done, info = play_one_step(env, obs, epsilon)\n",
" if done:\n",
" break\n",
2022-04-05 11:47:12 +02:00
"\n",
" print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n",
" end=\"\")\n",
2019-05-26 17:30:39 +02:00
" rewards.append(step)\n",
2021-03-09 10:21:08 +01:00
" if step >= best_score:\n",
2019-05-26 17:30:39 +02:00
" best_weights = model.get_weights()\n",
" best_score = step\n",
2022-04-05 11:47:12 +02:00
"\n",
" if episode > 50:\n",
2019-05-26 17:30:39 +02:00
" training_step(batch_size)\n",
2021-03-09 10:21:08 +01:00
" if episode % 50 == 0:\n",
" target.set_weights(model.get_weights())\n",
2019-05-26 17:30:39 +02:00
"\n",
"model.set_weights(best_weights)"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 69,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAfgAAAEKCAYAAAD+ckdtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAACcPklEQVR4nO2dd5gkR32w35rZvHs5S5dPOZ90ymlAAkSwCQZMjkbGxp+xjQPBNrYxDmDAxjZgkbFNFCYKlDWnrJNO6aTLOd/e3d5tDhPq+6O7uqu7q3t6ZmfD3dX7PPvMTMeamt761S+WkFJisVgsFovl5CIz0Q2wWCwWi8VSf6yAt1gsFovlJMQKeIvFYrFYTkKsgLdYLBaL5STECniLxWKxWE5CGia6AfVk9uzZcunSpXW9Zn9/P+3t7XW95omK7Ysgtj+C2P4IYvvDx/ZFkHr3x9q1a49IKeeEt59UAn7p0qU89dRTdb1mPp8nl8vV9ZonKrYvgtj+CGL7I4jtDx/bF0Hq3R9CiF2m7dZEb7FYLBbLSYgV8BaLxWKxnIRYAW+xWCwWy0mIFfAWi8VisZyEjJuAF0IsEkI8IITYIIR4UQjxYXf7TCHEPUKILe7rDO2cjwkhtgohNgkhXjFebbVYLBaL5URnPDX4IvARKeW5wFXAh4QQ5wEfBe6TUp4J3Od+xt33FuB84BbgS0KI7Di212KxWCyWE5ZxE/BSygNSyqfd973ABuB04LXAt93Dvg28zn3/WuD7UsphKeUOYCtwxXi112KxWCyWExkxEcvFCiGWAg8CFwC7pZTTtX3HpJQzhBD/ATwupfwfd/vXgV9LKW8PXetW4FaAefPmXfb973+/rm3t6+ujo6Ojrtc8UbF9EeRE6o9D/WWODErOn12bEWztoSIrpmeY3hyvE+w52kdXuYWFHRke2lckK2Dp1AxTmwUlCf0jkh09ZYplmNokmNEiEq9ZlpJH9hW5+rQG+txzV86Nlu7YerzE84dL3uf2RsENCxu4b3eBQgmWTcuwo7tMcxbmtGXY01tmQXuGnhFJf8Ef/y6Zk2X59NqNhJu6SnQ0CQYKkpYGwQwxQKmxnfyeAounOvctls3nCmD59Azbj5eJG5EbMnDT4kbaGwUAaw4UOW9Wlo4mkap9z3YW2d4d0wCXC2ZnOWuG3wfPdBZZNjXD9BbnNxopSe7dVWC4BFcuaOC0jkzgu5/e4f+Wjx8octHsLG2Nwvtf2dld4unOkvedl03LVGxTJaY0CtqbBL0jkpctaSAj0vVHz4hkU1eJsoQpTYKWBsgKKJThDMNzUChL7tlZYHZrhqNDksGi5PL5DSyaEn1+H9hd4NiwJCvgxkUNbDte5ozpWaY1O22r99jxkpe8ZK2UclV4+7gXuhFCdAA/Bv5IStkj4n8M047Isy+lvA24DWDVqlWy3sUUbIEGH9sXQU6k/lj60TsA2PlPr6763OFiiff85Z2cObeDe/7kxtjjLvyrO+gtDPPHN5/FT7duTnXtFXPaue8jOeO+nz27j6/f9SxT5i/hp+v2sfPoMNv/4SYymeDQ8M1vrGH1tsMIAUpfWbFiBbdv3pCqDeq83obpvO91tRsJ3+P2seJbt3Swu3kpP9n6YuR+YXQ9K2n/dSvP49WXLmTf8UHec+f9XHfGbP7nd65M1b6/+sz97OkaNl5f3eOwnMqtr78KgHJZ8r5P/Io/vvksXpc7E4BHtx3hh/c8AcCUOQt4W+5CAP7hC6s5Z/5UvvialQBs7ezjK3eu5uZz5/G1d6/y/ld+/3/X8qttBwO/Vdx3TkNYP33PLVdx9vwpqc59w5ce4endx437TP8nj28/yg/vfjywrXHaPN75GxcHth3tG+Y9d97rfT7nzBX8x7Mb+YtbzuG1N64Axm/sGFcBL4RoxBHu/yul/D938yEhxAIp5QEhxAKg092+F1iknb4Q2D9+rbVYLOAPoru7BhKP6y04r6Vyskb2t795Pp/8uSP0dhzpjz2ue9C54JG+YXa59y5LSSY09y+Wy1y2ZAY//r1r+NW6A/z+/z7NUKEUuZ6J77zvCm44aw6/9eVHKZTqb80shq655uM3MXdqS+S4i//2broHC8yd0syaT9wc2X+we4ir/vE+hl0TgPp++48PVtWWN122kM++6WLj/rfe9jilst9eCZSlo7kqSjHviyVJUfvd49pXKEnOXTCVX3/4es79qzsZLJRYOKOVh//ipam/h84vntvP//veM9r101sD9h5L33cQ/L7+tuj9iu5xn3jVuXz6VxsYKZaREkbizDdjyHhG0Qvg68AGKeXntV0/B97tvn838DNt+1uEEM1CiGXAmcCa8WqvxWKpjUpisqXRH3YMY6aHMLw3HS+lv18p90nXzWoWgLYmxxSbFSIgoOpF2imDalKcebkx62xXQkK5VqvRfMtSVjRfS63F6h66G1fXmPXukqHPsdeX+ndVrzWq7wR/SzAL4Tiqva3Jm226W9k9UFmaVJOqaVu9GE8N/lrgncA6IcSz7raPA/8E/FAI8X5gN/AmACnli0KIHwLrcSLwPySlTDctt1gsE0ZSWI8Q0NRQvV4hXJtu2XBxR2gI/ziig+mUlgZ6h4oATG9t5Gj/CACtSsBnxJgMwJEYpxihogRVWGApVJ8pDVU1NcHFaWgLZBK6Pmw2V29N25w2BCcDxt8m0gZ/kuG/pmh8DOHJQamKmDIR92PEIA3i3HQ79Rw1uF9MtcnUP2PNuAl4KeXDxD7e3BRzzqeBT49ZoywWS90xDYSKloZs1QMr6Bp89NplTYVXx6lBtqkhw0ixzLTWRk/AT2vzBXx7kzMEZjOCkSrMu7US992FN0Exn9eYdSTzsKfBq+ulpyyTJwRCBAWyukdYkOvX87aHPqvbhCc4ZSm9feq1nhp8eQy1ZJN8jptwgqbBu22aCA3eVrKzWCyJVKt4JB3f3JipyTSqzjGa6Imae9XA29roaOhTWxq942e0NXnv2zQNvjgmGnzwc9x3z4oKGnw2qMGrSVQ1wlFKmTghEIigOR6leWrX0I4PCv6gMI+byEj8SYYSgKOQ72RDEmxMTfQpt6l+CWvwVsBbLJZJR5JGbj4+npaGKtLQDCOwaZB0BJcSGu5xYQHf6hsrp7f6wl430Y+F9hfuuziZUskfnckIGjLC88Gr4L16+uDjNPiAFq774EOTAf3cuNuUZfS7jkaDH52JvjqMKeVGrd55VZM2q8FbLJZJT9rhKcnX6GjwNZjo3XNMg6yu4StBrwZVJcCnaUJ9Wpv/vk0z0Y+PBh8vwCHZH92YzXgavGprNX2pC9c4KpmhdTEe0Oxlsrna/yy1gMhkq0Uaoib69OdW+xwmaeumbaptSrBXM/moF1bAWyyWRLxxKe34VCGCPfWwGjD5OsSb6IPmXnWcyUQ/vdU30XvBbWJsNPgw8Rp8ZW1WxRMAFF1BX42IklJW8MGLGA0+ug2iJnqjDz70MOgBkUo21zLhU2RHocFXjenZM25zTfTZUJDdSR5Fb7FYTkCqHZaSTJGC2nyu6hzTtfXALSU8SiENvr3ZH+pmaBq8wtHgxz5NLtYHXyGKHhwNfsQ1zSsNPikqPtIWmTyBEOogdbzJBx94HxbeUR98WADqboK6RNGPIsiueh98vLYe3Oa2TQT7oCQlxwdG2HGkn+ExqLlgwmrwFoslkWrLWZdlsqBKrcNrI7AvMOJM9EGhEQ6ya9Zy76fHCPixULCiWXJxUfTOa5IAbg5o8NUH2ekTobg2mDT4oIleu56eBx9Kk4v3wUvP7DDRefDVkjYPPpIm5/ngnWp4r//SoxzqH5+iN1bAWyyWRKodMstSJgv4UWhs5kI3vl83rOk3u/njStADTNOi6BVjp8FHJLwRZWpO0sgbs8LzwRfK1ZvoK/ngBckm5+j1goLf1H3hM/U2iHpo8ONoojdlRCTFLGRCUfTlsvTTG0fzT1AFVsBbLJZEqh0zHTNsTE11qo9eBl8ImAZwiRZkF/LBq4G0WYven9oS9Uw6UfQ1NKwCqdPktFiAOJoMGnxVhW6oFEUvgpXs3NegiT5dkJ0wHK8uGg6uG5UPflxN9KF7CxF
"text/plain": [
"<Figure size 576x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – this cell plots the learning curve\n",
"plt.figure(figsize=(8, 4))\n",
2019-05-26 17:30:39 +02:00
"plt.plot(rewards)\n",
2022-04-05 11:47:12 +02:00
"plt.xlabel(\"Episode\", fontsize=14)\n",
"plt.ylabel(\"Sum of rewards\", fontsize=14)\n",
"plt.grid(True)\n",
2019-05-26 17:30:39 +02:00
"plt.show()"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 70,
2019-05-26 17:30:39 +02:00
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
2022-04-05 11:47:12 +02:00
"# extra code – shows an animation of the trained DQN playing one episode\n",
"show_one_episode(epsilon_greedy_policy)"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This looks like a pretty robust agent!"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 71,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
"env.close()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"# Exercise Solutions"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"## 1. to 7."
2019-05-26 17:30:39 +02:00
]
},
2021-10-18 01:54:17 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"1. Reinforcement Learning is an area of Machine Learning aimed at creating agents capable of taking actions in an environment in a way that maximizes rewards over time. There are many differences between RL and regular supervised and unsupervised learning. Here are a few:\n",
" * In supervised and unsupervised learning, the goal is generally to find patterns in the data and use them to make predictions. In Reinforcement Learning, the goal is to find a good policy.\n",
" * Unlike in supervised learning, the agent is not explicitly given the \"right\" answer. It must learn by trial and error.\n",
" * Unlike in unsupervised learning, there is a form of supervision, through rewards. We do not tell the agent how to perform the task, but we do tell it when it is making progress or when it is failing.\n",
" * A Reinforcement Learning agent needs to find the right balance between exploring the environment, looking for new ways of getting rewards, and exploiting sources of rewards that it already knows. In contrast, supervised and unsupervised learning systems generally don't need to worry about exploration; they just feed on the training data they are given.\n",
2022-04-05 11:47:12 +02:00
" * In supervised and unsupervised learning, training instances are typically independent (in fact, they are generally shuffled). In Reinforcement Learning, consecutive observations are generally _not_ independent. An agent may remain in the same region of the environment for a while before it moves on, so consecutive observations will be very correlated. In some cases a replay buffer (memory) is used to ensure that the training algorithm gets fairly independent observations.\n",
2022-02-19 10:09:28 +01:00
"2. Here are a few possible applications of Reinforcement Learning, other than those mentioned in Chapter 18:\n",
" * Music personalization: The environment is a user's personalized web radio. The agent is the software deciding what song to play next for that user. Its possible actions are to play any song in the catalog (it must try to choose a song the user will enjoy) or to play an advertisement (it must try to choose an ad that the user will be interested in). It gets a small reward every time the user listens to a song, a larger reward every time the user listens to an ad, a negative reward when the user skips a song or an ad, and a very negative reward if the user leaves.\n",
" * Marketing: The environment is your company's marketing department. The agent is the software that defines which customers a mailing campaign should be sent to, given their profile and purchase history (for each customer it has two possible actions: send or don't send). It gets a negative reward for the cost of the mailing campaign, and a positive reward for estimated revenue generated from this campaign.\n",
" * Product delivery: Let the agent control a fleet of delivery trucks, deciding what they should pick up at the depots, where they should go, what they should drop off, and so on. It will get positive rewards for each product delivered on time, and negative rewards for late deliveries.\n",
"3. When estimating the value of an action, Reinforcement Learning algorithms typically sum all the rewards that this action led to, giving more weight to immediate rewards and less weight to later rewards (considering that an action has more influence on the near future than on the distant future). To model this, a discount factor is typically applied at each time step. For example, with a discount factor of 0.9, a reward of 100 that is received two time steps later is counted as only 0.9<sup>2</sup> × 100 = 81 when you are estimating the value of the action. You can think of the discount factor as a measure of how much the future is valued relative to the present: if it is very close to 1, then the future is valued almost as much as the present; if it is close to 0, then only immediate rewards matter. Of course, this impacts the optimal policy tremendously: if you value the future, you may be willing to put up with a lot of immediate pain for the prospect of eventual rewards, while if you don't value the future, you will just grab any immediate reward you can find, never investing in the future.\n",
"4. To measure the performance of a Reinforcement Learning agent, you can simply sum up the rewards it gets. In a simulated environment, you can run many episodes and look at the total rewards it gets on average (and possibly look at the min, max, standard deviation, and so on).\n",
"5. The credit assignment problem is the fact that when a Reinforcement Learning agent receives a reward, it has no direct way of knowing which of its previous actions contributed to this reward. It typically occurs when there is a large delay between an action and the resulting reward (e.g., during a game of Atari's _Pong_, there may be a few dozen time steps between the moment the agent hits the ball and the moment it wins the point). One way to alleviate it is to provide the agent with shorter-term rewards, when possible. This usually requires prior knowledge about the task. For example, if we want to build an agent that will learn to play chess, instead of giving it a reward only when it wins the game, we could give it a reward every time it captures one of the opponent's pieces.\n",
2022-04-05 11:47:12 +02:00
"6. An agent can often remain in the same region of its environment for a while, so all of its experiences will be very similar for that period of time. This can introduce some bias in the learning algorithm. It may tune its policy for this region of the environment, but it will not perform well as soon as it moves out of this region. To solve this problem, you can use a replay buffer; instead of using only the most immediate experiences for learning, the agent will learn based on a buffer of its past experiences, recent and not so recent (perhaps this is why we dream at night: to replay our experiences of the day and better learn from them?).\n",
2022-02-19 10:09:28 +01:00
"7. An off-policy RL algorithm learns the value of the optimal policy (i.e., the sum of discounted rewards that can be expected for each state if the agent acts optimally) while the agent follows a different policy. Q-Learning is a good example of such an algorithm. In contrast, an on-policy algorithm learns the value of the policy that the agent actually executes, including both exploration and exploitation."
2021-10-18 01:54:17 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2021-10-18 01:54:17 +02:00
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"## 8.\n",
2022-04-05 11:47:12 +02:00
"_Exercise: Use policy gradients to solve OpenAI Gym's LunarLander-v2 environment._"
2021-10-18 01:54:17 +02:00
]
},
2019-05-26 17:30:39 +02:00
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"Let's start by creating a LunarLander-v2 environment:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 72,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 10:09:28 +01:00
"env = gym.make(\"LunarLander-v2\")"
2019-05-26 17:30:39 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"The inputs are 8-dimensional:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 73,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"Box(-inf, inf, (8,), float32)"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"env.observation_space"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 74,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"array([ 0.00229702, 1.4181306 , 0.2326471 , 0.3204666 , -0.00265488,\n",
" -0.05269805, 0. , 0. ], dtype=float32)"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"obs = env.reset(seed=42)\n",
2022-02-19 10:09:28 +01:00
"obs"
2019-05-26 17:30:39 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"From the [source code](https://github.com/openai/gym/blob/master/gym/envs/box2d/lunar_lander.py), we can see that these each 8D observation (x, y, h, v, a, w, l, r) correspond to:\n",
"* x,y: the coordinates of the spaceship. It starts at a random location near (0, 1.4) and must land near the target at (0, 0).\n",
"* h,v: the horizontal and vertical speed of the spaceship. It starts with a small random speed.\n",
"* a,w: the spaceship's angle and angular velocity.\n",
"* l,r: whether the left or right leg touches the ground (1.0) or not (0.0)."
2019-05-26 17:30:39 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"The action space is discrete, with 4 possible actions:"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 75,
2019-05-26 17:30:39 +02:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"text/plain": [
"Discrete(4)"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
2019-05-26 17:30:39 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"env.action_space"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"Looking at the [LunarLander-v2's description](https://gym.openai.com/envs/LunarLander-v2/), these actions are:\n",
"* do nothing\n",
"* fire left orientation engine\n",
"* fire main engine\n",
"* fire right orientation engine"
2019-05-26 17:30:39 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2019-05-26 17:30:39 +02:00
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"Let's create a simple policy network with 4 output neurons (one per possible action):"
2019-05-26 17:30:39 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 76,
2019-05-26 17:30:39 +02:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 10:09:28 +01:00
"np.random.seed(42)\n",
"tf.random.set_seed(42)\n",
"\n",
"n_inputs = env.observation_space.shape[0]\n",
"n_outputs = env.action_space.n\n",
"\n",
"model = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(32, activation=\"relu\", input_shape=[n_inputs]),\n",
" tf.keras.layers.Dense(32, activation=\"relu\"),\n",
" tf.keras.layers.Dense(n_outputs, activation=\"softmax\"),\n",
"])"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"Note that we're using the softmax activation function in the output layer, instead of the sigmoid activation function \n",
"like we did for the CartPole-v1 environment. This is because we only had two possible actions for the CartPole-v1 environment, so a binary classification model worked fine. However, since we now how more than two possible actions, we need a multiclass classification model."
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "markdown",
2017-11-09 13:17:24 +01:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"Next, let's reuse the `play_one_step()` and `play_multiple_episodes()` functions we defined for the CartPole-v1 Policy Gradient code above, but we'll just tweak the `play_one_step()` function to account for the fact that the model is now a multiclass classification model rather than a binary classification model. We'll also tweak the `play_multiple_episodes()` function to call our tweaked `play_one_step()` function rather than the original one, and we add a big penalty if the spaceship does not land (or crash) before a maximum number of steps."
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 77,
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"outputs": [],
"source": [
2022-02-19 10:09:28 +01:00
"def lander_play_one_step(env, obs, model, loss_fn):\n",
" with tf.GradientTape() as tape:\n",
" probas = model(obs[np.newaxis])\n",
" logits = tf.math.log(probas + tf.keras.backend.epsilon())\n",
" action = tf.random.categorical(logits, num_samples=1)\n",
" loss = tf.reduce_mean(loss_fn(action, probas))\n",
" grads = tape.gradient(loss, model.trainable_variables)\n",
" obs, reward, done, info = env.step(action[0, 0].numpy())\n",
" return obs, reward, done, grads\n",
2019-05-26 17:30:39 +02:00
"\n",
2022-02-19 10:09:28 +01:00
"def lander_play_multiple_episodes(env, n_episodes, n_max_steps, model, loss_fn):\n",
" all_rewards = []\n",
" all_grads = []\n",
" for episode in range(n_episodes):\n",
" current_rewards = []\n",
" current_grads = []\n",
" obs = env.reset()\n",
" for step in range(n_max_steps):\n",
" obs, reward, done, grads = lander_play_one_step(env, obs, model, loss_fn)\n",
" current_rewards.append(reward)\n",
" current_grads.append(grads)\n",
" if done:\n",
" break\n",
" all_rewards.append(current_rewards)\n",
" all_grads.append(current_grads)\n",
" return all_rewards, all_grads"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"We'll keep exactly the same `discount_rewards()` and `discount_and_normalize_rewards()` functions as earlier:"
2016-10-23 15:32:33 +02:00
]
},
{
2019-05-26 17:30:39 +02:00
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 78,
2017-09-25 14:08:10 +02:00
"metadata": {},
2019-05-26 17:30:39 +02:00
"outputs": [],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"def discount_rewards(rewards, discount_factor):\n",
2022-02-19 10:09:28 +01:00
" discounted = np.array(rewards)\n",
" for step in range(len(rewards) - 2, -1, -1):\n",
2022-04-05 11:47:12 +02:00
" discounted[step] += discounted[step + 1] * discount_factor\n",
2022-02-19 10:09:28 +01:00
" return discounted\n",
2017-09-25 14:08:10 +02:00
"\n",
2022-04-05 11:47:12 +02:00
"def discount_and_normalize_rewards(all_rewards, discount_factor):\n",
" all_discounted_rewards = [discount_rewards(rewards, discount_factor)\n",
2022-02-19 10:09:28 +01:00
" for rewards in all_rewards]\n",
" flat_rewards = np.concatenate(all_discounted_rewards)\n",
" reward_mean = flat_rewards.mean()\n",
" reward_std = flat_rewards.std()\n",
" return [(discounted_rewards - reward_mean) / reward_std\n",
" for discounted_rewards in all_discounted_rewards]"
2017-09-25 14:08:10 +02:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-02-19 10:09:28 +01:00
"Now let's define some hyperparameters:"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 79,
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"outputs": [],
"source": [
2022-02-19 10:09:28 +01:00
"n_iterations = 200\n",
"n_episodes_per_update = 16\n",
"n_max_steps = 1000\n",
2022-04-05 11:47:12 +02:00
"discount_factor = 0.99"
2016-10-23 15:32:33 +02:00
]
},
{
2022-02-19 10:09:28 +01:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"Again, since the model is a multiclass classification model, we must use the categorical cross-entropy rather than the binary cross-entropy. Moreover, since the `lander_play_one_step()` function sets the targets as class indices rather than class probabilities, we must use the `sparse_categorical_crossentropy()` loss function:"
2016-10-23 15:32:33 +02:00
]
},
2021-03-09 10:21:08 +01:00
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 80,
2021-03-09 10:21:08 +01:00
"metadata": {},
"outputs": [],
"source": [
2022-02-19 10:09:28 +01:00
"optimizer = tf.keras.optimizers.Nadam(learning_rate=0.005)\n",
"loss_fn = tf.keras.losses.sparse_categorical_crossentropy"
2016-10-23 15:32:33 +02:00
]
},
{
2021-03-09 10:21:08 +01:00
"cell_type": "markdown",
2017-09-25 14:08:10 +02:00
"metadata": {},
2016-10-23 15:32:33 +02:00
"source": [
2022-02-19 10:09:28 +01:00
"We're ready to train the model. Let's go!"
2016-10-23 15:32:33 +02:00
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 81,
2017-11-09 13:17:24 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iteration: 200/200, mean reward: 139.7 "
]
}
],
2016-10-23 15:32:33 +02:00
"source": [
2022-04-05 11:47:12 +02:00
"env.reset(seed=42)\n",
2016-10-23 15:32:33 +02:00
"\n",
2022-02-19 10:09:28 +01:00
"mean_rewards = []\n",
2021-03-19 22:04:52 +01:00
"\n",
"for iteration in range(n_iterations):\n",
" all_rewards, all_grads = lander_play_multiple_episodes(\n",
" env, n_episodes_per_update, n_max_steps, model, loss_fn)\n",
" mean_reward = sum(map(sum, all_rewards)) / n_episodes_per_update\n",
2022-04-05 11:47:12 +02:00
" print(f\"\\rIteration: {iteration + 1}/{n_iterations},\"\n",
" f\" mean reward: {mean_reward:.1f} \", end=\"\")\n",
2021-03-19 22:04:52 +01:00
" mean_rewards.append(mean_reward)\n",
" all_final_rewards = discount_and_normalize_rewards(all_rewards,\n",
2022-04-05 11:47:12 +02:00
" discount_factor)\n",
2021-03-19 22:04:52 +01:00
" all_mean_grads = []\n",
" for var_index in range(len(model.trainable_variables)):\n",
" mean_grads = tf.reduce_mean(\n",
" [final_reward * all_grads[episode_index][step][var_index]\n",
" for episode_index, final_rewards in enumerate(all_final_rewards)\n",
" for step, final_reward in enumerate(final_rewards)], axis=0)\n",
" all_mean_grads.append(mean_grads)\n",
" optimizer.apply_gradients(zip(all_mean_grads, model.trainable_variables))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's look at the learning curve:"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 82,
2021-03-19 22:04:52 +01:00
"metadata": {},
2022-04-05 11:47:12 +02:00
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZEAAAEKCAYAAADTgGjXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAABUtklEQVR4nO2dd5xjZ3nvv6/6aDS97czO9urddd11b2sMwTbFQAI2CS0Uhxuc4pAECNyECyEJBG4ICSVOwMCl2A7EYEKMcdmxjfuu7fX23qb3kTTq0nv/OGWkGc2sZizNaHef7+ejz2iOjqRHZzTnd57yPo/SWiMIgiAIc8Gx0AYIgiAIZy4iIoIgCMKcERERBEEQ5oyIiCAIgjBnREQEQRCEOeNaaAPmm8bGRr18+fI5PXd8fJzKysriGlQExK7ZU662iV2zo1ztgvK1ba527dixY1Br3TTlAa31OXXbvHmznivbtm2b83NLidg1e8rVNrFrdpSrXVqXr21ztQvYrvOcUyWcJQiCIMwZERFBEARhzoiICIIgCHNGREQQBEGYMyIigiAIwpwREREEQRDmjIiIIAiCMGdERARBEM4Anj48yI4TIwttxhRERARBEMqcn+zo5L3ffp5/fHj/QpsyhXOu7YkgCMKZxO6uMf7iJzvRGkYjyYU2ZwriiQiCIJQxe7rH0BouX1FPKJZaaHOmICIiCIJQxvSMxVAK1rZUEYyKJyIIgiDMgt6xGI0BLw0BD6F4inRGL7RJOYiICIIglDE9YzFaa3zUVLgBCMWm90ZODUd45vDgfJkGiIgIgiCUFV/45V5+/kqX/XvvWIxF1T6qfYaIBKPT50W++cQR3n/PCwyE4iW300JERBAEoYy478VT/PLVHvv37rEobbUVVJueyNgMeZHhcIJkWnP/9lMlt9NCREQQBKFMSKYzBGMpeoMxAMLxFKFYikU1Pqp9xoqM4AzhrNFoAoAfPX9y3nInIiKCIAhlwsi4IQI9Y4aI9Jo/W2t81PitcNYMIhJJEvC66BqN0nGgv8TWGoiICIIgLBBPHx7k9+95gYzpNQxHDBEZDMdJpDK2iGTnRGYKZ41Fk1y3thGAw/3hUppuU1YiopT6jlKqXym1O2vbZ5VSXUqpV8zbLVmPfUopdVgpdUAp9caFsVoQBGFuPHlogG0HBmxhGA4bIqI19Idi9IxFAXJyIlY4K5nOcP+Lp4gkJhLtY9EkzVU+ABKpzLx8hrISEeC7wE15tv+T1voi8/Y/AEqpDcDtwEbzOd9QSjnnzVJBEITXiFVFNWJ6IENmOAuMUJYV1mqu9lLpceJ0KLs66zu/OcZf/vRVHttnhK3iqTSRRJrGgAenQxE/F0VEa/0kMFzg7rcC92qt41rrY8Bh4LKSGScIglBkLBEZNT0RS0zAyIv0jMVoDHjwupwopaj2uQjGknSPRvnnxw4BMGwKj+XN1FS48bocxFPpefkMZ0oDxjuVUu8DtgMf11qPAIuB57L26TS3TUEpdQdwB0BLSwsdHR1zMiIcDs/5uaVE7Jo95Wqb2DU7ytUuKMy2Yz0RAJ56fgfBoy52HJoQkd+8tId9w2kqHdp+HTcpDh7v5BMnu0il0ijgpT0HWZY4TlfY8Dy6jh9G6TRHT5yio2Nqcr3Yx+xMEJFvAp8HtPnzK8AHAZVn37w1bVrru4G7AbZs2aK3bt06J0M6OjqY63NLidg1e8rVNrFrdpSrXVCYbdGnHgESLF29nq0Xt/PY6G5qe7pJpjJUNi4mPNTPuvYAW7duAaBl12+oCHg4NRLl+vWVvHBsmJrmNrZu3cSLx4fhN89y1eYLeeDYTppamtm69YI52TUbyiqclQ+tdZ/WOq21zgD/zkTIqhNYkrVrO9A93/YJgiDMhWQ6Y+dArBbvw+MJ6is9LKrxsePEMEcHx7lsRb39nOoKF8PjCY4PjrO6OUB9pccOZ1mvUet343E5ztnE+hSUUq1Zv74dsCq3HgRuV0p5lVIrgDXAC/NtnyAIZx7j8RS3fv1pdp4aXTAbhsIToauRbBHxe2itqWBn5xgAW9c12/vVVLjZ3xsildGsbgpQ53fb4mHlRGorjBzKOZlYV0r9GHgWWKeU6lRKfQj4klJql1LqVeAG4C4ArfUe4H5gL/Ar4GNa6/nJJAmCUPbcdd8r/Nn9r+R97HB/mJ2nRnn55MS42f/Z1cMd39+O1prtx4f5+P077fUbpSC7v9WYmVDP9kQA2usqWNVUae9X7XPb4jDVEzF+1vjnN7FeViKitX631rpVa+3WWrdrrb+ttX6v1vp8rfUFWuu3aq17svb/gtZ6ldZ6ndb6oYW0XRCE8uKFY8P8ancviVSGWDJNLDlxUu0eNdZfBLOGPH33meP8em8fvcEYP32pi5++1JlTcltsBsIx+75VnTU0nqAh4KHVFJGt65pQaiL9a60VAVjVHKDW77ErusaiSRwKqrwuPC5Hjifyi53dvO87LxCOF3+oVVmJiCAIwnQ8vr+P/pBx4v3lqz0cHZh+RXYynaFnLEokkWZn5yh3/uglPvL97fbjXZaImCfv0UiCHScMr2R3V5DdXUYoqS8Yo1T0Bw1PpDHgZTSSRGvNSCTXE7khK5QF2O3gW2t8BLyuKTmR6go3DocyPZEJEdndPcZzR4bwu4u/lE5ERBCEsieRyvDh723ne88cJ5PR3HXfK/Y6iXz0jMawIlE/fuEkj+7r50hWG5DuUUMcrNXfHQcG7IaFL58c4UBvCMAWrddKz1jUFikLK5y1urmS0WiSYNQYOFXn9/CGDS189PpVXLOmMec5VhPG1c0BAOr8HuKpDNFEmtFoklpTZDwuZ05ivXs0RlutD4cjX1Hra0NERBCEsmc0kiCjDXEYiSRIpDO8MkNS/NSIsf7C43TwXy8ZszkGwnG0NoTCCmdZyehH9/XRGPCwqqmSn73cRSJtnIB7x4ozl+Pr2w7zu//+XM5AqYFwnJoKN81VPkYjCYbGjfdqCHhorvLxyZvX43Xleg5WOGtVkyUixu/DkQRj0SQ1fg/AFE+kayRCW21FUT7LZEREBEEoK/qDMftkb2FVL/WFYnab9BNDEYKJ/InvU8OGiNy0aREASkEyre1Kpu4xK5yVIpXO8MTBAW5Y18yF7bV0j014H8UKZw2PJ4inMjy8p8/e1h+M01zlpdassLJyG/WV3mlfx2rCaHsilYZojIwnGIsksjyR3MS64YmIiAiCcBahteaHz5/g/hcnBij1h2Jc/cXH+elLXTn7WpVHvWMxO5cAcHQ0fwXSqZEITofi3ZctBeDWC9sA4+ofshPrSQbCcUKxFBctrWXj4hoAqnwuGgOeooWzLI8ne2LhQDhOU5WX2gq3YUfI+IwNpjDkY1VTgEqPk0uXG2tH6s19h8cTRjjL9Ey8WetEEqkMfaEYi0VEBEE4k/j6tsM8urcv72OZjObTP9vNpx/Yzd1PHbW3Hx0YJ5nW/GJn7rphyxPpD8ZtTwTgyKhxojw1HOETP3nVDhedGo7SVuvjylUNdPz5VltM+oNxYsk0g+YajWAsaa/XaKj0sqmtGoBNbTW0VPvoyxKsn73cNWMIbSaspolPHx60cyEDIVNE/B60hmOD48CEd5GPpQ1+9nzuJtYtqjL2NcNXI5EEo5GJnEj2OpG+YAytERERBGF27OsJlqSksxASqQxfffQg33n6WN7HX+kc5UfPn6TW76Y/SxS6RgwP4Zkjgzn5A8sTCcVT9sl2XUsVR8YMT+SRvX3ct/0UXzOT7Z0jEZbU+QFY3lhJc7VR7TQQnuiMW+VzEYym7DLexoCHDW3VOB2KC5fUmiJi7Ku15tMP7OLbv8n/eU7HWDTJea3VZLSRf9FaGyIS8Nrew9OHB6nyulhk2loIVk5kMJwgGJuUEzFLmjvNY7q4TkREEIQCiSRS3PqvT/O9Z47PuN8rp0ZLMrzoYF+IZFrzaudY3gV71tX/FSsaCMZS9hoOq/Q2mdY8cXDA3n80axDTzlOjNAY8XLqijqOjGdIZzfEhQ1juefo4h/pCnBqJ2iIC0FRl5Bn6g3E7lHXeompCsSSDISuh7aXK5+a+O67go9evpKX
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
2021-03-19 22:04:52 +01:00
"source": [
"plt.plot(mean_rewards)\n",
"plt.xlabel(\"Episode\")\n",
"plt.ylabel(\"Mean reward\")\n",
"plt.grid()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now let's look at the result!"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 83,
2021-03-19 22:04:52 +01:00
"metadata": {},
"outputs": [],
"source": [
"def lander_render_policy_net(model, n_max_steps=500, seed=42):\n",
" frames = []\n",
" env = gym.make(\"LunarLander-v2\")\n",
" tf.random.set_seed(seed)\n",
" np.random.seed(seed)\n",
2022-04-05 11:47:12 +02:00
" obs = env.reset(seed=seed)\n",
2021-03-19 22:04:52 +01:00
" for step in range(n_max_steps):\n",
" frames.append(env.render(mode=\"rgb_array\"))\n",
" probas = model(obs[np.newaxis])\n",
2021-10-17 04:04:08 +02:00
" logits = tf.math.log(probas + tf.keras.backend.epsilon())\n",
2021-03-19 22:04:52 +01:00
" action = tf.random.categorical(logits, num_samples=1)\n",
" obs, reward, done, info = env.step(action[0, 0].numpy())\n",
" if done:\n",
" break\n",
" env.close()\n",
" return frames"
]
},
{
"cell_type": "code",
2022-04-05 11:47:12 +02:00
"execution_count": 84,
2021-03-19 22:04:52 +01:00
"metadata": {},
"outputs": [],
"source": [
"frames = lander_render_policy_net(model, seed=42)\n",
"plot_animation(frames)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"That's pretty good. You can try training it for longer and/or tweaking the hyperparameters to see if you can get it to go over 200."
2021-03-18 10:16:38 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"## 9.\n",
"_Exercise: Use a Double Dueling DQN to train an agent that can achieve a superhuman level at the famous Atari Breakout game (`\"ALE/Breakout-v5\"`). The observations are images. To simplify the task, you should convert them to grayscale (i.e., average over the channels axis), crop them and downsample them, so they're just large enough to play, but not much more. An individual image does not tell you which way the ball and the paddles are going, so you should merge two or three consecutive images to form each state. Lastly, the DQN should be composed mostly of convolutional layers._"
2021-03-18 10:16:38 +01:00
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
2022-04-05 11:47:12 +02:00
"TODO"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Check out the [State-of-the-Art for Atari Games on paperswithcode.com](https://paperswithcode.com/task/atari-games)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 10.\n",
2021-03-18 10:16:38 +01:00
"_Exercise: If you have about $100 to spare, you can purchase a Raspberry Pi 3 plus some cheap robotics components, install TensorFlow on the Pi, and go wild! For an example, check out this [fun post](https://homl.info/2) by Lukas Biewald, or take a look at GoPiGo or BrickPi. Start with simple goals, like making the robot turn around to find the brightest angle (if it has a light sensor) or the closest object (if it has a sonar sensor), and move in that direction. Then you can start using Deep Learning: for example, if the robot has a camera, you can try to implement an object detection algorithm so it detects people and moves toward them. You can also try to use RL to make the agent learn on its own how to use the motors to achieve that goal. Have fun!_"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It's your turn now: go crazy, be creative, but most of all, be patient and move forward step by step, you can do it!"
]
2016-10-08 22:17:45 +02:00
}
],
"metadata": {
2022-04-05 11:47:12 +02:00
"accelerator": "GPU",
2022-02-19 10:09:28 +01:00
"interpreter": {
"hash": "95c485e91159f3a8b550e08492cb4ed2557284663e79130c96242e7ff9e65ae1"
},
2016-10-08 22:17:45 +02:00
"kernelspec": {
2022-02-19 10:09:28 +01:00
"display_name": "Python 3",
2016-10-08 22:17:45 +02:00
"language": "python",
2022-02-19 10:09:28 +01:00
"name": "python3"
2016-10-08 22:17:45 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
2016-11-25 09:34:55 +01:00
"version": 3
2016-10-08 22:17:45 +02:00
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
2016-11-25 09:34:55 +01:00
"pygments_lexer": "ipython3",
2022-04-05 11:47:12 +02:00
"version": "3.9.10"
}
2016-10-08 22:17:45 +02:00
},
"nbformat": 4,
2021-10-17 08:21:13 +02:00
"nbformat_minor": 4
2016-10-08 22:17:45 +02:00
}