Update ReplayMemory class and add a short description

main
Aurélien Geron 2018-05-09 16:54:17 +02:00
parent c6b424931a
commit 3af8bcd1bb
1 changed files with 69 additions and 52 deletions

View File

@ -1535,7 +1535,7 @@
" X_action = tf.placeholder(tf.int32, shape=[None])\n",
" y = tf.placeholder(tf.float32, shape=[None, 1])\n",
" q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n",
" axis=1, keep_dims=True)\n",
" axis=1, keepdims=True)\n",
" error = tf.abs(y - q_value)\n",
" clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n",
" linear_error = 2 * (error - clipped_error)\n",
@ -1556,43 +1556,58 @@
"Note: in the first version of the book, the loss function was simply the squared error between the target Q-Values (`y`) and the estimated Q-Values (`q_value`). However, because the experiences are very noisy, it is better to use a quadratic loss only for small errors (below 1.0) and a linear loss (twice the absolute error) for larger errors, which is what the code above computes. This way large errors don't push the model parameters around as much. Note that we also tweaked some hyperparameters (using a smaller learning rate, and using Nesterov Accelerated Gradients rather than Adam optimization, since adaptive gradient algorithms may sometimes be bad, according to this [paper](https://arxiv.org/abs/1705.08292)). We also tweaked a few other hyperparameters below (a larger replay memory, longer decay for the $\\epsilon$-greedy policy, larger discount rate, less frequent copies of the online DQN to the target DQN, etc.)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We use this `ReplayMemory` class instead of a `deque` because it is much faster for random access (thanks to @NileshPS who contributed it). Moreover, we default to sampling with replacement, which is much faster than sampling without replacement for large replay memories."
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 63,
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"\n",
"\n",
"class ReplayMemory:\n",
" \n",
" def __init__(self, size):\n",
" self.size = size\n",
" self.buf= [None] * self.size\n",
" self.index, self.count = 0, 0\n",
" def __init__(self, maxlen):\n",
" self.maxlen = maxlen\n",
" self.buf = np.empty(shape=maxlen, dtype=np.object)\n",
" self.index = 0\n",
" self.length = 0\n",
" \n",
" def append(self, data):\n",
" self.buf[self.index] = data\n",
" self.count = min(self.count + 1, self.size)\n",
" self.index = (self.index + 1) % self.size\n",
" \n",
" def __getitem__(self, idx):\n",
" return self.buf[idx]\n",
" \n",
" def __len__(self):\n",
" return self.count\n",
"\n",
"\n",
" self.length = min(self.length + 1, self.maxlen)\n",
" self.index = (self.index + 1) % self.maxlen\n",
" \n",
" def sample(self, batch_size, with_replacement=True):\n",
" if with_replacement:\n",
" indices = np.random.randint(self.length, size=batch_size) # faster\n",
" else:\n",
" indices = np.random.permutation(self.length)[:batch_size]\n",
" return self.buf[indices]"
]
},
{
"cell_type": "code",
"execution_count": 64,
"metadata": {},
"outputs": [],
"source": [
"replay_memory_size = 500000\n",
"replay_memory = ReplayMemory(replay_memory_size)\n",
"\n",
"replay_memory = ReplayMemory(replay_memory_size)"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [],
"source": [
"def sample_memories(batch_size):\n",
" indices = np.random.permutation(len(replay_memory))[:batch_size]\n",
" cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
" for idx in indices:\n",
" memory = replay_memory[idx]\n",
" for memory in replay_memory.sample(batch_size):\n",
" for col, value in zip(cols, memory):\n",
" col.append(value)\n",
" cols = [np.array(col) for col in cols]\n",
@ -1601,7 +1616,7 @@
},
{
"cell_type": "code",
"execution_count": 64,
"execution_count": 66,
"metadata": {},
"outputs": [],
"source": [
@ -1619,7 +1634,7 @@
},
{
"cell_type": "code",
"execution_count": 65,
"execution_count": 67,
"metadata": {},
"outputs": [],
"source": [
@ -1645,7 +1660,7 @@
},
{
"cell_type": "code",
"execution_count": 66,
"execution_count": 68,
"metadata": {},
"outputs": [],
"source": [
@ -1664,7 +1679,7 @@
},
{
"cell_type": "code",
"execution_count": 67,
"execution_count": 69,
"metadata": {},
"outputs": [],
"source": [
@ -1740,7 +1755,7 @@
},
{
"cell_type": "code",
"execution_count": 68,
"execution_count": 70,
"metadata": {},
"outputs": [],
"source": [
@ -1770,7 +1785,7 @@
},
{
"cell_type": "code",
"execution_count": 69,
"execution_count": 71,
"metadata": {
"scrolled": true
},
@ -1802,7 +1817,7 @@
},
{
"cell_type": "code",
"execution_count": 71,
"execution_count": 73,
"metadata": {
"collapsed": true
},
@ -1815,7 +1830,7 @@
},
{
"cell_type": "code",
"execution_count": 72,
"execution_count": 74,
"metadata": {},
"outputs": [],
"source": [
@ -1829,7 +1844,7 @@
},
{
"cell_type": "code",
"execution_count": 73,
"execution_count": 75,
"metadata": {},
"outputs": [],
"source": [
@ -1854,10 +1869,12 @@
},
{
"cell_type": "code",
"execution_count": 74,
"execution_count": 76,
"metadata": {},
"outputs": [],
"source": [
"from collections import deque\n",
"\n",
"def combine_observations_multichannel(preprocessed_observations):\n",
" return np.array(preprocessed_observations).transpose([1, 2, 0])\n",
"\n",
@ -1877,7 +1894,7 @@
},
{
"cell_type": "code",
"execution_count": 75,
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
@ -1933,7 +1950,7 @@
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
@ -1942,7 +1959,7 @@
},
{
"cell_type": "code",
"execution_count": 77,
"execution_count": 79,
"metadata": {
"scrolled": true
},
@ -1965,7 +1982,7 @@
},
{
"cell_type": "code",
"execution_count": 78,
"execution_count": 80,
"metadata": {},
"outputs": [],
"source": [
@ -1974,7 +1991,7 @@
},
{
"cell_type": "code",
"execution_count": 79,
"execution_count": 81,
"metadata": {},
"outputs": [],
"source": [
@ -1983,7 +2000,7 @@
},
{
"cell_type": "code",
"execution_count": 80,
"execution_count": 82,
"metadata": {},
"outputs": [],
"source": [
@ -1994,7 +2011,7 @@
},
{
"cell_type": "code",
"execution_count": 81,
"execution_count": 83,
"metadata": {},
"outputs": [],
"source": [
@ -2010,7 +2027,7 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 84,
"metadata": {},
"outputs": [],
"source": [
@ -2019,7 +2036,7 @@
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 85,
"metadata": {},
"outputs": [],
"source": [
@ -2028,7 +2045,7 @@
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 86,
"metadata": {},
"outputs": [],
"source": [
@ -2044,7 +2061,7 @@
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 87,
"metadata": {},
"outputs": [],
"source": [
@ -2053,7 +2070,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 88,
"metadata": {},
"outputs": [],
"source": [
@ -2064,7 +2081,7 @@
},
{
"cell_type": "code",
"execution_count": 87,
"execution_count": 89,
"metadata": {},
"outputs": [],
"source": [
@ -2117,7 +2134,7 @@
},
{
"cell_type": "code",
"execution_count": 88,
"execution_count": 90,
"metadata": {},
"outputs": [],
"source": [
@ -2144,7 +2161,7 @@
},
{
"cell_type": "code",
"execution_count": 89,
"execution_count": 91,
"metadata": {},
"outputs": [],
"source": [
@ -2162,7 +2179,7 @@
},
{
"cell_type": "code",
"execution_count": 90,
"execution_count": 92,
"metadata": {},
"outputs": [],
"source": [
@ -2208,7 +2225,7 @@
},
{
"cell_type": "code",
"execution_count": 91,
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [