diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index 30f2ab2..b8aea98 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1576,6 +1576,23 @@ "replay_memory_size = 500000\n", "replay_memory = deque([], maxlen=replay_memory_size)\n", "\n", + "\n", + "class ReplayMemory:\n", + " \n", + " def __init__(self, size):\n", + " self.size = size\n", + " self.buf= [None] * self.size\n", + " self.index = 0\n", + " \n", + " def append(self, data):\n", + " self.buf[self.index] = data\n", + " self.index = (self.index + 1) % self.size\n", + " \n", + " def __getitem__(self, idx):\n", + " return self.buf[idx]\n", + "\n", + " \n", + "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",