Fix figure name and clarify a couple code examples

main
Aurélien Geron 2019-05-27 20:35:00 +08:00
parent c5f4b41cf5
commit 3ef350ab4c
1 changed files with 15 additions and 67 deletions

View File

@ -233,7 +233,7 @@
"$ xvfb-run -s \"-screen 0 1400x900x24\" jupyter notebook\n",
"```\n",
"\n",
"Alternatively, you can install the `pyvirtualdisplay` Python library which wraps Xvfb:\n",
"Alternatively, you can install the [pyvirtualdisplay](https://github.com/ponty/pyvirtualdisplay) Python library which wraps Xvfb:\n",
"\n",
"```bash\n",
"python3 -m pip install -U pyvirtualdisplay\n",
@ -1319,8 +1319,8 @@
" indices = np.random.randint(len(replay_memory), size=batch_size)\n",
" batch = [replay_memory[index] for index in indices]\n",
" states, actions, rewards, next_states, dones = [\n",
" np.array([experience[index] for experience in batch])\n",
" for index in range(5)]\n",
" np.array([experience[field_index] for experience in batch])\n",
" for field_index in range(5)]\n",
" return states, actions, rewards, next_states, dones"
]
},
@ -1580,7 +1580,7 @@
"plt.plot(rewards)\n",
"plt.xlabel(\"Episode\", fontsize=14)\n",
"plt.ylabel(\"Sum of rewards\", fontsize=14)\n",
"save_fig(\"dqn_rewards_plot\")\n",
"save_fig(\"double_dqn_rewards_plot\")\n",
"plt.show()"
]
},
@ -2057,7 +2057,7 @@
" current_frame_delta = np.maximum(obs[..., 3] - obs[..., :3].mean(axis=-1), 0.)\n",
" img[..., 0] += current_frame_delta\n",
" img[..., 2] += current_frame_delta\n",
" img = (img - img.min()) / (img.max() - img.min())\n",
" img = np.clip(img / 150, 0, 1)\n",
" plt.imshow(img)\n",
" plt.axis(\"off\")"
]
@ -2459,58 +2459,6 @@
"train_agent(n_iterations=200) # change this to 10 million or more!"
]
},
{
"cell_type": "code",
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
"num_eval_episodes = 10\n",
"eval_metrics = [\n",
" tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),\n",
" tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
"eval_tf_env = suite_atari.load(\n",
" environment_name,\n",
" max_episode_steps=max_episode_steps,\n",
" gym_env_wrappers=[AtariPreprocessing, FrameStack4])\n",
"\n",
"eval_tf_env = TFPyEnvironment(eval_tf_env)"
]
},
{
"cell_type": "code",
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
"from tf_agents.eval import metric_utils\n",
"\n",
"results = metric_utils.eager_compute(\n",
" eval_metrics,\n",
" eval_tf_env,\n",
" agent.policy,\n",
" num_episodes=num_eval_episodes,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
"results"
]
},
{
"cell_type": "markdown",
"metadata": {},
@ -2534,7 +2482,7 @@
},
{
"cell_type": "code",
"execution_count": 123,
"execution_count": 119,
"metadata": {},
"outputs": [],
"source": [
@ -2548,7 +2496,7 @@
},
{
"cell_type": "code",
"execution_count": 124,
"execution_count": 120,
"metadata": {},
"outputs": [],
"source": [
@ -2557,7 +2505,7 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 121,
"metadata": {},
"outputs": [],
"source": [
@ -2573,7 +2521,7 @@
},
{
"cell_type": "code",
"execution_count": 126,
"execution_count": 122,
"metadata": {},
"outputs": [],
"source": [
@ -2596,7 +2544,7 @@
},
{
"cell_type": "code",
"execution_count": 127,
"execution_count": 123,
"metadata": {},
"outputs": [],
"source": [
@ -2608,7 +2556,7 @@
},
{
"cell_type": "code",
"execution_count": 128,
"execution_count": 124,
"metadata": {},
"outputs": [],
"source": [
@ -2617,7 +2565,7 @@
},
{
"cell_type": "code",
"execution_count": 129,
"execution_count": 125,
"metadata": {},
"outputs": [],
"source": [
@ -2640,7 +2588,7 @@
},
{
"cell_type": "code",
"execution_count": 130,
"execution_count": 126,
"metadata": {},
"outputs": [],
"source": [
@ -2689,7 +2637,7 @@
},
{
"cell_type": "code",
"execution_count": 131,
"execution_count": 127,
"metadata": {},
"outputs": [],
"source": [
@ -2700,7 +2648,7 @@
},
{
"cell_type": "code",
"execution_count": 132,
"execution_count": 128,
"metadata": {},
"outputs": [],
"source": [