Fix MyRNN shape issues, fixes #457
parent
341d8fe792
commit
d1a577314a
|
@ -852,20 +852,23 @@
|
||||||
" self.get_initial_state = getattr(\n",
|
" self.get_initial_state = getattr(\n",
|
||||||
" self.cell, \"get_initial_state\", self.fallback_initial_state)\n",
|
" self.cell, \"get_initial_state\", self.fallback_initial_state)\n",
|
||||||
" def fallback_initial_state(self, inputs):\n",
|
" def fallback_initial_state(self, inputs):\n",
|
||||||
" return [tf.zeros([self.cell.state_size], dtype=inputs.dtype)]\n",
|
" batch_size = tf.shape(inputs)[0]\n",
|
||||||
|
" return [tf.zeros([batch_size, self.cell.state_size], dtype=inputs.dtype)]\n",
|
||||||
" @tf.function\n",
|
" @tf.function\n",
|
||||||
" def call(self, inputs):\n",
|
" def call(self, inputs):\n",
|
||||||
" states = self.get_initial_state(inputs)\n",
|
" states = self.get_initial_state(inputs)\n",
|
||||||
" n_steps = tf.shape(inputs)[1]\n",
|
" shape = tf.shape(inputs)\n",
|
||||||
" if self.return_sequences:\n",
|
" batch_size = shape[0]\n",
|
||||||
" sequences = tf.TensorArray(inputs.dtype, size=n_steps)\n",
|
" n_steps = shape[1]\n",
|
||||||
" outputs = tf.zeros(shape=[n_steps, self.cell.output_size], dtype=inputs.dtype)\n",
|
" sequences = tf.TensorArray(\n",
|
||||||
|
" inputs.dtype, size=(n_steps if self.return_sequences else 0))\n",
|
||||||
|
" outputs = tf.zeros(shape=[batch_size, self.cell.output_size], dtype=inputs.dtype)\n",
|
||||||
" for step in tf.range(n_steps):\n",
|
" for step in tf.range(n_steps):\n",
|
||||||
" outputs, states = self.cell(inputs[:, step], states)\n",
|
" outputs, states = self.cell(inputs[:, step], states)\n",
|
||||||
" if self.return_sequences:\n",
|
" if self.return_sequences:\n",
|
||||||
" sequences = sequences.write(step, outputs)\n",
|
" sequences = sequences.write(step, outputs)\n",
|
||||||
" if self.return_sequences:\n",
|
" if self.return_sequences:\n",
|
||||||
" return sequences.stack()\n",
|
" return tf.transpose(sequences.stack(), [1, 0, 2])\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" return outputs"
|
" return outputs"
|
||||||
]
|
]
|
||||||
|
|
Loading…
Reference in New Issue