Fix render code for LunarLander-v2
parent
df58dd2a70
commit
4ca56568ca
|
@ -2547,12 +2547,12 @@
|
||||||
"source": [
|
"source": [
|
||||||
"def lander_render_policy_net(model, n_max_steps=500, seed=42):\n",
|
"def lander_render_policy_net(model, n_max_steps=500, seed=42):\n",
|
||||||
" frames = []\n",
|
" frames = []\n",
|
||||||
" env = gym.make(\"LunarLander-v2\")\n",
|
" env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")\n",
|
||||||
" tf.random.set_seed(seed)\n",
|
" tf.random.set_seed(seed)\n",
|
||||||
" np.random.seed(seed)\n",
|
" np.random.seed(seed)\n",
|
||||||
" obs, info = env.reset(seed=seed)\n",
|
" obs, info = env.reset(seed=seed)\n",
|
||||||
" for step in range(n_max_steps):\n",
|
" for step in range(n_max_steps):\n",
|
||||||
" frames.append(env.render(mode=\"rgb_array\"))\n",
|
" frames.append(env.render())\n",
|
||||||
" probas = model(obs[np.newaxis])\n",
|
" probas = model(obs[np.newaxis])\n",
|
||||||
" logits = tf.math.log(probas + tf.keras.backend.epsilon())\n",
|
" logits = tf.math.log(probas + tf.keras.backend.epsilon())\n",
|
||||||
" action = tf.random.categorical(logits, num_samples=1)\n",
|
" action = tf.random.categorical(logits, num_samples=1)\n",
|
||||||
|
|
Loading…
Reference in New Issue