Fix MyRNN shape issues, fixes #457

main
Aurélien Geron 2021-08-19 12:30:12 +12:00
parent 341d8fe792
commit d1a577314a
1 changed files with 9 additions and 6 deletions

View File

@ -852,20 +852,23 @@
" self.get_initial_state = getattr(\n", " self.get_initial_state = getattr(\n",
" self.cell, \"get_initial_state\", self.fallback_initial_state)\n", " self.cell, \"get_initial_state\", self.fallback_initial_state)\n",
" def fallback_initial_state(self, inputs):\n", " def fallback_initial_state(self, inputs):\n",
" return [tf.zeros([self.cell.state_size], dtype=inputs.dtype)]\n", " batch_size = tf.shape(inputs)[0]\n",
" return [tf.zeros([batch_size, self.cell.state_size], dtype=inputs.dtype)]\n",
" @tf.function\n", " @tf.function\n",
" def call(self, inputs):\n", " def call(self, inputs):\n",
" states = self.get_initial_state(inputs)\n", " states = self.get_initial_state(inputs)\n",
" n_steps = tf.shape(inputs)[1]\n", " shape = tf.shape(inputs)\n",
" if self.return_sequences:\n", " batch_size = shape[0]\n",
" sequences = tf.TensorArray(inputs.dtype, size=n_steps)\n", " n_steps = shape[1]\n",
" outputs = tf.zeros(shape=[n_steps, self.cell.output_size], dtype=inputs.dtype)\n", " sequences = tf.TensorArray(\n",
" inputs.dtype, size=(n_steps if self.return_sequences else 0))\n",
" outputs = tf.zeros(shape=[batch_size, self.cell.output_size], dtype=inputs.dtype)\n",
" for step in tf.range(n_steps):\n", " for step in tf.range(n_steps):\n",
" outputs, states = self.cell(inputs[:, step], states)\n", " outputs, states = self.cell(inputs[:, step], states)\n",
" if self.return_sequences:\n", " if self.return_sequences:\n",
" sequences = sequences.write(step, outputs)\n", " sequences = sequences.write(step, outputs)\n",
" if self.return_sequences:\n", " if self.return_sequences:\n",
" return sequences.stack()\n", " return tf.transpose(sequences.stack(), [1, 0, 2])\n",
" else:\n", " else:\n",
" return outputs" " return outputs"
] ]