Speed up training: I tuned learning rate for DQN variants, and added auto-FIRE for Blockout. Fixes #117

main
Aurélien Geron 2021-03-09 22:21:08 +13:00
parent 80eec21242
commit dd94101c5d
1 changed files with 57 additions and 40 deletions

View File

@ -57,7 +57,7 @@
" # %tensorflow_version only exists in Colab.\n", " # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n", " %tensorflow_version 2.x\n",
" !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb\n", " !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb\n",
" !pip install -q -U tf-agents-nightly pyvirtualdisplay gym[atari]\n", " !pip install -q -U tf-agents pyvirtualdisplay gym[atari]\n",
" IS_COLAB = True\n", " IS_COLAB = True\n",
"except Exception:\n", "except Exception:\n",
" IS_COLAB = False\n", " IS_COLAB = False\n",
@ -1378,7 +1378,9 @@
"source": [ "source": [
"Lastly, let's create a function that will sample some experiences from the replay memory and perform a training step:\n", "Lastly, let's create a function that will sample some experiences from the replay memory and perform a training step:\n",
"\n", "\n",
"**Note**: the first 3 releases of the 2nd edition were missing the `reshape()` operation which converts `target_Q_values` to a column vector (this is required by the `loss_fn()`)." "**Notes**:\n",
"* The first 3 releases of the 2nd edition were missing the `reshape()` operation which converts `target_Q_values` to a column vector (this is required by the `loss_fn()`).\n",
"* The book uses a learning rate of 1e-3, but in the code below I use 1e-2, as it significantly improves training. I also tuned the learning rates of the DQN variants below."
] ]
}, },
{ {
@ -1389,7 +1391,7 @@
"source": [ "source": [
"batch_size = 32\n", "batch_size = 32\n",
"discount_rate = 0.95\n", "discount_rate = 0.95\n",
"optimizer = keras.optimizers.Adam(lr=1e-3)\n", "optimizer = keras.optimizers.Adam(lr=1e-2)\n",
"loss_fn = keras.losses.mean_squared_error\n", "loss_fn = keras.losses.mean_squared_error\n",
"\n", "\n",
"def training_step(batch_size):\n", "def training_step(batch_size):\n",
@ -1446,7 +1448,7 @@
" if done:\n", " if done:\n",
" break\n", " break\n",
" rewards.append(step) # Not shown in the book\n", " rewards.append(step) # Not shown in the book\n",
" if step > best_score: # Not shown\n", " if step >= best_score: # Not shown\n",
" best_weights = model.get_weights() # Not shown\n", " best_weights = model.get_weights() # Not shown\n",
" best_score = step # Not shown\n", " best_score = step # Not shown\n",
" print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\") # Not shown\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\") # Not shown\n",
@ -1496,7 +1498,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Not bad at all!" "Not bad at all! 😀"
] ]
}, },
{ {
@ -1534,7 +1536,7 @@
"source": [ "source": [
"batch_size = 32\n", "batch_size = 32\n",
"discount_rate = 0.95\n", "discount_rate = 0.95\n",
"optimizer = keras.optimizers.Adam(lr=1e-3)\n", "optimizer = keras.optimizers.Adam(lr=6e-3)\n",
"loss_fn = keras.losses.Huber()\n", "loss_fn = keras.losses.Huber()\n",
"\n", "\n",
"def training_step(batch_size):\n", "def training_step(batch_size):\n",
@ -1586,16 +1588,16 @@
" if done:\n", " if done:\n",
" break\n", " break\n",
" rewards.append(step)\n", " rewards.append(step)\n",
" if step > best_score:\n", " if step >= best_score:\n",
" best_weights = model.get_weights()\n", " best_weights = model.get_weights()\n",
" best_score = step\n", " best_score = step\n",
" print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n",
" if episode > 50:\n", " if episode >= 50:\n",
" training_step(batch_size)\n", " training_step(batch_size)\n",
" if episode % 50 == 0:\n", " if episode % 50 == 0:\n",
" target.set_weights(model.get_weights())\n", " target.set_weights(model.get_weights())\n",
" # Alternatively, you can do soft updates at each step:\n", " # Alternatively, you can do soft updates at each step:\n",
" #if episode > 50:\n", " #if episode >= 50:\n",
" #target_weights = target.get_weights()\n", " #target_weights = target.get_weights()\n",
" #online_weights = model.get_weights()\n", " #online_weights = model.get_weights()\n",
" #for index in range(len(target_weights)):\n", " #for index in range(len(target_weights)):\n",
@ -1627,7 +1629,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"env.seed(42)\n", "env.seed(43)\n",
"state = env.reset()\n", "state = env.reset()\n",
"\n", "\n",
"frames = []\n", "frames = []\n",
@ -1682,7 +1684,7 @@
"source": [ "source": [
"batch_size = 32\n", "batch_size = 32\n",
"discount_rate = 0.95\n", "discount_rate = 0.95\n",
"optimizer = keras.optimizers.Adam(lr=1e-2)\n", "optimizer = keras.optimizers.Adam(lr=7.5e-3)\n",
"loss_fn = keras.losses.Huber()\n", "loss_fn = keras.losses.Huber()\n",
"\n", "\n",
"def training_step(batch_size):\n", "def training_step(batch_size):\n",
@ -1734,13 +1736,13 @@
" if done:\n", " if done:\n",
" break\n", " break\n",
" rewards.append(step)\n", " rewards.append(step)\n",
" if step > best_score:\n", " if step >= best_score:\n",
" best_weights = model.get_weights()\n", " best_weights = model.get_weights()\n",
" best_score = step\n", " best_score = step\n",
" print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n", " print(\"\\rEpisode: {}, Steps: {}, eps: {:.3f}\".format(episode, step + 1, epsilon), end=\"\")\n",
" if episode > 50:\n", " if episode >= 50:\n",
" training_step(batch_size)\n", " training_step(batch_size)\n",
" if episode % 200 == 0:\n", " if episode % 50 == 0:\n",
" target.set_weights(model.get_weights())\n", " target.set_weights(model.get_weights())\n",
"\n", "\n",
"model.set_weights(best_weights)" "model.set_weights(best_weights)"
@ -2015,6 +2017,15 @@
"limited_repeating_env" "limited_repeating_env"
] ]
}, },
{
"cell_type": "code",
"execution_count": 93,
"metadata": {},
"outputs": [],
"source": [
"limited_repeating_env.unwrapped"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@ -2023,12 +2034,10 @@
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": 93,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [ "source": [
"limited_repeating_env.unwrapped" "**Warning**: Breakout requires the player to press the FIRE button at the start of the game and after each life lost. The agent may take a very long time learning this because at first it seems that pressing FIRE just means losing faster. To speed up training considerably, we create and use a subclass of the `AtariPreprocessing` wrapper class called `AtariPreprocessingWithAutoFire` which presses FIRE (i.e., plays action 1) automatically at the start of the game and after each life lost. This is different from the book which uses the regular `AtariPreprocessing` wrapper."
] ]
}, },
{ {
@ -2044,10 +2053,21 @@
"max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames\n", "max_episode_steps = 27000 # <=> 108k ALE frames since 1 step = 4 frames\n",
"environment_name = \"BreakoutNoFrameskip-v4\"\n", "environment_name = \"BreakoutNoFrameskip-v4\"\n",
"\n", "\n",
"class AtariPreprocessingWithAutoFire(AtariPreprocessing):\n",
" def reset(self, **kwargs):\n",
" super().reset(**kwargs)\n",
" return self.step(1)[0] # FIRE to start\n",
" def step(self, action):\n",
" lives_before_action = self.ale.lives()\n",
" out = super().step(action)\n",
" if self.ale.lives() < lives_before_action and not done:\n",
" out = super().step(1) # FIRE to start after life lost\n",
" return out\n",
"\n",
"env = suite_atari.load(\n", "env = suite_atari.load(\n",
" environment_name,\n", " environment_name,\n",
" max_episode_steps=max_episode_steps,\n", " max_episode_steps=max_episode_steps,\n",
" gym_env_wrappers=[AtariPreprocessing, FrameStack4])" " gym_env_wrappers=[AtariPreprocessingWithAutoStart, FrameStack4])"
] ]
}, },
{ {
@ -2074,7 +2094,6 @@
"source": [ "source": [
"env.seed(42)\n", "env.seed(42)\n",
"env.reset()\n", "env.reset()\n",
"time_step = env.step(1) # FIRE\n",
"for _ in range(4):\n", "for _ in range(4):\n",
" time_step = env.step(3) # LEFT" " time_step = env.step(3) # LEFT"
] ]
@ -2214,6 +2233,13 @@
"Create the replay buffer (this will use a lot of RAM, so please reduce the buffer size if you get an out-of-memory error):" "Create the replay buffer (this will use a lot of RAM, so please reduce the buffer size if you get an out-of-memory error):"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Warning**: we use a replay buffer of size 100,000 instead of 1,000,000 (as used in the book) since many people were getting OOM (Out-Of-Memory) errors."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 102, "execution_count": 102,
@ -2225,7 +2251,7 @@
"replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n", "replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(\n",
" data_spec=agent.collect_data_spec,\n", " data_spec=agent.collect_data_spec,\n",
" batch_size=tf_env.batch_size,\n", " batch_size=tf_env.batch_size,\n",
" max_length=1000000) # reduce if OOM error\n", " max_length=100000) # reduce if OOM error\n",
"\n", "\n",
"replay_buffer_observer = replay_buffer.add_batch" "replay_buffer_observer = replay_buffer.add_batch"
] ]
@ -2365,9 +2391,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"tf.random.set_seed(93) # chosen to show an example of trajectory at the end of an episode\n", "tf.random.set_seed(9) # chosen to show an example of trajectory at the end of an episode\n",
"\n", "\n",
"#trajectories, buffer_info = replay_buffer.get_next(\n", "#trajectories, buffer_info = replay_buffer.get_next( # get_next() is deprecated\n",
"# sample_batch_size=2, num_steps=3)\n", "# sample_batch_size=2, num_steps=3)\n",
"\n", "\n",
"trajectories, buffer_info = next(iter(replay_buffer.as_dataset(\n", "trajectories, buffer_info = next(iter(replay_buffer.as_dataset(\n",
@ -2500,7 +2526,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Run the next cell to train the agent for 10,000 steps. Then look at its behavior by running the following cell. You can run these two cells as many times as you wish. The agent will keep improving!" "Run the next cell to train the agent for 50,000 steps. Then look at its behavior by running the following cell. You can run these two cells as many times as you wish. The agent will keep improving! It will likely take over 200,000 iterations for the agent to become reasonably good."
] ]
}, },
{ {
@ -2509,7 +2535,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"train_agent(n_iterations=10000)" "train_agent(n_iterations=50000)"
] ]
}, },
{ {
@ -2523,19 +2549,10 @@
" global frames\n", " global frames\n",
" frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n", " frames.append(tf_env.pyenv.envs[0].render(mode=\"rgb_array\"))\n",
"\n", "\n",
"prev_lives = tf_env.pyenv.envs[0].ale.lives()\n",
"def reset_and_fire_on_life_lost(trajectory):\n",
" global prev_lives\n",
" lives = tf_env.pyenv.envs[0].ale.lives()\n",
" if prev_lives != lives:\n",
" tf_env.reset()\n",
" tf_env.pyenv.envs[0].step(1)\n",
" prev_lives = lives\n",
"\n",
"watch_driver = DynamicStepDriver(\n", "watch_driver = DynamicStepDriver(\n",
" tf_env,\n", " tf_env,\n",
" agent.policy,\n", " agent.policy,\n",
" observers=[save_frames, reset_and_fire_on_life_lost, ShowProgress(1000)],\n", " observers=[save_frames, ShowProgress(1000)],\n",
" num_steps=1000)\n", " num_steps=1000)\n",
"final_time_step, final_policy_state = watch_driver.run()\n", "final_time_step, final_policy_state = watch_driver.run()\n",
"\n", "\n",
@ -2798,7 +2815,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.9" "version": "3.7.10"
} }
}, },
"nbformat": 4, "nbformat": 4,