use ReplayMemory
parent
16e8a8cf61
commit
aaa5246d9c
|
@ -1567,32 +1567,36 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 63,
|
"execution_count": 1,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"from collections import deque\n",
|
"from collections import deque\n",
|
||||||
"\n",
|
"\n",
|
||||||
"replay_memory_size = 500000\n",
|
|
||||||
"replay_memory = deque([], maxlen=replay_memory_size)\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"class ReplayMemory:\n",
|
"class ReplayMemory:\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def __init__(self, size):\n",
|
" def __init__(self, size):\n",
|
||||||
" self.size = size\n",
|
" self.size = size\n",
|
||||||
" self.buf= [None] * self.size\n",
|
" self.buf= [None] * self.size\n",
|
||||||
" self.index = 0\n",
|
" self.index, self.count = 0, 0\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def append(self, data):\n",
|
" def append(self, data):\n",
|
||||||
" self.buf[self.index] = data\n",
|
" self.buf[self.index] = data\n",
|
||||||
|
" self.count = min(self.count + 1, self.size)\n",
|
||||||
" self.index = (self.index + 1) % self.size\n",
|
" self.index = (self.index + 1) % self.size\n",
|
||||||
" \n",
|
" \n",
|
||||||
" def __getitem__(self, idx):\n",
|
" def __getitem__(self, idx):\n",
|
||||||
" return self.buf[idx]\n",
|
" return self.buf[idx]\n",
|
||||||
" \n",
|
" \n",
|
||||||
|
" def __len__(self):\n",
|
||||||
|
" return self.count\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" \n",
|
||||||
|
"replay_memory_size = 500000\n",
|
||||||
|
"replay_memory = ReplayMemory(replay_memory_size)\n",
|
||||||
|
"\n",
|
||||||
"def sample_memories(batch_size):\n",
|
"def sample_memories(batch_size):\n",
|
||||||
" indices = np.random.permutation(len(replay_memory))[:batch_size]\n",
|
" indices = np.random.permutation(len(replay_memory))[:batch_size]\n",
|
||||||
" cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
|
" cols = [[], [], [], [], []] # state, action, reward, next_state, continue\n",
|
||||||
|
|
Loading…
Reference in New Issue