Add reinforcement learning notebook (in progress)
parent
25da041cf4
commit
05ffb99e10
|
@ -0,0 +1,908 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Chapter 16 – Reinforcement Learning**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This notebook contains all the sample code and solutions to the exercices in chapter 16."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"First, let's make sure this notebook works well in both python 2 and 3, import a few common modules, ensure MatplotLib plots figures inline and prepare a function to save the figures:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# To support both python 2 and python 3\n",
|
||||
"from __future__ import division, print_function, unicode_literals\n",
|
||||
"\n",
|
||||
"# Common imports\n",
|
||||
"import numpy as np\n",
|
||||
"import numpy.random as rnd\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"# to make this notebook's output stable across runs\n",
|
||||
"rnd.seed(42)\n",
|
||||
"\n",
|
||||
"# To plot pretty figures and animations\n",
|
||||
"%matplotlib nbagg\n",
|
||||
"import matplotlib\n",
|
||||
"import matplotlib.animation as animation\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"plt.rcParams['axes.labelsize'] = 14\n",
|
||||
"plt.rcParams['xtick.labelsize'] = 12\n",
|
||||
"plt.rcParams['ytick.labelsize'] = 12\n",
|
||||
"\n",
|
||||
"# Where to save the figures\n",
|
||||
"PROJECT_ROOT_DIR = \".\"\n",
|
||||
"CHAPTER_ID = \"rl\"\n",
|
||||
"\n",
|
||||
"def save_fig(fig_id, tight_layout=True):\n",
|
||||
" path = os.path.join(PROJECT_ROOT_DIR, \"images\", CHAPTER_ID, fig_id + \".png\")\n",
|
||||
" print(\"Saving figure\", fig_id)\n",
|
||||
" if tight_layout:\n",
|
||||
" plt.tight_layout()\n",
|
||||
" plt.savefig(path, format='png', dpi=300)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Introduction to OpenAI gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this notebook we will be using [OpenAI gym](https://gym.openai.com/), a great toolkit for developing and comparing Reinforcement Learning algorithms. It provides many environments for your learning *agents* to interact with. Let's start by importing `gym`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import gym"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next we will load the CartPole environment, version 0. This environment contains a cart that can move left and right, and a pole standing vertically on top of it. Your agent can apply some force to the cart, pushing it left or right: its goal is to control it so that the pole remains upright."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('CartPole-v0')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's initialize the environment by calling is `reset()` method. This returns an observation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Observations vary depending on the environment. In this case it returns a 1D NumPy array containing 4 floats, but in other cases it will return different types of objects (eg. for Atari games it returns an image of the screen, as we will see below). The 4 floats represent the position of the cart, its velocity, the angle of the pole and its angular velocity."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment). In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array.\n",
|
||||
"\n",
|
||||
"Note: unfortunately some environments (including the CartPole) draw on your screen even if you specify the `rgb_array` mode, opening up a separate window. In general you can safely ignore it. However, if Jupyter is running on a headless server (ie. without a screen), or if you just can't stand having a window pop up for no good reason, you can use a fake X server like Xvfb. You need to install Xvfb and start Jupyter using the `xvfb-run` command (if you are running this notebook using binder, this has been taken care of for you):\n",
|
||||
"\n",
|
||||
" $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = env.render(mode=\"rgb_array\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's plot this image:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plt.figure(figsize=(5,4))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once you have finished playing with an environment, you should close it to free up resources:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's try MsPacman! This requires the [Atari dependencies](https://github.com/openai/gym#atari).\n",
|
||||
"\n",
|
||||
" pip install --upgrade \"gym[atari]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('MsPacman-v0')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that the observation is now a 3D numpy array of shape [width, height, channels] representing an image:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following command renders the environment to a Numpy array. Luckily, the Atari environments do not open separate windows when you use the `\"rgb_array\"` mode. :)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = env.render(mode=\"rgb_array\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this case, the rendering is simply equal to the observation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(img == obs).all()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's plot it. Welcome back to the 1980s!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(5,4))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a little helper function to plot an environment:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_environment(env, figsize=(5,4)):\n",
|
||||
" plt.figure(figsize=figsize)\n",
|
||||
" img = env.render(mode=\"rgb_array\")\n",
|
||||
" plt.imshow(img)\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's see how to interact with an environment. Your agent will need to select an action from an \"action space\" (the set of possible actions). Let's see what this environment's action space looks like:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.action_space"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"`Discrete(9)` means that the possible actions are integers 0 through 8, which represents the 9 possible positions of the joystick (0=center, 1=up, 2=right, 3=left, 4=down, 5=upper-right, 6=upper-left, 7=lower-right, 8=lower-left)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next we need to tell the environment which action to play, and it will compute the next step of the game. Let's go left for 110 steps, then lower left for 40 steps:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"for step in range(110):\n",
|
||||
" env.step(3) #left\n",
|
||||
"for step in range(40):\n",
|
||||
" env.step(8) #lower-left"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Where are we now?"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_environment(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `step()` function actually returns several important objects:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs, reward, done, info = env.step(0)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The observation tells the agent what the environment looks like, as discussed earlier. This is a 210x160 RGB image:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The environment also tells the agent how much reward it got during the last step:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"reward"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"When the game is over, the environment returns `done=True`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"done"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, `info` is an environment-specific dictionary that can provide some extra information about the internal state of the environment. This is useful for debugging, but your agent should not use this information for learning (it would be cheating)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"info"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's play one full game (with 3 lives), by moving in random directions for 10 steps at a time, recording each frame:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"frames = []\n",
|
||||
"\n",
|
||||
"n_max_iterations = 1000\n",
|
||||
"n_change_steps = 10\n",
|
||||
"\n",
|
||||
"obs = env.reset()\n",
|
||||
"for iteration in range(n_max_iterations):\n",
|
||||
" img = env.render(mode=\"rgb_array\")\n",
|
||||
" frames.append(img)\n",
|
||||
" if iteration % n_change_steps == 0:\n",
|
||||
" action = env.action_space.sample() # play randomly\n",
|
||||
" obs, reward, done, info = env.step(action)\n",
|
||||
" if done:\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now show the animation (it's a bit jittery within Jupyter):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def update_scene(num, frames, patch):\n",
|
||||
" patch.set_data(frames[num])\n",
|
||||
" return patch,\n",
|
||||
"\n",
|
||||
"def plot_animation(frames, repeat=False, interval=50):\n",
|
||||
" fig = plt.figure()\n",
|
||||
" patch = plt.imshow(frames[0])\n",
|
||||
" plt.axis('off')\n",
|
||||
" return animation.FuncAnimation(fig, update_scene, fargs=(frames, patch), frames=len(frames), repeat=repeat, interval=interval)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video = plot_animation(frames)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, let's go back to the CartPole environment, it is much simpler to start with. But don't forget to close the MsPacman environment first:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A simple hard-coded policy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create and initialize the CartPole environment again:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make(\"CartPole-v0\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's look at the action space:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.action_space"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Yep, just two possible actions: accelerate towards the left or towards the right. Let's push the cart left until the pole falls:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"while True:\n",
|
||||
" obs, reward, done, info = env.step(0)\n",
|
||||
" if done:\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_environment(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice that the game is over when the pole tilts too much, not when it actually falls. Now let's reset the environment and push the cart to right instead:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"\n",
|
||||
"while True:\n",
|
||||
" obs, reward, done, info = env.step(1)\n",
|
||||
" if done:\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_environment(env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Looks like it's doing what we're telling it to do. Now how can we make the poll remain upright? We will need to define a _policy_ for that. This is the strategy that the agent will use to select an action at each step. It can use all the past actions and observations to decide what to do."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's hard code a simple strategy: if the pole is tilting to the left, then push the cart to the left, and _vice versa_. Let's see if that works:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"frames = []\n",
|
||||
"\n",
|
||||
"n_max_iterations = 1000\n",
|
||||
"n_change_steps = 10\n",
|
||||
"\n",
|
||||
"obs = env.reset()\n",
|
||||
"for iteration in range(n_max_iterations):\n",
|
||||
" img = env.render(mode=\"rgb_array\")\n",
|
||||
" frames.append(img)\n",
|
||||
" \n",
|
||||
" # hard-coded policy\n",
|
||||
" position, velocity, angle, angular_velocity = obs\n",
|
||||
" if angle < 0:\n",
|
||||
" action = 0\n",
|
||||
" else:\n",
|
||||
" action = 1\n",
|
||||
"\n",
|
||||
" obs, reward, done, info = env.step(action)\n",
|
||||
" if done:\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"video = plot_animation(frames)\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Nope, the system is unstable and after just a few wobbles, the pole ends up too tilted: game over. We will need to be smarter than that!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Policy Gradients"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a neural network that will take the observations as inputs, and output the action to take. More precisely, it will output a probability for each action, and we will sample an action based on those probabilities. For example, if it says that the probability of pushing left should be 70%, and the probability of pushing right should be 30%, then we will pick a random number between 0 and 1 and if it is lower than 0.7 we will push left, or else we will push right. This approach lets the agent find the right balance between exploring new actions and exploiting the actions that are known to work well. Suppose you go to the same restaurant every week, and the first time you really enjoyed the caesar salad, you could order the same thing every week and be guaranteed to enjoy your meal. But you may be missing out on another great dish. Once in a while, you should try out something new."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tensorflow as tf\n",
|
||||
"\n",
|
||||
"from tensorflow.contrib.layers import fully_connected\n",
|
||||
"\n",
|
||||
"n_inputs = 4\n",
|
||||
"n_hidden = 4\n",
|
||||
"n_outputs = 1\n",
|
||||
"\n",
|
||||
"learning_rate=0.01\n",
|
||||
"\n",
|
||||
"X = tf.placeholder(tf.float32, shape=[None, n_inputs])\n",
|
||||
"y = tf.placeholder(tf.float32, shape=[None, n_outputs])\n",
|
||||
"hidden = fully_connected(X, n_hidden, activation_fn=tf.nn.elu)\n",
|
||||
"logits = fully_connected(hidden, n_outputs, activation_fn=None)\n",
|
||||
"outputs = tf.nn.softmax(logits)\n",
|
||||
"cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits, y)\n",
|
||||
"optimizer = tf.train.AdamOptimizer(learning_rate)\n",
|
||||
"training_op = optimizer.minimize(cross_entropy)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now to train this network we will need to feed the input batches `X` and the targets `y`. The inputs are easy enough, these will be the observations.\n",
|
||||
"\n",
|
||||
"_Note_: in this particular environment, the past actions and observations can safely be ignored, since you can observe the environment's full state. If there were some hidden state then you may need to consider all past actions and observations in order to try to infer the hidden state of the environment. For example, if the environment only revealed the position of the cart but not its velocity, you would have to consider not only the current observation but also the previous observation in order to estimate the current velocity. Another example is if the observations are noisy: you may want to use the past few observations to estimate the most likely current state. Our problem is thus as simple as can be: the current observation is noise-free and contains the environment's full state.\n",
|
||||
"\n",
|
||||
"But what about the labels? How can we tell what the target probabilities should be? One option is to let this policy network play the game say 100 times. Then rank the games according to the total reward they get. The actions taken during the best games were good, on average, so they should be made a bit more likely, while the actions taken during the worst games were bad, on average, so they should be made less likely. Of course, perhaps the policy network made a few good moves during a very bad game, and unfortunately these good moves will be made less likely, but that's ok because if we repeat the process many times, after a while the good moves should on average get more and more likely, and the bad moves should get less and less likely. A good basketball player sometimes plays in a really bad team: this obviously damages his reputation, but if he stars in a sufficient number of movies, overall his reputation should correspond to his talent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()\n",
|
||||
"while True:\n",
|
||||
" obs, reward, done, info = env.step(env.action_space.sample())\n",
|
||||
" print(reward)\n",
|
||||
" if done:\n",
|
||||
" break"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Work in progress – more content coming soon...**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise solutions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Coming soon..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 2",
|
||||
"language": "python",
|
||||
"name": "python2"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.11"
|
||||
},
|
||||
"nav_menu": {},
|
||||
"toc": {
|
||||
"navigate_menu": true,
|
||||
"number_sections": true,
|
||||
"sideBar": true,
|
||||
"threshold": 6,
|
||||
"toc_cell": false,
|
||||
"toc_section_display": "block",
|
||||
"toc_window_display": false
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
|
@ -38,7 +38,7 @@
|
|||
"13. [Convolutional Neural Networks](13_convolutional_neural_networks.ipynb)\n",
|
||||
"14. [Recurrent Neural Networks](14_recurrent_neural_networks.ipynb)\n",
|
||||
"15. [Autoencoders](15_autoencoders.ipynb)\n",
|
||||
"16. Reinforcement Learning (coming soon)\n",
|
||||
"16. [Reinforcement Learning](16_reinforcement_learning.ipynb)\n",
|
||||
"\n",
|
||||
"## Scientific Python tutorials\n",
|
||||
"* [NumPy](tools_numpy.ipynb)\n",
|
||||
|
|
Loading…
Reference in New Issue