TF mostly fixed an issue so remove workaround for ReconstructingRegressor
parent
fdb5d1695e
commit
495de15361
|
@ -2209,7 +2209,9 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Note**: due to an issue introduced in TF 2.2 ([#46858](https://github.com/tensorflow/tensorflow/issues/46858)), it is currently not possible to use `add_loss()` along with the `build()` method. So the following code differs from the book: I create the `reconstruct` layer in the constructor instead of the `build()` method. Unfortunately, this means that the number of units in this layer must be hard-coded (alternatively, it could be passed as an argument to the constructor)."
|
||||
"**Note**: the following code has two differences with the code in the book:\n",
|
||||
"1. It creates a `keras.metrics.Mean()` metric in the constructor and uses it in the `call()` method to track the mean reconstruction loss. Since we only want to do this during training, we add a `training` argument to the `call()` method, and if `training` is `True`, then we update `reconstruction_mean` and we call `self.add_metric()` to ensure it's displayed properly.\n",
|
||||
"2. Due to an issue introduced in TF 2.2 ([#46858](https://github.com/tensorflow/tensorflow/issues/46858)), we must not call `super().build()` inside the `build()` method."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -2218,21 +2220,19 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ReconstructingRegressor(keras.models.Model):\n",
|
||||
"class ReconstructingRegressor(keras.Model):\n",
|
||||
" def __init__(self, output_dim, **kwargs):\n",
|
||||
" super().__init__(**kwargs)\n",
|
||||
" self.hidden = [keras.layers.Dense(30, activation=\"selu\",\n",
|
||||
" kernel_initializer=\"lecun_normal\")\n",
|
||||
" for _ in range(5)]\n",
|
||||
" self.out = keras.layers.Dense(output_dim)\n",
|
||||
" self.reconstruct = keras.layers.Dense(8) # workaround for TF issue #46858\n",
|
||||
" self.reconstruction_mean = keras.metrics.Mean(name=\"reconstruction_error\")\n",
|
||||
"\n",
|
||||
" #Commented out due to TF issue #46858, see the note above\n",
|
||||
" #def build(self, batch_input_shape):\n",
|
||||
" # n_inputs = batch_input_shape[-1]\n",
|
||||
" # self.reconstruct = keras.layers.Dense(n_inputs)\n",
|
||||
" # super().build(batch_input_shape)\n",
|
||||
" def build(self, batch_input_shape):\n",
|
||||
" n_inputs = batch_input_shape[-1]\n",
|
||||
" self.reconstruct = keras.layers.Dense(n_inputs)\n",
|
||||
" #super().build(batch_input_shape)\n",
|
||||
"\n",
|
||||
" def call(self, inputs, training=None):\n",
|
||||
" Z = inputs\n",
|
||||
|
|
Loading…
Reference in New Issue