Use int8 instead of float64 to represent pixel values: divides RAM footprint by 8

main
Aurélien Geron 2017-11-09 13:17:24 +01:00
parent 2a02668e5e
commit 7686839b36
1 changed files with 42 additions and 97 deletions

View File

@ -31,9 +31,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# To support both python 2 and python 3\n", "# To support both python 2 and python 3\n",
@ -95,9 +93,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import gym" "import gym"
@ -129,9 +125,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"obs = env.reset()" "obs = env.reset()"
@ -163,9 +157,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"img = env.render(mode=\"rgb_array\")" "img = env.render(mode=\"rgb_array\")"
@ -226,9 +218,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 9,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def plot_environment(env, figsize=(5,4)):\n", "def plot_environment(env, figsize=(5,4)):\n",
@ -273,9 +263,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 11,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"env.reset()\n", "env.reset()\n",
@ -311,9 +299,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 13,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"obs, reward, done, info = env.step(0)" "obs, reward, done, info = env.step(0)"
@ -393,9 +379,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 18,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"frames = []\n", "frames = []\n",
@ -424,9 +408,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 19,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def update_scene(num, frames, patch):\n", "def update_scene(num, frames, patch):\n",
@ -461,9 +443,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 21,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"env.close()" "env.close()"
@ -502,9 +482,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 23,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"obs = env.reset()" "obs = env.reset()"
@ -547,9 +525,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25, "execution_count": 25,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from PIL import Image, ImageDraw\n", "from PIL import Image, ImageDraw\n",
@ -633,9 +609,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 28,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"obs = env.reset()\n", "obs = env.reset()\n",
@ -677,9 +651,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 31, "execution_count": 31,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"obs = env.reset()\n", "obs = env.reset()\n",
@ -722,9 +694,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 33, "execution_count": 33,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"frames = []\n", "frames = []\n",
@ -795,9 +765,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 35, "execution_count": 35,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import tensorflow as tf\n", "import tensorflow as tf\n",
@ -846,9 +814,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 36, "execution_count": 36,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"n_max_steps = 1000\n", "n_max_steps = 1000\n",
@ -895,9 +861,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 38, "execution_count": 38,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import tensorflow as tf\n", "import tensorflow as tf\n",
@ -965,9 +929,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 40, "execution_count": 40,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def render_policy_net(model_path, action, X, n_max_steps = 1000):\n", "def render_policy_net(model_path, action, X, n_max_steps = 1000):\n",
@ -1024,9 +986,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 42,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import tensorflow as tf\n", "import tensorflow as tf\n",
@ -1069,9 +1029,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 43,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def discount_rewards(rewards, discount_rate):\n", "def discount_rewards(rewards, discount_rate):\n",
@ -1157,9 +1115,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 47, "execution_count": 47,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"env.close()" "env.close()"
@ -1309,9 +1265,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 51, "execution_count": 51,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"n_states = 3\n", "n_states = 3\n",
@ -1336,9 +1290,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 52, "execution_count": 52,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"def optimal_policy(state):\n", "def optimal_policy(state):\n",
@ -1439,23 +1391,28 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 57, "execution_count": 57,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"mspacman_color = np.array([210, 164, 74]).mean()\n", "mspacman_color = 210 + 164 + 74\n",
"\n", "\n",
"def preprocess_observation(obs):\n", "def preprocess_observation(obs):\n",
" img = obs[1:176:2, ::2] # crop and downsize\n", " img = obs[1:176:2, ::2] # crop and downsize\n",
" img = img.mean(axis=2) # to greyscale\n", " img = img.sum(axis=2) # to greyscale\n",
" img[img==mspacman_color] = 0 # Improve contrast\n", " img[img==mspacman_color] = 0 # Improve contrast\n",
" img = (img - 128) / 128 - 1 # normalize from -1. to 1.\n", " img = (img // 3 - 128).astype(np.int8) # normalize from -128 to 127\n",
" return img.reshape(88, 80, 1)\n", " return img.reshape(88, 80, 1)\n",
"\n", "\n",
"img = preprocess_observation(obs)" "img = preprocess_observation(obs)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note: the `preprocess_observation()` function is slightly different from the one in the book: instead of representing pixels as 64-bit floats from -1.0 to 1.0, it represents them as 8-bit integers from -128 to 127. The benefit is that the replay memory will take up about 6.5 GB of RAM instead of 52 GB. The reduced precision has no impact on training."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 58, "execution_count": 58,
@ -1498,9 +1455,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 59, "execution_count": 59,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"reset_graph()\n", "reset_graph()\n",
@ -1545,9 +1500,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 60, "execution_count": 60,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n", "X_state = tf.placeholder(tf.float32, shape=[None, input_height, input_width,\n",
@ -1572,9 +1525,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 62, "execution_count": 62,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"learning_rate = 0.001\n", "learning_rate = 0.001\n",
@ -1608,9 +1559,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 63,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"from collections import deque\n", "from collections import deque\n",
@ -1632,9 +1581,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 64, "execution_count": 64,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"eps_min = 0.1\n", "eps_min = 0.1\n",
@ -1678,9 +1625,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 66, "execution_count": 66,
"metadata": { "metadata": {},
"collapsed": true
},
"outputs": [], "outputs": [],
"source": [ "source": [
"loss_val = np.infty\n", "loss_val = np.infty\n",
@ -1970,7 +1915,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.5.2" "version": "3.6.3"
}, },
"nav_menu": {}, "nav_menu": {},
"toc": { "toc": {