From aaa5246d9c7a66cf9150fa753df49633b32f0f6d Mon Sep 17 00:00:00 2001 From: Nilesh PS Date: Wed, 28 Mar 2018 20:07:39 +0530 Subject: [PATCH] use ReplayMemory --- 16_reinforcement_learning.ipynb | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index b8aea98..c68a538 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1567,31 +1567,35 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", - "replay_memory_size = 500000\n", - "replay_memory = deque([], maxlen=replay_memory_size)\n", - "\n", "\n", "class ReplayMemory:\n", " \n", " def __init__(self, size):\n", " self.size = size\n", " self.buf= [None] * self.size\n", - " self.index = 0\n", + " self.index, self.count = 0, 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", + " self.count = min(self.count + 1, self.size)\n", " self.index = (self.index + 1) % self.size\n", " \n", " def __getitem__(self, idx):\n", " return self.buf[idx]\n", + " \n", + " def __len__(self):\n", + " return self.count\n", + "\n", "\n", " \n", + "replay_memory_size = 500000\n", + "replay_memory = ReplayMemory(replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n",