Update ReplayMemory class and add a short description
parent
c6b424931a
commit
3af8bcd1bb
|
@ -1535,7 +1535,7 @@
|
||||||
" X_action = tf.placeholder(tf.int32, shape=[None])\n",
|
" X_action = tf.placeholder(tf.int32, shape=[None])\n",
|
||||||
" y = tf.placeholder(tf.float32, shape=[None, 1])\n",
|
" y = tf.placeholder(tf.float32, shape=[None, 1])\n",
|
||||||
" q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n",
|
" q_value = tf.reduce_sum(online_q_values * tf.one_hot(X_action, n_outputs),\n",
|
||||||
" axis=1, keep_dims=True)\n",
|
" axis=1, keepdims=True)\n",
|
||||||
" error = tf.abs(y - q_value)\n",
|
" error = tf.abs(y - q_value)\n",
|
||||||
" clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n",
|
" clipped_error = tf.clip_by_value(error, 0.0, 1.0)\n",
|
||||||
" linear_error = 2 * (error - clipped_error)\n",
|
" linear_error = 2 * (error - clipped_error)\n",
|
||||||
|
@ -1556,43 +1556,58 @@
|
||||||
"Note: in the first version of the book, the loss function was simply the squared error between the target Q-Values (`y`) and the estimated Q-Values (`q_value`). However, because the experiences are very noisy, it is better to use a quadratic loss only for small errors (below 1.0) and a linear loss (twice the absolute error) for larger errors, which is what the code above computes. This way large errors don't push the model parameters around as much. Note that we also tweaked some hyperparameters (using a smaller learning rate, and using Nesterov Accelerated Gradients rather than Adam optimization, since adaptive gradient algorithms may sometimes be bad, according to this [paper](https://arxiv.org/abs/1705.08292)). We also tweaked a few other hyperparameters below (a larger replay memory, longer decay for the $\\epsilon$-greedy policy, larger discount rate, less frequent copies of the online DQN to the target DQN, etc.)."
|
"Note: in the first version of the book, the loss function was simply the squared error between the target Q-Values (`y`) and the estimated Q-Values (`q_value`). However, because the experiences are very noisy, it is better to use a quadratic loss only for small errors (below 1.0) and a linear loss (twice the absolute error) for larger errors, which is what the code above computes. This way large errors don't push the model parameters around as much. Note that we also tweaked some hyperparameters (using a smaller learning rate, and using Nesterov Accelerated Gradients rather than Adam optimization, since adaptive gradient algorithms may sometimes be bad, according to this [paper](https://arxiv.org/abs/1705.08292)). We also tweaked a few other hyperparameters below (a larger replay memory, longer decay for the $\\epsilon$-greedy policy, larger discount rate, less frequent copies of the online DQN to the target DQN, etc.)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"We use this `ReplayMemory` class instead of a `deque` because it is much faster for random access (thanks to @NileshPS who contributed it). Moreover, we default to sampling with replacement, which is much faster than sampling without replacement for large replay memories."
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 63,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from collections import deque\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"class ReplayMemory:\n",
|
"class ReplayMemory:\n",
|
||||||
" \n",
|
" def __init__(self, maxlen):\n",
|
||||||
" def __init__(self, size):\n",
|
" self.maxlen = maxlen\n",
|
||||||
" self.size = size\n",
|
" self.buf = np.empty(shape=maxlen, dtype=np.object)\n",
|
||||||
" self.buf= [None] * self.size\n",
|
" self.index = 0\n",
|
||||||
" self.index, self.count = 0, 0\n",
|
" self.length = 0\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def append(self, data):\n",
|
" def append(self, data):\n",
|
||||||
" self.buf[self.index] = data\n",
|
" self.buf[self.index] = data\n",
|
||||||
" self.count = min(self.count + 1, self.size)\n",
|
" self.length = min(self.length + 1, self.maxlen)\n",
|
||||||
" self.index = (self.index + 1) % self.size\n",
|
" self.index = (self.index + 1) % self.maxlen\n",
|
||||||
" \n",
|
|
||||||
" def __getitem__(self, idx):\n",
|
|
||||||
" return self.buf[idx]\n",
|
|
||||||
" \n",
|
|
||||||
" def __len__(self):\n",
|
|
||||||
" return self.count\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
" \n",
|
" \n",
|
||||||
|
" def sample(self, batch_size, with_replacement=True):\n",
|
||||||
|
" if with_replacement:\n",
|
||||||
|
" indices = np.random.randint(self.length, size=batch_size) # faster\n",
|
||||||
|
" else:\n",
|
||||||
|
" indices = np.random.permutation(self.length)[:batch_size]\n",
|
||||||
|
" return self.buf[indices]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 64,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
"replay_memory_size = 500000\n",
|
"replay_memory_size = 500000\n",
|
||||||
"replay_memory = ReplayMemory(replay_memory_size)\n",
|
"replay_memory = ReplayMemory(replay_memory_size)"
|
||||||
"\n",
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 65,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
"def sample_memories(batch_size):\n",
|
"def sample_memories(batch_size):\n",
|
||||||
" indices = np.random.permutation(len(replay_memory))[:batch_size]\n",
|
|
||||||
" cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
|
" cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
|
||||||
" for idx in indices:\n",
|
" for memory in replay_memory.sample(batch_size):\n",
|
||||||
" memory = replay_memory[idx]\n",
|
|
||||||
" for col, value in zip(cols, memory):\n",
|
" for col, value in zip(cols, memory):\n",
|
||||||
" col.append(value)\n",
|
" col.append(value)\n",
|
||||||
" cols = [np.array(col) for col in cols]\n",
|
" cols = [np.array(col) for col in cols]\n",
|
||||||
|
@ -1601,7 +1616,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 64,
|
"execution_count": 66,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1619,7 +1634,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 65,
|
"execution_count": 67,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1645,7 +1660,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 66,
|
"execution_count": 68,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1664,7 +1679,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 67,
|
"execution_count": 69,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1740,7 +1755,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 68,
|
"execution_count": 70,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1770,7 +1785,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 69,
|
"execution_count": 71,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
|
@ -1802,7 +1817,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 71,
|
"execution_count": 73,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": true
|
"collapsed": true
|
||||||
},
|
},
|
||||||
|
@ -1815,7 +1830,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 72,
|
"execution_count": 74,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1829,7 +1844,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 73,
|
"execution_count": 75,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1854,10 +1869,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 74,
|
"execution_count": 76,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"from collections import deque\n",
|
||||||
|
"\n",
|
||||||
"def combine_observations_multichannel(preprocessed_observations):\n",
|
"def combine_observations_multichannel(preprocessed_observations):\n",
|
||||||
" return np.array(preprocessed_observations).transpose([1, 2, 0])\n",
|
" return np.array(preprocessed_observations).transpose([1, 2, 0])\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1877,7 +1894,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 75,
|
"execution_count": 77,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1933,7 +1950,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 76,
|
"execution_count": 78,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1942,7 +1959,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 77,
|
"execution_count": 79,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"scrolled": true
|
"scrolled": true
|
||||||
},
|
},
|
||||||
|
@ -1965,7 +1982,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 78,
|
"execution_count": 80,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1974,7 +1991,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 79,
|
"execution_count": 81,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1983,7 +2000,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 80,
|
"execution_count": 82,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -1994,7 +2011,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 81,
|
"execution_count": 83,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2010,7 +2027,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 82,
|
"execution_count": 84,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2019,7 +2036,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 83,
|
"execution_count": 85,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2028,7 +2045,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 84,
|
"execution_count": 86,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2044,7 +2061,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 85,
|
"execution_count": 87,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2053,7 +2070,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 86,
|
"execution_count": 88,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2064,7 +2081,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 87,
|
"execution_count": 89,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2117,7 +2134,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 88,
|
"execution_count": 90,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2144,7 +2161,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 89,
|
"execution_count": 91,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2162,7 +2179,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 90,
|
"execution_count": 92,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -2208,7 +2225,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 91,
|
"execution_count": 93,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
|
Loading…
Reference in New Issue