From d1a577314aa912838a09a2e54b14d12c99095ddd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Geron?= Date: Thu, 19 Aug 2021 12:30:12 +1200 Subject: [PATCH] Fix MyRNN shape issues, fixes #457 --- 15_processing_sequences_using_rnns_and_cnns.ipynb | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/15_processing_sequences_using_rnns_and_cnns.ipynb b/15_processing_sequences_using_rnns_and_cnns.ipynb index e530c36..70fd374 100644 --- a/15_processing_sequences_using_rnns_and_cnns.ipynb +++ b/15_processing_sequences_using_rnns_and_cnns.ipynb @@ -852,20 +852,23 @@ " self.get_initial_state = getattr(\n", " self.cell, \"get_initial_state\", self.fallback_initial_state)\n", " def fallback_initial_state(self, inputs):\n", - " return [tf.zeros([self.cell.state_size], dtype=inputs.dtype)]\n", + " batch_size = tf.shape(inputs)[0]\n", + " return [tf.zeros([batch_size, self.cell.state_size], dtype=inputs.dtype)]\n", " @tf.function\n", " def call(self, inputs):\n", " states = self.get_initial_state(inputs)\n", - " n_steps = tf.shape(inputs)[1]\n", - " if self.return_sequences:\n", - " sequences = tf.TensorArray(inputs.dtype, size=n_steps)\n", - " outputs = tf.zeros(shape=[n_steps, self.cell.output_size], dtype=inputs.dtype)\n", + " shape = tf.shape(inputs)\n", + " batch_size = shape[0]\n", + " n_steps = shape[1]\n", + " sequences = tf.TensorArray(\n", + " inputs.dtype, size=(n_steps if self.return_sequences else 0))\n", + " outputs = tf.zeros(shape=[batch_size, self.cell.output_size], dtype=inputs.dtype)\n", " for step in tf.range(n_steps):\n", " outputs, states = self.cell(inputs[:, step], states)\n", " if self.return_sequences:\n", " sequences = sequences.write(step, outputs)\n", " if self.return_sequences:\n", - " return sequences.stack()\n", + " return tf.transpose(sequences.stack(), [1, 0, 2])\n", " else:\n", " return outputs" ]