use ReplayMemory

main
Nilesh PS 2018-03-28 20:07:39 +05:30
parent 16e8a8cf61
commit aaa5246d9c
1 changed files with 9 additions and 5 deletions

View File

@ -1567,31 +1567,35 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 63, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from collections import deque\n", "from collections import deque\n",
"\n", "\n",
"replay_memory_size = 500000\n",
"replay_memory = deque([], maxlen=replay_memory_size)\n",
"\n",
"\n", "\n",
"class ReplayMemory:\n", "class ReplayMemory:\n",
" \n", " \n",
" def __init__(self, size):\n", " def __init__(self, size):\n",
" self.size = size\n", " self.size = size\n",
" self.buf= [None] * self.size\n", " self.buf= [None] * self.size\n",
" self.index = 0\n", " self.index, self.count = 0, 0\n",
" \n", " \n",
" def append(self, data):\n", " def append(self, data):\n",
" self.buf[self.index] = data\n", " self.buf[self.index] = data\n",
" self.count = min(self.count + 1, self.size)\n",
" self.index = (self.index + 1) % self.size\n", " self.index = (self.index + 1) % self.size\n",
" \n", " \n",
" def __getitem__(self, idx):\n", " def __getitem__(self, idx):\n",
" return self.buf[idx]\n", " return self.buf[idx]\n",
" \n",
" def __len__(self):\n",
" return self.count\n",
"\n",
"\n", "\n",
" \n", " \n",
"replay_memory_size = 500000\n",
"replay_memory = ReplayMemory(replay_memory_size)\n",
"\n", "\n",
"def sample_memories(batch_size):\n", "def sample_memories(batch_size):\n",
" indices = np.random.permutation(len(replay_memory))[:batch_size]\n", " indices = np.random.permutation(len(replay_memory))[:batch_size]\n",