diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index 9ad70dc..ba2379a 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -170,7 +170,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's install the gym library, which provides many environments for Reinforcement Learning. Some of these environments require an X server to plot graphics, so we need to install xvfb on Colab or Kaggle (that's an in-memory X server, since the runtimes are not hooked to a screen). We also need to install pyvirtualdisplay, which provides a Python interface to xvfb. And let's also install the Box2D and Atari environments. By running the following cell, you also accept the Atari ROM license." + "Let's install the gym library, which provides many environments for Reinforcement Learning. We'll also install the extra libraries needed for classic control environments (including CartPole, which we will use shortly), as well as for Box2D and Atari environments, which are needed for the exercises.\n", + "\n", + "**Note:** by running the following cell, you accept the Atari ROM license." ] }, { @@ -181,36 +183,7 @@ "source": [ "if \"google.colab\" in sys.modules or \"kaggle_secrets\" in sys.modules:\n", " %pip install -q -U gym\n", - " %pip install -q -U gym[box2d,atari,accept-rom-license]\n", - " !apt update &> /dev/null && apt install -y xvfb &> /dev/null\n", - " %pip install -q -U pyvirtualdisplay" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Warning**: some environments (including the Cart-Pole) require access to your display, which opens up a separate window. In general you can safely ignore that window. However, if Jupyter is running on a headless server (ie. without a screen) it will raise an exception. Examples of headless servers include Colab, Kaggle, or Docker containers. One way to avoid this is to install an X server like [Xvfb](http://en.wikipedia.org/wiki/Xvfb), which performs all graphical operations on a virtual display, in memory. You can then start Jupyter using the `xvfb-run` command:\n", - "\n", - "```bash\n", - "$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter lab\n", - "```\n", - "\n", - "Alternatively, you can install the [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) Python library which wraps Xvfb, and lets you create a virtual display. Let's create a virtual display using `pyvirtualdisplay`, if it is installed:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "try:\n", - " import pyvirtualdisplay\n", - "\n", - " display = pyvirtualdisplay.Display(visible=0, size=(1400, 900)).start()\n", - "except ImportError:\n", - " pass" + " %pip install -q -U gym[classic_control,box2d,atari,accept-rom-license]" ] }, { @@ -224,32 +197,66 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "In this notebook we will be using [OpenAI gym](https://gym.openai.com/), a great toolkit for developing and comparing Reinforcement Learning algorithms. It provides many environments for your learning *agents* to interact with. Let's import Gym and make a new environment:" + "In this notebook we will be using [OpenAI gym](https://gym.openai.com/), a great toolkit for developing and comparing Reinforcement Learning algorithms. It provides many environments for your learning *agents* to interact with. Let's import Gym and make a new CartPole environment:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import gym\n", + "\n", + "env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The CartPole (version 1) is a very simple environment composed of a cart that can move left or right, and pole placed vertically on top of it. The agent must move the cart left or right to keep the pole upright." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Tip**: `gym.envs.registry` is a dictionary containing all available environments:" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "['ALE/Adventure-ram-v5',\n", + " 'ALE/Adventure-v5',\n", + " 'ALE/AirRaid-ram-v5',\n", + " 'ALE/AirRaid-v5',\n", + " 'ALE/Alien-ram-v5',\n", + " '...']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "import gym\n", - "\n", - "env = gym.make(\"CartPole-v1\")" + "# extra code – shows the first few environments\n", + "envs = gym.envs.registry\n", + "sorted(envs.keys())[:5] + [\"...\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The Cart-Pole (version 1) is a very simple environment composed of a cart that can move left or right, and pole placed vertically on top of it. The agent must move the cart left or right to keep the pole upright." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Tip**: you can use `gym.envs.registry.all()` to get the full list of available environments:" + "The registry values are environment specifications:" ] }, { @@ -260,12 +267,7 @@ { "data": { "text/plain": [ - "['ALE/Tetris-v5',\n", - " 'ALE/Tetris-ram-v5',\n", - " 'ALE/Asterix-v5',\n", - " 'ALE/Asterix-ram-v5',\n", - " 'ALE/Asteroids-v5',\n", - " '...']" + "EnvSpec(id='CartPole-v1', entry_point='gym.envs.classic_control.cartpole:CartPoleEnv', reward_threshold=475.0, nondeterministic=False, max_episode_steps=500, order_enforce=True, autoreset=False, disable_env_checker=False, apply_api_compatibility=False, kwargs={}, namespace=None, name='CartPole', version=1)" ] }, "execution_count": 9, @@ -274,16 +276,15 @@ } ], "source": [ - "# extra code – shows the first few environments\n", - "envs = gym.envs.registry.all()\n", - "[env.id for env in envs][:5] + [\"...\"]" + "# extra code – shows the specification for the CartPole-v1 environment\n", + "envs[\"CartPole-v1\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's initialize the environment by calling is `reset()` method. This returns an observation:" + "Let's initialize the environment by calling is `reset()` method. This returns an observation, as well as a dictionary that may contain extra information. Both are environment-specific." ] }, { @@ -303,24 +304,10 @@ } ], "source": [ - "obs = env.reset(seed=42)\n", + "obs, info = env.reset(seed=42)\n", "obs" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Observations vary depending on the environment. In this case it is a 1D NumPy array composed of 4 floats: they represent the cart's horizontal position, its velocity, the angle of the pole (0 = vertical), and the angular velocity." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment)." - ] - }, { "cell_type": "code", "execution_count": 11, @@ -329,7 +316,7 @@ { "data": { "text/plain": [ - "(400, 600, 3)" + "{}" ] }, "execution_count": 11, @@ -338,8 +325,21 @@ } ], "source": [ - "img = env.render(mode=\"rgb_array\")\n", - "img.shape # height, width, channels (3 = Red, Green, Blue)" + "info" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the case of the CartPole, each observation is a 1D NumPy array composed of 4 floats: they represent the cart's horizontal position, its velocity, the angle of the pole (0 = vertical), and the angular velocity." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "An environment can be visualized by calling its `render()` method. If you set `render_mode` to `\"rgb_array\"` when creating the environment, then this will return a NumPy array." ] }, { @@ -349,14 +349,33 @@ "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAASUAAADICAYAAACuyvefAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAGVElEQVR4nO3dzY4bWRmA4a9sd/d0t8RkEpQwAoSQRggh2MEFIHqB2LDOPlJug1vgIlhklQ3bSFmEDdIo7IKIYBDTIWnCQJOQ/nFX1WGRQcLTaTdlV8WfM8+zPGVZZ1F6fapc9qlKKQGQxWjVEwD4X6IEpCJKQCqiBKQiSkAqogSkMrnkuOcFgCFUFx2wUgJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSAVUQJSESUgFVECUhElIBVRAlIRJSCVyaonwLujmR5HW09nxkaTrRhvvreiGbGORIlelFLiyce/jr///jcz4ze+/+P4+o9+vqJZsY5EiZ6UaKbH0UyPZkbb+mxF82FduadEP0qJ0tSrngXvAFGiF6WUaFtRYnmiRE+slOiHKNEPl2/0RJTohcs3+iJK9KJtzmL68rPZwaqKra98dTUTYm2JEr0oTR318cuZsaoaxcbuByuaEetKlBhQFaOxR+HoRpQYVDXeWPUUWDOixHCqsFKiM1FiQFWMrJToSJQYVGWlREeixGCqykqJ7kSJXpS2ifKmA1X1tqfCmhMlelHaOuLNWYJORIletE2tSfRClOhFac5CleiDKNGL1j8E0BNRohelqaMUKyWWJ0r0oj49iijtzFg13ojKt290JEr04vgfT6K0zczYe+/fiNHm9opmxLoSJXpy/tJtNJ5YKdGZKDGYajSOCFGiG1FiMNV4ElXlFKMbZwyDqUZjPzOhM1FiMO4psQhRYjDVaBLh8o2OnDEsrZTyxgcnq5HTi+6cNfTioo0oXb7RlSjRg+K3b/RGlFheKVHas1XPgneEKNGLiy7foCtRYmmluHyjP6JEL6yU6IsosbRS2jg7eXlufLy1u4LZsO5EiaWVpo7Tw4MvjFaxc+2bK5kP602UGIYtu1mQKDGYykaULECUGMxoIkp0J0oMxJbdLEaUGEzlnhILECUGY6XEIkSJpZXSvnlvXP8QwAJEiaWVpo6wESU9ESWW9vonJqJEP0SJpbWtLbvpjyixNJdv9EmUWFpp64hoVz0N3hGixNLa5vzlW1VVUdkdlwV4uo25jo6Oommaua85fPpJtGfTmbHJ7tU4qSPql+f/0uS/qqqK3d1dmwswo7rkBqUbBV9yN2/ejAcPHsx9zc9++K249dPvzYz98em/4he/+m28Orn4z9+uXbsW9+/fjytXrvQxVdbLhZ9EVkrM9fz589jf35/7msOP3o/Tdjv+evpR1GUjvrb5SRwf/y3295/Eq5OLNxSYTqfRtu5FMUuUWNpJuxMPX/wk/ll/GBER+yffjTj9LJrWQpvuRImlHUy//XmQXq/IT9rd+PTfP4hWlFiAb98YRN200Xp2iQWIEku7sfnn+GDyLF5/L1Jia/QqvrH5u2jcL2IBcy/fDg8P39I0yKquL9866dOnf4k/fPzLeHLynajLRny49ac4OHh86UPebdvGixcvYjTy2fhlM+8b17lRunPnTt9zYc0cHHxxl5LzHj5+Fg8fP4uIe53e+/T0NO7evRs7OzsLzo51dfv27QuPeU6Jufb29uLevW6x+X9dv349Hj16FFevXh3k/UntwueUrJuBVEQJSEWUgFRECUhFlIBU/MyEuW7duhV7e3uDvPfOzk5sb28P8t6sL48EAKvgkQBgPYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkIooAamIEpCKKAGpiBKQiigBqYgSkMrkkuPVW5kFwOeslIBURAlIRZSAVEQJSEWUgFRECUjlPyRsVw4Q1yV5AAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "(400, 600, 3)" ] }, - "metadata": { - "needs_background": "light" + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "img = env.render()\n", + "img.shape # height, width, channels (3 = Red, Green, Blue)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAEWCAYAAACqitpwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAIQ0lEQVR4nO3dT49kVR3H4d+91T1/mXEACUPUmKiBiQlLN0gyJi7cGN6AL4DEN+C7cM/ed2EMezDEaIIYDWExDI1EBgdm6OmqusfFwHQPds+cgu90VTPPs71V1b9N5ZNzTvW9Q2utFQAEjeseAIBvH3EBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIgTFwDixAWAOHEBIE5cAIjbWvcAcJK01urDv/6xbv/n2qHXv/vCS3Xxe1eOeSrYPOICq2hT/ffa23Xz2tuHXn7i2R+LC5RtMVhJm6Zq07TuMWDjiQusoE3LqiYu8DDiAitobWnlAh3EBVbQpqlaW657DNh44gIraNOymm0xeChxgRW0aVllWwweSlxgBVYu0EdcYAWt+Sky9BAXWMHyzu1azncPvTbMtmt26uwxTwSbSVxgBXduflTzWzcOvbZ97mKdefLyMU8Em0lcIGQYxhpGd1SCKnGBmGEYaxxn6x4DNoK4QMo41iAuUFXiAjF3t8XEBarEBXLEBe4RFwgZhkFc4AviAiG2xWCfuEDKONYwExeoEhfo1lp74HUrF9gnLrCCaXrQs1yGGgZfKagSF1hJWy7WPQKcCOICK5iW83WPACeCuEC3ZuUCncQFejUrF+glLtDNygV6iQuswMoF+ogLrKCJC3QRF+jUWqvJthh0ERfo1mrx+c0jr26dPnuMs8BmExfo1KZl3Xz/H0dev/j9nx7jNLDZxAVCxtn2ukeAjSEuEDKIC9wjLhAybokLfElcIMTKBfaJC4SMs611jwAbQ1wgxIE+7BMXCHHmAvvEBUKcucA+cYEQ22KwT1ygU5taVbUjrw/j7PiGgQ0nLtCpTe6IDL3EBTp5lgv0ExfoNC0WVe3obTFgn7hAp7acP+DEBThIXKCTB4VBP3GBTncfcWztAj3EBTpZuUA/cYFObTm3cIFO4gKd7q5c1AV6iAt0+vzG9WptOvTa6YvPuP0LHCAu0Gn3k50j/8/l7JPPuSsyHCAuEDCMW1XDsO4xYGOICwQMs62qEhf4krhAwDDbqsHKBe4RFwgYZ7bF4CBxgYBhtC0GB4kLBIy2xeA+4gIBVi5wP3GBgMGZC9xHXKBDa+2BDwobx9kxTgObT1ygR5uOvPXLXYMzFzhAXKBDm6Zq03LdY8CJIS7QobWp2vSglQtwkLhAhzYtrVxgBeICHVqbqjVxgV7iAj2cucBKxAU6tDZVOXOBbuICHZy5wGrEBTrc/bWYuEAvcYEOy91bNd/99NBr49ap2j7/nWOeCDbb1roHgOO2u7tbOzs7K71n76N/1fzWJ4dea7NTdeN2q8/ee6/7886cOVOXL19eaQY4ScSFx85bb71VL7/88krv+fmLP6jf//ZXh17b2dmp37zySv3z2sfdn3f16tV6/fXXV5oBThJx4bHUHnATykNfP919/bLNanc6X1Mba3vYq9Pj7ZqmVnvz5Uqfuerfh5NGXKDTso31989eqg/3fljzdrouzD6u58+/WVN7u+YLh/1wkLhAh3k7XX/77Bf1wZ0f15cPBbu5fKb+8ukv6+ndT2ux9D8wcJBfi0GHTxbP1gd3flJffdrkop2ud279rOYLcYGDxAW+oWlqNV/aFoODxAU6DNVqqCNWJ21RCysXuI+4QIent9+v58+98X+BOT+7US8+8ScrF/gKB/rQYblc1JPTm/Vc3apruy/U3nS2Lm1/WD/a/nP9e/ejWiz9tBgO6o7La6+99ijngGPz7rvvrvyeN955v379uz9U+2KDrGq4t1X2dbJy/fp13ylOrFdfffWhr+mOy5UrV77RMLApll9jC6u1qnnw58bnzp3zneJbrTsuV69efZRzwLHZ3t5e9wh16dIl3ym+1RzoAxAnLgDEiQsAceICQJy4ABAnLgDEiQsAcW7/wmNna2urnnrqqbXOcOHChbX+fXjUhuZ5qzxmpmmqvb29tc4wjmOdOnVqrTPAoyQuAMQ5cwEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOLEBYA4cQEgTlwAiBMXAOL+B2BEo3liTWDwAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] }, + "metadata": {}, "output_type": "display_data" } ], @@ -365,7 +384,7 @@ "\n", "def plot_environment(env, figsize=(5, 4)):\n", " plt.figure(figsize=figsize)\n", - " img = env.render(mode=\"rgb_array\")\n", + " img = env.render()\n", " plt.imshow(img)\n", " plt.axis(\"off\")\n", " return img\n", @@ -383,7 +402,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -392,7 +411,7 @@ "Discrete(2)" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -405,7 +424,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Yep, just two possible actions: accelerate towards the left or towards the right." + "Yep, just two possible actions: accelerate towards the left (0) or towards the right (1)." ] }, { @@ -417,7 +436,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -426,14 +445,14 @@ "array([ 0.02727336, 0.18847767, 0.03625453, -0.26141977], dtype=float32)" ] }, - "execution_count": 14, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "action = 1 # accelerate right\n", - "obs, reward, done, info = env.step(action)\n", + "obs, reward, done, truncated, info = env.step(action)\n", "obs" ] }, @@ -446,19 +465,17 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVgAAADqCAYAAADnGV2KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAHWElEQVR4nO3dz25cVx3A8d+9M/Y4tZvIIsWKi/jTboAliyzZ8AiE98gTRCiPEPEczRYpbFgiIYMQtF1AiUoEaik2CiGNZ+bew8KVUGrP2E7z89wz/XyWvtb4LK6+OTlz7j1NKSUAeP3aVQ8AYF0JLEASgQVIIrAASQQWIInAAiQZn3PdHi6A5ZpFF8xgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiDJeNUDgMsopY/Dv/w2+tnxqWu77/woxpPtFYwKziawVKXv5vHkN+/F9Nnhl640sbP3rsAyKJYIqErp5hGlrHoYcCECS1VKN48SAksdBJaqlN4MlnoILFXp+3kUgaUSAktVSteZwVINgaUqfTePsAZLJQSWqhRLBFREYKnKi39/cuZDBptvfiNGm1srGBEsJrBUZf7505OdBF+y8cb1aMebKxgRLCawrIW2HUc0bmeGxR3JWmhGo2iaZtXDgJcILGuhMYNlgNyRrIWmNYNleASWtWAGyxC5I1kL1mAZIoGlGsseMLCLgCFyR1KVUvqzLzSNGSyDI7BU5eRdBFAHgaUi5eREA6iEwFKPUqLvZqseBVyYwFIVM1hqIrBUo5RiDZaqCCxVKZYIqIjAUpES/RmvKoShEliqUbouXhz94/SFpomtG3tXPyA4h8BSjVK6mD47PPXzpmljcv2tFYwIlhNY1kI7Gq96CHCKwLIGmmhGG6seBJwisNSviWgFlgESWNaCJQKGSGBZA5YIGCaBZS20YzNYhkdgqV7TmMEyTAJLNUq/+ESDph1d4UjgYgSWapQlj8k6y4AhEliq0XezpedywdAILNU4eReswFIPgaUafTfXV6oisFTj5F2wCks9BJZqOM2A2ggs1Si+5KIyAks1+vk0zlwiaGzSYpgElmo8++SjM0+V3fnmO9FuTFYwIlhOYKnGogcN2o1JRONWZnjclVSvGY2j8SwXAySwVK8dbXhWlkESWKrXjMahsAyRwFK9th1HYycBAySwVM8MlqESWKrXjsb2wjJIAksVSikLX0PgZdsMlcBSiRL9whduN9ZgGSSBpQ6lROm7VY8CLkVgqUIp5czHZGHIBJY6lGVLBDBMAksVSpQonSUC6iKw1KH0S0+VhSESWOpgDZYKCSxVKKUsPjLGFi0GarzqAfD1dZnjX7rj5/H8s49P/bwdT2L7re9e6rPsmeWqNOfcmA5AIs2TJ0/izp07MZ+f/1//m9cn8fOf/TBG7ctxfH48j1/88s/x10+fn/8ZN2/Gw4cP49q1a688ZjjDwn+xzWBZmePj4zg4OIjZbHbu7+7ffDP6n34/SjOJF912NE2JrfZZdF0Xf/zTB/Hhx5+d+xm3bt2Kvu9fx9DhQgSWakzLG/Hhf34c/5p+K5qmi1uTj2I/fh3Tme1bDJPAUoV5P4nfP/1JPC1vR0QTUTbiby9+EIezeUzn7616eHAmuwiowqxM4mi+Fy8vdzXxz+nbcTz3VQHDJLBUrZSI6dy6KsMksFRh0v43vrP1fkT8P6ZtzON7134X3Xy6uoHBEtZgqULpZ7Hz+a9iZ/pp/P343WibPr699X5sTv8Qswts84JVWLoP9v79+xa3SHN0dBQPHjy48Napkz2wTZQv1mGbKBFRousvdpvu7OzE3bt3Y2Nj4xVHDKfdu3dv4T7YpYE9OjoSWNI8fvw4bt++faEHDV6Hvb29ODg48KABr9Xu7u6rPWiwu7v7+kcDXzg8PLzSx1bbto0bN27E9vb2lf1Nvt58yQWQRGABkggsQBKBBUgisABJBBYgiSe5WJn9/f149OjRpU4j+Co2Nzdja2vrSv4WRDjRAOCrWriZ2xIBQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQRWIAkAguQRGABkggsQBKBBUgisABJBBYgicACJBFYgCQCC5BEYAGSCCxAEoEFSCKwAEkEFiCJwAIkEViAJAILkERgAZIILEASgQVIIrAASQQWIInAAiQZn3O9uZJRAKwhM1iAJAILkERgAZIILEASgQVIIrAASf4HJtl/NcHAwnIAAAAASUVORK5CYII=\n", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAFFCAYAAACpCDdAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAJ1UlEQVR4nO3dS48c1RnH4be624yxiRPAtpRgZxWJCCTIig0blEiRIr4EEt8sS8QqnwCEkqBcTJA3kAQDloJkLgKDbbD7crIwQk5gZtoD/ned8vOsLHf16N2UfjqnqquG1lorAOCem+16AAC4X4guAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhIguAISILgCEiC4AhCx2PQBwdNc/er+ufXhp389n82N1+vFnaxiG4FTAfkQXOtVaq88uX6wP/vaHfY9Z7J2s048/G5wKOIjtZehV21Rbr3Y9BXAXRBc61VqrtlnvegzgLogu9KptRBc6I7rQqdsrXdvL0BPRhV61TW2sdKErogudck0X+iO60Ct3L0N3RBc6ZaUL/RFd6FXbVFuLLvREdKFTVrrQH9GFTm3Wq1rdunHgMfO9k6FpgG2ILnRqeeOzun7lnQOPeeQXz4SmAbYhujBhs7l3msCYiC5M2GxxbNcjAHcQXZiw2fyBXY8A3EF0YcIGK10YFdGFCZvNRRfGRHRhwgbRhVERXZgwN1LBuIguTJjtZRgX0YUJE10YF9GFCbO9DOMiutCh1tpWx4kujIvoQqe2e63fUMMw3PNZgO2ILnRqs17uegTgLokudKqtV7seAbhLogudstKF/ogudEp0oT+iC11q1UQXuiO60KnNSnShN6ILPWpVGzdSQXdEFzrV1rd2PQJwl0QXOmWlC/0RXeiSG6mgR6ILHWqt1fWP3j/wmL0fn/XsZRgZ0YUetVbXP3z3wENOPHKuZou90EDANkQXJmqYL6q87ABGRXRhombzRUkujIvowkQNMytdGBvRhYka5osqa10YFdGFiRrmCy+wh5ERXZio2cxKF8ZGdGGi3L0M4yO6MFEz0YXREV2YqGG2qMH2MoyK6MJE3d5e3vUUwJ1EF7rUDj3CjVQwPqILHdps1ocfNAx+MgQjI7rQoeZdutAl0YUObda3dj0CcASiCx3arJZV7fDrusC4iC50qK1XW9xKBYyN6EKHNqvlrkcAjkB0oUObtehCj0QXOtREF7okutChjZ8MQZdEFzrU/GQIuiS60KHNalXbPAoSGBfRhQ6tV4evdD0BEsZHdKFDn/zzzwc+HOPEo+frxOmfBycCtiG60KO2OfDjYTavYbYIDQNsS3RhimazGganN4yNsxImaBjmVTOnN4yNsxImaJjNapjNdz0G8H9EFyZoGGwvwxg5K2GKZnMrXRgh0YUJur297PSGsXFWwgTd3l620oWxEV2YoNu/03V6w9g4K2GChmFWZaULoyO60Jl2wOMfv+GaLoySsxJ60zaHhneooQZvPIDREV3ozO0X2B/87GVgnEQXOtM26+22mIHREV3oTNusD3ytHzBeogudaetVtUNe7QeMk+hCZzZWutAt0YXOtM3KNV3olOhCZ25f07W9DD0SXehMW69sL0OnRBc64ydD0C/Rhc60zcr2MnRKdKEzyxtf1Ga93P+AYVaLB3+UGwjYmuhCZ65/9F6tvrq27+fzY8fr1M9+GZwI2NbQXByCuBdeeKEuXbp0pO8+/6sz9esnH9338+s3V/X71/5T/77y5ZH+/ssvv1xnz5490neBgy12PQDcjy5cuFAXL1480nefPv1M1QHRXS5XdeGNN+vNd64c6e/fvHnzSN8DDie60Knl5oH6ePlY3dycrMVwq36yuFIPLa5Wa62WKzdawRiJLnRo02b1jy9+U1dXp2vZjte8VnVy/lk98dCfqrX36tZqvesRge/gRirozK3NXv3l6vP18fJcLduDVTXUuo7V5+sz9dfPf1dXV4/UUnRhlEQXOvPul0/Xp6ufVtXwrc/W7YH6+9Xf2l6GkRJdmJjWyvYyjJTowsS0araXYaREFzrz0OLTWgz7/ayn1cOLD0QXRkp0oTOP7f2rnjj5x6r69nXbc3tv1xMnX6tbrunCKPnJEHRmuVrXw8Nb9eTetXrny6fri9WjtTe7UeePv1Xnjr1dn1z7qlZr0YUx2voxkK+88so9HgXuHy+++OKRHwM5DFXD8PWdy22o9s3/f/2vVrX5Hk93femll+rMmTNH/j7cr5577rlDj9l6pfvqq69+n1mAO1y7tv8LCw7TWt3xPt32P///Q3j99dfr1KlTP8wfg/vINtH1wgPYgaeeeurIz16+1y5fvlznz5/f9RgwSW6kAoAQ0QWAENEFgBDRBYAQ0QWAENEFgBDRBYAQ0QWAENEFgBAvPIAdOH78eJ04cWLXY3ynb57rDPzgPAYSdmC9XtdYT735fC68cI+ILgCEuKYLACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACGiCwAhogsAIaILACH/BZqbyhO+q2x5AAAAAElFTkSuQmCC\n", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -485,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -494,7 +511,7 @@ "1.0" ] }, - "execution_count": 16, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -507,12 +524,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "When the game is over, the environment returns `done=True`:" + "When the game is over, the environment returns `done=True`. In this case, it's not over yet:" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -521,7 +538,7 @@ "False" ] }, - "execution_count": 17, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -530,6 +547,33 @@ "done" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Some environment wrappers may want to interrupt the environment early. For example, when a time limit is reached or when an object goes out of bounds. In this case, `truncated` will be set to `True`. In this case, it's not truncated yet:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "truncated" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -539,7 +583,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -548,7 +592,7 @@ "{}" ] }, - "execution_count": 18, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -561,17 +605,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The sequence of steps between the moment the environment is reset until it is done is called an \"episode\". At the end of an episode (i.e., when `step()` returns `done=True`), you should reset the environment before you continue to use it." + "The sequence of steps between the moment the environment is reset until it is done or truncated is called an \"episode\". At the end of an episode (i.e., when `step()` returns `done=True` or `truncated=True`), you should reset the environment before you continue to use it." ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "if done:\n", - " obs = env.reset()" + "if done or truncated:\n", + " obs, info = env.reset()" ] }, { @@ -597,7 +641,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -608,12 +652,12 @@ "totals = []\n", "for episode in range(500):\n", " episode_rewards = 0\n", - " obs = env.reset(seed=episode)\n", + " obs, info = env.reset(seed=episode)\n", " for step in range(200):\n", " action = basic_policy(obs)\n", - " obs, reward, done, info = env.step(action)\n", + " obs, reward, done, truncated, info = env.step(action)\n", " episode_rewards += reward\n", - " if done:\n", + " if done or truncated:\n", " break\n", "\n", " totals.append(episode_rewards)" @@ -621,7 +665,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "metadata": {}, "outputs": [ { @@ -630,7 +674,7 @@ "(41.698, 8.389445512070509, 24.0, 63.0)" ] }, - "execution_count": 21, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -645,7 +689,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Well, as expected, this strategy is a bit too basic: the best it did was to keep the poll up for only 68 steps. This environment is considered solved when the agent keeps the poll up for 200 steps." + "Well, as expected, this strategy is a bit too basic: the best it did was to keep the poll up for only 63 steps. This environment is considered solved when the agent keeps the poll up for 200 steps." ] }, { @@ -657,7 +701,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -679,14 +723,14 @@ "\n", "def show_one_episode(policy, n_max_steps=200, seed=42):\n", " frames = []\n", - " env = gym.make(\"CartPole-v1\")\n", + " env = gym.make(\"CartPole-v1\", render_mode=\"rgb_array\")\n", " np.random.seed(seed)\n", - " obs = env.reset(seed=seed)\n", + " obs, info = env.reset(seed=seed)\n", " for step in range(n_max_steps):\n", - " frames.append(env.render(mode=\"rgb_array\"))\n", + " frames.append(env.render())\n", " action = policy(obs)\n", - " obs, reward, done, info = env.step(action)\n", - " if done:\n", + " obs, reward, done, truncated, info = env.step(action)\n", + " if done or truncated:\n", " break\n", " env.close()\n", " return plot_animation(frames)\n", @@ -712,12 +756,12 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's create a neural network that will take observations as inputs, and output the probabilities of actions to take for each observation. To choose an action, the network will estimate a probability for each action, then we will select an action randomly according to the estimated probabilities. In the case of the Cart-Pole environment, there are just two possible actions (left or right), so we only need one output neuron: it will output the probability `p` of the action 0 (left), and of course the probability of action 1 (right) will be `1 - p`." + "Let's create a neural network that will take observations as inputs, and output the probabilities of actions to take for each observation. To choose an action, the network will estimate a probability for each action, then we will select an action randomly according to the estimated probabilities. In the case of the CartPole environment, there are just two possible actions (left or right), so we only need one output neuron: it will output the probability `p` of the action 0 (left), and of course the probability of action 1 (right) will be `1 - p`." ] }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -754,14 +798,14 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "# extra code – a function that creates an animation for a given policy model\n", "\n", "def pg_policy(obs):\n", - " left_proba = model.predict(obs[np.newaxis])\n", + " left_proba = model.predict(obs[np.newaxis], verbose=0)\n", " return int(np.random.rand() > left_proba)\n", "\n", "np.random.seed(42)\n", @@ -793,7 +837,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To train this neural network we will need to define the target probabilities `y`. If an action is good we should increase its probability, and conversely if it is bad we should reduce it. But how do we know whether an action is good or bad? The problem is that most actions have delayed effects, so when you win or lose points in an episode, it is not clear which actions contributed to this result: was it just the last action? Or the last 10? Or just one action 50 steps earlier? This is called the _credit assignment problem_.\n", + "To train this neural network we will need to define the target probabilities **y**. If an action is good we should increase its probability, and conversely if it is bad we should reduce it. But how do we know whether an action is good or bad? The problem is that most actions have delayed effects, so when you win or lose points in an episode, it is not clear which actions contributed to this result: was it just the last action? Or the last 10? Or just one action 50 steps earlier? This is called the _credit assignment problem_.\n", "\n", "The _Policy Gradients_ algorithm tackles this problem by first playing multiple episodes, then making the actions near positive rewards slightly more likely, while actions near negative rewards are made slightly less likely. First we play, then we go back and think about what we did." ] @@ -807,7 +851,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -819,8 +863,8 @@ " loss = tf.reduce_mean(loss_fn(y_target, left_proba))\n", "\n", " grads = tape.gradient(loss, model.trainable_variables)\n", - " obs, reward, done, info = env.step(int(action))\n", - " return obs, reward, done, grads" + " obs, reward, done, truncated, info = env.step(int(action))\n", + " return obs, reward, done, truncated, grads" ] }, { @@ -839,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -849,12 +893,13 @@ " for episode in range(n_episodes):\n", " current_rewards = []\n", " current_grads = []\n", - " obs = env.reset()\n", + " obs, info = env.reset()\n", " for step in range(n_max_steps):\n", - " obs, reward, done, grads = play_one_step(env, obs, model, loss_fn)\n", + " obs, reward, done, truncated, grads = play_one_step(\n", + " env, obs, model, loss_fn)\n", " current_rewards.append(reward)\n", " current_grads.append(grads)\n", - " if done:\n", + " if done or truncated:\n", " break\n", "\n", " all_rewards.append(current_rewards)\n", @@ -872,7 +917,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 29, "metadata": {}, "outputs": [], "source": [ @@ -901,7 +946,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -910,7 +955,7 @@ "array([-22, -40, -50])" ] }, - "execution_count": 28, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -928,7 +973,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 31, "metadata": { "scrolled": true }, @@ -940,7 +985,7 @@ " array([1.26665318, 1.0727777 ])]" ] }, - "execution_count": 29, + "execution_count": 31, "metadata": {}, "output_type": "execute_result" } @@ -952,7 +997,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 32, "metadata": {}, "outputs": [], "source": [ @@ -964,7 +1009,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -978,12 +1023,12 @@ " tf.keras.layers.Dense(1, activation=\"sigmoid\"),\n", "])\n", "\n", - "obs = env.reset(seed=42)" + "obs, info = env.reset(seed=42)" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 34, "metadata": {}, "outputs": [], "source": [ @@ -993,14 +1038,14 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Iteration: 150/150, mean rewards: 193.1" + "Iteration: 150/150, mean rewards: 190.3" ] } ], @@ -1029,7 +1074,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 36, "metadata": {}, "outputs": [], "source": [ @@ -1054,7 +1099,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 37, "metadata": {}, "outputs": [ { @@ -1122,7 +1167,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 38, "metadata": {}, "outputs": [], "source": [ @@ -1148,18 +1193,18 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "Q_values = np.full((3, 3), -np.inf) # -np.inf for impossible actions\n", "for state, actions in enumerate(possible_actions):\n", - " Q_values[state, actions] = 0.0 # for all possible actions" + " Q_values[state, actions] = 0.0 # for all possible actions" ] }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 40, "metadata": {}, "outputs": [], "source": [ @@ -1181,7 +1226,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 41, "metadata": {}, "outputs": [ { @@ -1192,7 +1237,7 @@ " [ -inf, 50.13365013, -inf]])" ] }, - "execution_count": 39, + "execution_count": 41, "metadata": {}, "output_type": "execute_result" } @@ -1203,7 +1248,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 42, "metadata": {}, "outputs": [ { @@ -1212,7 +1257,7 @@ "array([0, 0, 1])" ] }, - "execution_count": 40, + "execution_count": 42, "metadata": {}, "output_type": "execute_result" } @@ -1251,7 +1296,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 43, "metadata": {}, "outputs": [], "source": [ @@ -1271,7 +1316,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 44, "metadata": {}, "outputs": [], "source": [ @@ -1288,7 +1333,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 45, "metadata": {}, "outputs": [], "source": [ @@ -1301,7 +1346,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ @@ -1326,7 +1371,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 47, "metadata": {}, "outputs": [ { @@ -1378,7 +1423,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ @@ -1403,7 +1448,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ @@ -1411,7 +1456,7 @@ " if np.random.rand() < epsilon:\n", " return np.random.randint(n_outputs) # random action\n", " else:\n", - " Q_values = model.predict(state[np.newaxis])[0]\n", + " Q_values = model.predict(state[np.newaxis], verbose=0)[0]\n", " return Q_values.argmax() # optimal action according to the DQN" ] }, @@ -1424,7 +1469,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": 50, "metadata": {}, "outputs": [], "source": [ @@ -1442,7 +1487,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": 51, "metadata": {}, "outputs": [], "source": [ @@ -1469,23 +1514,22 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "And let's create a function to sample experiences from the replay buffer. It will return 5 NumPy arrays: `[obs, actions, rewards, next_obs, dones]`." + "And let's create a function to sample experiences from the replay buffer. It will return 6 NumPy arrays: `[obs, actions, rewards, next_obs, dones, truncateds]`." ] }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "def sample_experiences(batch_size):\n", " indices = np.random.randint(len(replay_buffer), size=batch_size)\n", " batch = [replay_buffer[index] for index in indices]\n", - " states, actions, rewards, next_states, dones = [\n", + " return [\n", " np.array([experience[field_index] for experience in batch])\n", - " for field_index in range(5)\n", - " ]\n", - " return states, actions, rewards, next_states, dones" + " for field_index in range(6)\n", + " ] # [states, actions, rewards, next_states, dones, truncateds]" ] }, { @@ -1497,15 +1541,15 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "def play_one_step(env, state, epsilon):\n", " action = epsilon_greedy_policy(state, epsilon)\n", - " next_state, reward, done, info = env.step(action)\n", - " replay_buffer.append((state, action, reward, next_state, done))\n", - " return next_state, reward, done, info" + " next_state, reward, done, truncated, info = env.step(action)\n", + " replay_buffer.append((state, action, reward, next_state, done, truncated))\n", + " return next_state, reward, done, truncated, info" ] }, { @@ -1517,7 +1561,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": 54, "metadata": {}, "outputs": [], "source": [ @@ -1531,7 +1575,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 55, "metadata": {}, "outputs": [], "source": [ @@ -1542,11 +1586,11 @@ "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", - " states, actions, rewards, next_states, dones = experiences\n", - " next_Q_values = model.predict(next_states)\n", + " states, actions, rewards, next_states, dones, truncateds = experiences\n", + " next_Q_values = model.predict(next_states, verbose=0)\n", " max_next_Q_values = next_Q_values.max(axis=1)\n", - " target_Q_values = (rewards +\n", - " (1 - dones) * discount_factor * max_next_Q_values)\n", + " runs = 1.0 - (dones | truncateds) # episode is not done or truncated\n", + " target_Q_values = rewards + runs * discount_factor * max_next_Q_values\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", @@ -1567,7 +1611,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 56, "metadata": { "scrolled": true }, @@ -1582,11 +1626,11 @@ ], "source": [ "for episode in range(600):\n", - " obs = env.reset() \n", + " obs, info = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", - " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", - " if done:\n", + " obs, reward, done, truncated, info = play_one_step(env, obs, epsilon)\n", + " if done or truncated:\n", " break\n", "\n", " # extra code – displays debug info, stores data for the next figure, and\n", @@ -1606,7 +1650,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 57, "metadata": {}, "outputs": [ { @@ -1635,7 +1679,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 58, "metadata": {}, "outputs": [], "source": [ @@ -1666,7 +1710,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 59, "metadata": {}, "outputs": [], "source": [ @@ -1690,7 +1734,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 60, "metadata": {}, "outputs": [], "source": [ @@ -1707,7 +1751,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 61, "metadata": {}, "outputs": [], "source": [ @@ -1726,11 +1770,11 @@ "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", - " states, actions, rewards, next_states, dones = experiences\n", - " next_Q_values = target.predict(next_states) # <= CHANGED\n", + " states, actions, rewards, next_states, dones, truncateds = experiences\n", + " next_Q_values = target.predict(next_states, verbose=0) # <= CHANGED\n", " max_next_Q_values = next_Q_values.max(axis=1)\n", - " target_Q_values = (rewards +\n", - " (1 - dones) * discount_factor * max_next_Q_values)\n", + " runs = 1.0 - (dones | truncateds) # episode is not done or truncated\n", + " target_Q_values = rewards + runs * discount_factor * max_next_Q_values\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", @@ -1751,7 +1795,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 62, "metadata": {}, "outputs": [ { @@ -1764,11 +1808,11 @@ ], "source": [ "for episode in range(600):\n", - " obs = env.reset() \n", + " obs, info = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", - " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", - " if done:\n", + " obs, reward, done, info, truncated = play_one_step(env, obs, epsilon)\n", + " if done or truncated:\n", " break\n", "\n", " # extra code – displays debug info, stores data for the next figure, and\n", @@ -1800,7 +1844,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 63, "metadata": {}, "outputs": [ { @@ -1828,7 +1872,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 64, "metadata": { "scrolled": true, "tags": [] @@ -1855,17 +1899,9 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 65, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Episode: 600, Steps: 200, eps: 0.010" - ] - } - ], + "outputs": [], "source": [ "tf.random.set_seed(42)\n", "\n", @@ -1891,17 +1927,18 @@ "\n", "def training_step(batch_size):\n", " experiences = sample_experiences(batch_size)\n", - " states, actions, rewards, next_states, dones = experiences\n", + " states, actions, rewards, next_states, dones, truncateds = experiences\n", "\n", " #################### CHANGED SECTION ####################\n", - " next_Q_values = model.predict(next_states) # not target.predict(...)\n", + " next_Q_values = model.predict(next_states, verbose=0) # ≠ target.predict()\n", " best_next_actions = next_Q_values.argmax(axis=1)\n", " next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n", - " max_next_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n", + " max_next_Q_values = (target.predict(next_states, verbose=0) * next_mask\n", + " ).sum(axis=1)\n", " #########################################################\n", "\n", - " target_Q_values = (rewards +\n", - " (1 - dones) * discount_factor * max_next_Q_values)\n", + " runs = 1.0 - (dones | truncateds) # episode is not done or truncated\n", + " target_Q_values = rewards + runs * discount_factor * max_next_Q_values\n", " target_Q_values = target_Q_values.reshape(-1, 1)\n", " mask = tf.one_hot(actions, n_outputs)\n", " with tf.GradientTape() as tape:\n", @@ -1915,11 +1952,11 @@ "replay_buffer = deque(maxlen=2000)\n", "\n", "for episode in range(600):\n", - " obs = env.reset() \n", + " obs, info = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", - " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", - " if done:\n", + " obs, reward, done, info, truncated = play_one_step(env, obs, epsilon)\n", + " if done or truncated:\n", " break\n", "\n", " print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n", @@ -1939,7 +1976,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 66, "metadata": {}, "outputs": [ { @@ -1967,7 +2004,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": 67, "metadata": { "scrolled": true }, @@ -1986,7 +2023,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ @@ -2012,14 +2049,14 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Episode: 600, Steps: 190, eps: 0.010" + "Episode: 600, Steps: 137, eps: 0.010" ] } ], @@ -2040,11 +2077,11 @@ "best_score = 0\n", "\n", "for episode in range(600):\n", - " obs = env.reset() \n", + " obs, info = env.reset() \n", " for step in range(200):\n", " epsilon = max(1 - episode / 500, 0.01)\n", - " obs, reward, done, info = play_one_step(env, obs, epsilon)\n", - " if done:\n", + " obs, reward, done, info, truncated = play_one_step(env, obs, epsilon)\n", + " if done or truncated:\n", " break\n", "\n", " print(f\"\\rEpisode: {episode + 1}, Steps: {step + 1}, eps: {epsilon:.3f}\",\n", @@ -2064,7 +2101,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 70, "metadata": {}, "outputs": [ { @@ -2092,7 +2129,7 @@ }, { "cell_type": "code", - "execution_count": 69, + "execution_count": 71, "metadata": { "scrolled": true }, @@ -2111,7 +2148,7 @@ }, { "cell_type": "code", - "execution_count": 70, + "execution_count": 72, "metadata": {}, "outputs": [], "source": [ @@ -2170,11 +2207,11 @@ }, { "cell_type": "code", - "execution_count": 71, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ - "env = gym.make(\"LunarLander-v2\")" + "env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")" ] }, { @@ -2186,16 +2223,18 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "Box(-inf, inf, (8,), float32)" + "Box([-1.5 -1.5 -5. -5. -3.1415927 -5.\n", + " -0. -0. ], [1.5 1.5 5. 5. 3.1415927 5. 1.\n", + " 1. ], (8,), float32)" ] }, - "execution_count": 72, + "execution_count": 74, "metadata": {}, "output_type": "execute_result" } @@ -2206,7 +2245,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 75, "metadata": {}, "outputs": [ { @@ -2216,13 +2255,13 @@ " -0.05269805, 0. , 0. ], dtype=float32)" ] }, - "execution_count": 73, + "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "obs = env.reset(seed=42)\n", + "obs, info = env.reset(seed=42)\n", "obs" ] }, @@ -2246,7 +2285,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 76, "metadata": {}, "outputs": [ { @@ -2255,7 +2294,7 @@ "Discrete(4)" ] }, - "execution_count": 74, + "execution_count": 76, "metadata": {}, "output_type": "execute_result" } @@ -2284,7 +2323,7 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -2318,7 +2357,7 @@ }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ @@ -2329,8 +2368,8 @@ " action = tf.random.categorical(logits, num_samples=1)\n", " loss = tf.reduce_mean(loss_fn(action, probas))\n", " grads = tape.gradient(loss, model.trainable_variables)\n", - " obs, reward, done, info = env.step(action[0, 0].numpy())\n", - " return obs, reward, done, grads\n", + " obs, reward, done, info, truncated = env.step(action[0, 0].numpy())\n", + " return obs, reward, done, truncated, grads\n", "\n", "def lander_play_multiple_episodes(env, n_episodes, n_max_steps, model, loss_fn):\n", " all_rewards = []\n", @@ -2338,12 +2377,13 @@ " for episode in range(n_episodes):\n", " current_rewards = []\n", " current_grads = []\n", - " obs = env.reset()\n", + " obs, info = env.reset()\n", " for step in range(n_max_steps):\n", - " obs, reward, done, grads = lander_play_one_step(env, obs, model, loss_fn)\n", + " obs, reward, done, truncated, grads = lander_play_one_step(\n", + " env, obs, model, loss_fn)\n", " current_rewards.append(reward)\n", " current_grads.append(grads)\n", - " if done:\n", + " if done or truncated:\n", " break\n", " all_rewards.append(current_rewards)\n", " all_grads.append(current_grads)\n", @@ -2359,7 +2399,7 @@ }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 79, "metadata": {}, "outputs": [], "source": [ @@ -2388,7 +2428,7 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 80, "metadata": {}, "outputs": [], "source": [ @@ -2407,7 +2447,7 @@ }, { "cell_type": "code", - "execution_count": 79, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ @@ -2424,14 +2464,14 @@ }, { "cell_type": "code", - "execution_count": 80, + "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Iteration: 200/200, mean reward: 139.7 " + "Iteration: 200/200, mean reward: 167.1 " ] } ], @@ -2468,7 +2508,7 @@ }, { "cell_type": "code", - "execution_count": 81, + "execution_count": 83, "metadata": {}, "outputs": [ { @@ -2501,7 +2541,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 84, "metadata": {}, "outputs": [], "source": [ @@ -2510,14 +2550,14 @@ " env = gym.make(\"LunarLander-v2\")\n", " tf.random.set_seed(seed)\n", " np.random.seed(seed)\n", - " obs = env.reset(seed=seed)\n", + " obs, info = env.reset(seed=seed)\n", " for step in range(n_max_steps):\n", " frames.append(env.render(mode=\"rgb_array\"))\n", " probas = model(obs[np.newaxis])\n", " logits = tf.math.log(probas + tf.keras.backend.epsilon())\n", " action = tf.random.categorical(logits, num_samples=1)\n", - " obs, reward, done, info = env.step(action[0, 0].numpy())\n", - " if done:\n", + " obs, reward, done, truncated, info = env.step(action[0, 0].numpy())\n", + " if done or truncated:\n", " break\n", " env.close()\n", " return frames" @@ -2525,7 +2565,7 @@ }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 85, "metadata": {}, "outputs": [], "source": [ @@ -2598,7 +2638,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.9.10" } }, "nbformat": 4, diff --git a/apt.txt b/apt.txt index eccdbaa..533d3c2 100644 --- a/apt.txt +++ b/apt.txt @@ -10,6 +10,5 @@ sudo swig unzip xorg-dev -xvfb zip zlib1g-dev diff --git a/docker/Dockerfile b/docker/Dockerfile index 9d045ff..4fb6857 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -12,15 +12,13 @@ RUN apt-get update && apt-get install -y \ sudo \ unzip \ xorg-dev \ - xvfb \ zip \ zlib1g-dev \ && apt clean \ && rm -rf /var/lib/apt/lists/* COPY environment.yml /tmp/ -RUN echo ' - pyvirtualdisplay' >> /tmp/environment.yml \ - && conda env create -f /tmp/environment.yml \ +RUN conda env create -f /tmp/environment.yml \ && conda clean -afy \ && find /opt/conda/ -follow -type f -name '*.a' -delete \ && find /opt/conda/ -follow -type f -name '*.pyc' -delete \ diff --git a/docker/Dockerfile.gpu b/docker/Dockerfile.gpu index 71bb42a..a8dc73e 100644 --- a/docker/Dockerfile.gpu +++ b/docker/Dockerfile.gpu @@ -49,7 +49,6 @@ RUN apt-get update -q && apt-get install -q -y --no-install-recommends \ swig \ wget \ xorg-dev \ - xvfb \ zip \ zlib1g-dev \ build-essential \ @@ -129,8 +128,7 @@ RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-${CONDA_VERSION} # Now we're ready to create our conda environment COPY environment.yml /tmp/ -RUN echo ' - pyvirtualdisplay' >> /tmp/environment.yml \ - && conda env create -f /tmp/environment.yml \ +RUN conda env create -f /tmp/environment.yml \ && conda clean -afy \ && find /opt/conda/ -follow -type f -name '*.a' -delete \ && find /opt/conda/ -follow -type f -name '*.pyc' -delete \ diff --git a/environment.yml b/environment.yml index a9f6047..449f573 100644 --- a/environment.yml +++ b/environment.yml @@ -24,7 +24,6 @@ dependencies: - pyglet=1.5 # used only in chapter 18 to render environments - pyopengl=3.1 # used only in chapter 18 to render environments - python=3.10 # your beloved programming language! :) - #- pyvirtualdisplay=3.0 # used only in chapter 18 if on headless server - requests=2.28 # used only in chapter 19 for REST API queries - scikit-learn=1.1 # machine learning library - scipy=1.9 # scientific/technical computing library diff --git a/requirements.txt b/requirements.txt index ee1556c..0afdbfe 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,10 +59,6 @@ Pillow~=9.2.0 graphviz~=0.20.1 pyglet~=1.5.26 -#pyvirtualdisplay # needed in chapter 18, if on a headless server - # (i.e., without screen, e.g., Colab or VM) - - ##### Google Cloud Platform - used only in chapter 19 google-cloud-aiplatform~=1.17.0 google-cloud-storage~=2.5.0