Fix render code for LunarLander-v2
parent
df58dd2a70
commit
4ca56568ca
|
@ -2547,12 +2547,12 @@
|
|||
"source": [
|
||||
"def lander_render_policy_net(model, n_max_steps=500, seed=42):\n",
|
||||
" frames = []\n",
|
||||
" env = gym.make(\"LunarLander-v2\")\n",
|
||||
" env = gym.make(\"LunarLander-v2\", render_mode=\"rgb_array\")\n",
|
||||
" tf.random.set_seed(seed)\n",
|
||||
" np.random.seed(seed)\n",
|
||||
" obs, info = env.reset(seed=seed)\n",
|
||||
" for step in range(n_max_steps):\n",
|
||||
" frames.append(env.render(mode=\"rgb_array\"))\n",
|
||||
" frames.append(env.render())\n",
|
||||
" probas = model(obs[np.newaxis])\n",
|
||||
" logits = tf.math.log(probas + tf.keras.backend.epsilon())\n",
|
||||
" action = tf.random.categorical(logits, num_samples=1)\n",
|
||||
|
|
Loading…
Reference in New Issue