Fix bug in training_step: target_Q_values must be a column vector
parent
0c2c80d89e
commit
49715d4b74
|
@ -67,7 +67,7 @@
|
|||
"from tensorflow import keras\n",
|
||||
"assert tf.__version__ >= \"2.0\"\n",
|
||||
"\n",
|
||||
"if not tf.test.is_gpu_available():\n",
|
||||
"if not tf.config.list_physical_devices('GPU'):\n",
|
||||
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
|
||||
" if IS_COLAB:\n",
|
||||
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
|
||||
|
@ -574,6 +574,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
|
@ -638,7 +639,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -882,6 +883,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"np.random.seed(42)\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"\n",
|
||||
|
@ -1274,6 +1276,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
|
@ -1392,7 +1395,9 @@
|
|||
" states, actions, rewards, next_states, dones = experiences\n",
|
||||
" next_Q_values = model.predict(next_states)\n",
|
||||
" max_next_Q_values = np.max(next_Q_values, axis=1)\n",
|
||||
" target_Q_values = rewards + (1 - dones) * discount_rate * max_next_Q_values\n",
|
||||
" target_Q_values = (rewards +\n",
|
||||
" (1 - dones) * discount_rate * max_next_Q_values)\n",
|
||||
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||
" with tf.GradientTape() as tape:\n",
|
||||
" all_Q_values = model(states)\n",
|
||||
|
@ -1505,6 +1510,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
|
@ -1536,7 +1542,9 @@
|
|||
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
||||
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
||||
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
||||
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
|
||||
" target_Q_values = (rewards + \n",
|
||||
" (1 - dones) * discount_rate * next_best_Q_values)\n",
|
||||
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||
" with tf.GradientTape() as tape:\n",
|
||||
" all_Q_values = model(states)\n",
|
||||
|
@ -1646,6 +1654,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"keras.backend.clear_session()\n",
|
||||
"tf.random.set_seed(42)\n",
|
||||
"np.random.seed(42)\n",
|
||||
"\n",
|
||||
|
@ -1681,7 +1690,9 @@
|
|||
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
||||
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
||||
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
||||
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
|
||||
" target_Q_values = (rewards + \n",
|
||||
" (1 - dones) * discount_rate * next_best_Q_values)\n",
|
||||
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||
" with tf.GradientTape() as tape:\n",
|
||||
" all_Q_values = model(states)\n",
|
||||
|
@ -2777,7 +2788,7 @@
|
|||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.3"
|
||||
"version": "3.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
|
Loading…
Reference in New Issue