diff --git a/18_reinforcement_learning.ipynb b/18_reinforcement_learning.ipynb index ed44866..b723f04 100644 --- a/18_reinforcement_learning.ipynb +++ b/18_reinforcement_learning.ipynb @@ -1306,7 +1306,7 @@ "source": [ "def epsilon_greedy_policy(state, epsilon=0):\n", " if np.random.rand() < epsilon:\n", - " return np.random.randint(2)\n", + " return np.random.randint(n_outputs)\n", " else:\n", " Q_values = model.predict(state[np.newaxis])\n", " return np.argmax(Q_values[0])"