From 16e8a8cf6102a8cfccb6a2c6cd50c4b86d034714 Mon Sep 17 00:00:00 2001 From: Nilesh PS Date: Wed, 28 Mar 2018 19:57:26 +0530 Subject: [PATCH 1/2] use list with circular indexing instead of deque as the replay buffer --- 16_reinforcement_learning.ipynb | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index 30f2ab2..b8aea98 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1576,6 +1576,23 @@ "replay_memory_size = 500000\n", "replay_memory = deque([], maxlen=replay_memory_size)\n", "\n", + "\n", + "class ReplayMemory:\n", + " \n", + " def __init__(self, size):\n", + " self.size = size\n", + " self.buf= [None] * self.size\n", + " self.index = 0\n", + " \n", + " def append(self, data):\n", + " self.buf[self.index] = data\n", + " self.index = (self.index + 1) % self.size\n", + " \n", + " def __getitem__(self, idx):\n", + " return self.buf[idx]\n", + "\n", + " \n", + "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n", From aaa5246d9c7a66cf9150fa753df49633b32f0f6d Mon Sep 17 00:00:00 2001 From: Nilesh PS Date: Wed, 28 Mar 2018 20:07:39 +0530 Subject: [PATCH 2/2] use ReplayMemory --- 16_reinforcement_learning.ipynb | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/16_reinforcement_learning.ipynb b/16_reinforcement_learning.ipynb index b8aea98..c68a538 100644 --- a/16_reinforcement_learning.ipynb +++ b/16_reinforcement_learning.ipynb @@ -1567,31 +1567,35 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from collections import deque\n", "\n", - "replay_memory_size = 500000\n", - "replay_memory = deque([], maxlen=replay_memory_size)\n", - "\n", "\n", "class ReplayMemory:\n", " \n", " def __init__(self, size):\n", " self.size = size\n", " self.buf= [None] * self.size\n", - " self.index = 0\n", + " self.index, self.count = 0, 0\n", " \n", " def append(self, data):\n", " self.buf[self.index] = data\n", + " self.count = min(self.count + 1, self.size)\n", " self.index = (self.index + 1) % self.size\n", " \n", " def __getitem__(self, idx):\n", " return self.buf[idx]\n", + " \n", + " def __len__(self):\n", + " return self.count\n", + "\n", "\n", " \n", + "replay_memory_size = 500000\n", + "replay_memory = ReplayMemory(replay_memory_size)\n", "\n", "def sample_memories(batch_size):\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n",