diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index b8aea98..c68a538 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1567,31 +1567,35 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", - "replay_memory_size = 500000\n", - "replay_memory = deque([], maxlen=replay_memory_size)\n", - "\n", "\n", "class ReplayMemory:\n", " \n", " def __init__(self, size):\n", " self.size = size\n", " self.buf= [None] * self.size\n", - " self.index = 0\n", + " self.index, self.count = 0, 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", + " self.count = min(self.count + 1, self.size)\n", " self.index = (self.index + 1) % self.size\n", " \n", " def __getitem__(self, idx):\n", " return self.buf[idx]\n", + " \n", + " def __len__(self):\n", + " return self.count\n", + "\n", "\n", " \n", + "replay_memory_size = 500000\n", + "replay_memory = ReplayMemory(replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n",