Fix bug in training_step: target_Q_values must be a column vector
parent
0c2c80d89e
commit
49715d4b74
|
@ -67,7 +67,7 @@
|
||||||
"from tensorflow import keras\n",
|
"from tensorflow import keras\n",
|
||||||
"assert tf.__version__ >= \"2.0\"\n",
|
"assert tf.__version__ >= \"2.0\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"if not tf.test.is_gpu_available():\n",
|
"if not tf.config.list_physical_devices('GPU'):\n",
|
||||||
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
|
" print(\"No GPU was detected. CNNs can be very slow without a GPU.\")\n",
|
||||||
" if IS_COLAB:\n",
|
" if IS_COLAB:\n",
|
||||||
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
|
" print(\"Go to Runtime > Change runtime and select a GPU hardware accelerator.\")\n",
|
||||||
|
@ -574,6 +574,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"keras.backend.clear_session()\n",
|
||||||
"tf.random.set_seed(42)\n",
|
"tf.random.set_seed(42)\n",
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -638,7 +639,7 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
@ -882,6 +883,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"keras.backend.clear_session()\n",
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"tf.random.set_seed(42)\n",
|
"tf.random.set_seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1274,6 +1276,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"keras.backend.clear_session()\n",
|
||||||
"tf.random.set_seed(42)\n",
|
"tf.random.set_seed(42)\n",
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1392,7 +1395,9 @@
|
||||||
" states, actions, rewards, next_states, dones = experiences\n",
|
" states, actions, rewards, next_states, dones = experiences\n",
|
||||||
" next_Q_values = model.predict(next_states)\n",
|
" next_Q_values = model.predict(next_states)\n",
|
||||||
" max_next_Q_values = np.max(next_Q_values, axis=1)\n",
|
" max_next_Q_values = np.max(next_Q_values, axis=1)\n",
|
||||||
" target_Q_values = rewards + (1 - dones) * discount_rate * max_next_Q_values\n",
|
" target_Q_values = (rewards +\n",
|
||||||
|
" (1 - dones) * discount_rate * max_next_Q_values)\n",
|
||||||
|
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||||
" with tf.GradientTape() as tape:\n",
|
" with tf.GradientTape() as tape:\n",
|
||||||
" all_Q_values = model(states)\n",
|
" all_Q_values = model(states)\n",
|
||||||
|
@ -1505,6 +1510,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"keras.backend.clear_session()\n",
|
||||||
"tf.random.set_seed(42)\n",
|
"tf.random.set_seed(42)\n",
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1536,7 +1542,9 @@
|
||||||
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
||||||
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
||||||
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
||||||
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
|
" target_Q_values = (rewards + \n",
|
||||||
|
" (1 - dones) * discount_rate * next_best_Q_values)\n",
|
||||||
|
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||||
" with tf.GradientTape() as tape:\n",
|
" with tf.GradientTape() as tape:\n",
|
||||||
" all_Q_values = model(states)\n",
|
" all_Q_values = model(states)\n",
|
||||||
|
@ -1646,6 +1654,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"keras.backend.clear_session()\n",
|
||||||
"tf.random.set_seed(42)\n",
|
"tf.random.set_seed(42)\n",
|
||||||
"np.random.seed(42)\n",
|
"np.random.seed(42)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
@ -1681,7 +1690,9 @@
|
||||||
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
" best_next_actions = np.argmax(next_Q_values, axis=1)\n",
|
||||||
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
" next_mask = tf.one_hot(best_next_actions, n_outputs).numpy()\n",
|
||||||
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
" next_best_Q_values = (target.predict(next_states) * next_mask).sum(axis=1)\n",
|
||||||
" target_Q_values = rewards + (1 - dones) * discount_rate * next_best_Q_values\n",
|
" target_Q_values = (rewards + \n",
|
||||||
|
" (1 - dones) * discount_rate * next_best_Q_values)\n",
|
||||||
|
" target_Q_values = target_Q_values.reshape(-1, 1)\n",
|
||||||
" mask = tf.one_hot(actions, n_outputs)\n",
|
" mask = tf.one_hot(actions, n_outputs)\n",
|
||||||
" with tf.GradientTape() as tape:\n",
|
" with tf.GradientTape() as tape:\n",
|
||||||
" all_Q_values = model(states)\n",
|
" all_Q_values = model(states)\n",
|
||||||
|
@ -2777,7 +2788,7 @@
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.7.3"
|
"version": "3.7.6"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|
Loading…
Reference in New Issue