Fix MyRNN shape issues, fixes #457
parent
341d8fe792
commit
d1a577314a
|
@ -852,20 +852,23 @@
|
|||
" self.get_initial_state = getattr(\n",
|
||||
" self.cell, \"get_initial_state\", self.fallback_initial_state)\n",
|
||||
" def fallback_initial_state(self, inputs):\n",
|
||||
" return [tf.zeros([self.cell.state_size], dtype=inputs.dtype)]\n",
|
||||
" batch_size = tf.shape(inputs)[0]\n",
|
||||
" return [tf.zeros([batch_size, self.cell.state_size], dtype=inputs.dtype)]\n",
|
||||
" @tf.function\n",
|
||||
" def call(self, inputs):\n",
|
||||
" states = self.get_initial_state(inputs)\n",
|
||||
" n_steps = tf.shape(inputs)[1]\n",
|
||||
" if self.return_sequences:\n",
|
||||
" sequences = tf.TensorArray(inputs.dtype, size=n_steps)\n",
|
||||
" outputs = tf.zeros(shape=[n_steps, self.cell.output_size], dtype=inputs.dtype)\n",
|
||||
" shape = tf.shape(inputs)\n",
|
||||
" batch_size = shape[0]\n",
|
||||
" n_steps = shape[1]\n",
|
||||
" sequences = tf.TensorArray(\n",
|
||||
" inputs.dtype, size=(n_steps if self.return_sequences else 0))\n",
|
||||
" outputs = tf.zeros(shape=[batch_size, self.cell.output_size], dtype=inputs.dtype)\n",
|
||||
" for step in tf.range(n_steps):\n",
|
||||
" outputs, states = self.cell(inputs[:, step], states)\n",
|
||||
" if self.return_sequences:\n",
|
||||
" sequences = sequences.write(step, outputs)\n",
|
||||
" if self.return_sequences:\n",
|
||||
" return sequences.stack()\n",
|
||||
" return tf.transpose(sequences.stack(), [1, 0, 2])\n",
|
||||
" else:\n",
|
||||
" return outputs"
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue