Fix CartPole rendering issue in chapter 16
parent
a24ad685ec
commit
f6298cb03b
|
@ -30,7 +30,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 42,
|
||||
"execution_count": 1,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -97,7 +97,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next we will load the CartPole environment, version 0. This environment contains a cart that can move left and right, and a pole standing vertically on top of it. Your agent can apply some force to the cart, pushing it left or right: its goal is to control it so that the pole remains upright."
|
||||
"Next we will load the MsPacman environment, version 0."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -108,7 +108,7 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('CartPole-v0')"
|
||||
"env = gym.make('MsPacman-v0')"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -133,7 +133,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Observations vary depending on the environment. In this case it returns a 1D NumPy array containing 4 floats, but in other cases it will return different types of objects (eg. for Atari games it returns an image of the screen, as we will see below). The 4 floats represent the position of the cart, its velocity, the angle of the pole and its angular velocity."
|
||||
"Observations vary depending on the environment. In this case it is an RGB image represented as a 3D NumPy array of shape [width, height, channels] (with 3 channels: Red, Green and Blue). In other environments it may return different objects, as we will see later."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -144,25 +144,21 @@
|
|||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs"
|
||||
"obs.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment). In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array.\n",
|
||||
"\n",
|
||||
"Note: unfortunately some environments (including the CartPole) draw on your screen even if you specify the `rgb_array` mode, opening up a separate window. In general you can safely ignore it. However, if Jupyter is running on a headless server (ie. without a screen), or if you just can't stand having a window pop up for no good reason, you can use a fake X server like Xvfb. You need to install Xvfb and start Jupyter using the `xvfb-run` command (if you are running this notebook using binder, this has been taken care of for you):\n",
|
||||
"\n",
|
||||
" $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook"
|
||||
"An environment can be visualized by calling its `render()` method, and you can pick the rendering mode (the rendering options depend on the environment). In this example we will set `mode=\"rgb_array\"` to get an image of the environment as a NumPy array:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -195,7 +191,14 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once you have finished playing with an environment, you should close it to free up resources:"
|
||||
"Welcome back to the 1980s! :)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this environment, the rendered image is simply equal to the observation (but in many environments this is not the case):"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -205,127 +208,10 @@
|
|||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.close()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's try MsPacman! This requires the [Atari dependencies](https://github.com/openai/gym#atari).\n",
|
||||
"\n",
|
||||
" pip install --upgrade \"gym[atari]\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = gym.make('MsPacman-v0')\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Note that the observation is now a 3D numpy array of shape [width, height, channels] representing an image:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"type(obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs.shape"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The following command renders the environment to a Numpy array. Luckily, the Atari environments do not open separate windows when you use the `\"rgb_array\"` mode. :)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = env.render(mode=\"rgb_array\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this case, the rendering is simply equal to the observation:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"(img == obs).all()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's plot it. Welcome back to the 1980s!"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"fig = plt.figure(figsize=(5,4))\n",
|
||||
"plt.imshow(img)\n",
|
||||
"plt.axis(\"off\")\n",
|
||||
"plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -335,13 +221,14 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def plot_environment(env, figsize=(5,4)):\n",
|
||||
" plt.close() # or else nbagg sometimes plots in the previous cell\n",
|
||||
" plt.figure(figsize=figsize)\n",
|
||||
" img = env.render(mode=\"rgb_array\")\n",
|
||||
" plt.imshow(img)\n",
|
||||
|
@ -358,7 +245,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -383,7 +270,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
|
@ -405,7 +292,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -423,7 +310,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
|
@ -441,7 +328,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -459,7 +346,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -477,7 +364,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -495,7 +382,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -513,7 +400,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -544,7 +431,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -554,7 +441,8 @@
|
|||
" patch.set_data(frames[num])\n",
|
||||
" return patch,\n",
|
||||
"\n",
|
||||
"def plot_animation(frames, repeat=False, interval=50):\n",
|
||||
"def plot_animation(frames, repeat=False, interval=40):\n",
|
||||
" plt.close() # or else nbagg sometimes plots in the previous cell\n",
|
||||
" fig = plt.figure()\n",
|
||||
" patch = plt.imshow(frames[0])\n",
|
||||
" plt.axis('off')\n",
|
||||
|
@ -563,7 +451,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -577,14 +465,14 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Ok, let's go back to the CartPole environment, it is much simpler to start with. But don't forget to close the MsPacman environment first:"
|
||||
"Once you have finished playing with an environment, you should close it to free up resources:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -595,19 +483,26 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A simple hard-coded policy"
|
||||
"To code our first learning agent, we will be using a simpler environment: the Cart-Pole. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create and initialize the CartPole environment again:"
|
||||
"# A simple environment: the Cart-Pole"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The Cart-Pole is a very simple environment composed of a cart that can move left or right, and pole placed vertically on top of it. The agent must move the cart left or right to keep the pole upright."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -618,7 +513,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 24,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
|
@ -627,6 +522,105 @@
|
|||
"obs = env.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The observation is a 1D NumPy array composed of 4 floats: they represent the cart's horizontal position, its velocity, the angle of the pole (O = vertical), and the angular velocity. Let's render the environment... unfortunately we need to fix an annoying rendering issue first."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Fixing the rendering issue"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Some environments (including the CartPole) require access to your display, which opens up a separate window, even if you specify the `rgb_array` mode. In general you can safely ignore that window. However, if Jupyter is running on a headless server (ie. without a screen) it will raise an exception. One way to avoid this is to install a fake X server like Xvfb. You can start Jupyter using the `xvfb-run` command:\n",
|
||||
"\n",
|
||||
" $ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n",
|
||||
"\n",
|
||||
"This does not seem to be possible using binder, so unfortunately we cannot use OpenAI gym's rendering function, we need to define our own."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from PIL import Image, ImageDraw\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" from pyglet.gl import gl_info\n",
|
||||
" openai_cart_pole_rendering = True # no problem, let's use OpenAI gym's rendering function\n",
|
||||
"except ImportError:\n",
|
||||
" openai_cart_pole_rendering = False # probably running on binder, let's use our own rendering function\n",
|
||||
"\n",
|
||||
"def render_cart_pole(env, obs):\n",
|
||||
" if openai_cart_pole_rendering:\n",
|
||||
" # use OpenAI gym's rendering function\n",
|
||||
" return env.render(mode=\"rgb_array\")\n",
|
||||
" else:\n",
|
||||
" # basic rendering for the cart pole environment if OpenAI can't render it\n",
|
||||
" img_w = 100\n",
|
||||
" img_h = 50\n",
|
||||
" cart_w = 20\n",
|
||||
" pole_len = 30\n",
|
||||
" x_width = 2\n",
|
||||
" max_ang = 0.2\n",
|
||||
" bg_col = (255, 255, 255)\n",
|
||||
" cart_col = 0x000000 # Blue Green Red\n",
|
||||
" pole_col = 0x0000FF # Blue Green Red\n",
|
||||
"\n",
|
||||
" pos, vel, ang, ang_vel = obs\n",
|
||||
" img = Image.new('RGB', (img_w, img_h), bg_col)\n",
|
||||
" draw = ImageDraw.Draw(img)\n",
|
||||
" cart_x = pos * img_w // x_width + img_w // x_width\n",
|
||||
" cart_y = img_h * 95 // 100\n",
|
||||
" top_pole_x = cart_x + pole_len * np.sin(ang)\n",
|
||||
" top_pole_y = cart_y - pole_len * np.cos(ang)\n",
|
||||
" pole_col = int(np.minimum(np.abs(ang / max_ang), 1) * pole_col)\n",
|
||||
" draw.line((cart_x, cart_y, top_pole_x, top_pole_y), fill=pole_col) # draw pole\n",
|
||||
" draw.line((cart_x - cart_w // 2, cart_y, cart_x + cart_w // 2, cart_y), fill=cart_col) # draw cart\n",
|
||||
" return np.array(img)\n",
|
||||
"\n",
|
||||
"def plot_cart_pole(env, obs):\n",
|
||||
" plt.close() # or else nbagg sometimes plots in the previous cell\n",
|
||||
" img = render_cart_pole(env, obs)\n",
|
||||
" plt.imshow(img)\n",
|
||||
" plt.axis(\"off\")\n",
|
||||
" plt.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_cart_pole(env, obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -636,7 +630,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 33,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -654,12 +648,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 35,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"obs = env.reset()\n",
|
||||
"while True:\n",
|
||||
" obs, reward, done, info = env.step(0)\n",
|
||||
" if done:\n",
|
||||
|
@ -668,13 +663,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 36,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_environment(env)"
|
||||
"plot_cart_pole(env, obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -686,14 +681,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 37,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"\n",
|
||||
"obs = env.reset()\n",
|
||||
"while True:\n",
|
||||
" obs, reward, done, info = env.step(1)\n",
|
||||
" if done:\n",
|
||||
|
@ -702,13 +696,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 38,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"plot_environment(env)"
|
||||
"plot_cart_pole(env, obs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -718,6 +712,13 @@
|
|||
"Looks like it's doing what we're telling it to do. Now how can we make the poll remain upright? We will need to define a _policy_ for that. This is the strategy that the agent will use to select an action at each step. It can use all the past actions and observations to decide what to do."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# A simple hard-coded policy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
@ -727,7 +728,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 39,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -740,9 +741,9 @@
|
|||
"\n",
|
||||
"obs = env.reset()\n",
|
||||
"for iteration in range(n_max_iterations):\n",
|
||||
" img = env.render(mode=\"rgb_array\")\n",
|
||||
" img = render_cart_pole(env, obs)\n",
|
||||
" frames.append(img)\n",
|
||||
" \n",
|
||||
"\n",
|
||||
" # hard-coded policy\n",
|
||||
" position, velocity, angle, angular_velocity = obs\n",
|
||||
" if angle < 0:\n",
|
||||
|
@ -757,7 +758,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 40,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -790,7 +791,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 41,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -829,7 +830,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 42,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
|
@ -863,15 +864,6 @@
|
|||
"source": [
|
||||
"Coming soon..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
|
Loading…
Reference in New Issue