Add a basic implementation of MultiHeadAttention

main
Aurélien Geron 2019-04-16 19:52:49 +08:00
parent d24fee1576
commit 61bb73074b
1 changed files with 75 additions and 5 deletions

View File

@ -2109,6 +2109,13 @@
"decoder_in = positional_encoding(decoder_embeddings)" "decoder_in = positional_encoding(decoder_embeddings)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here is a (very) simplified Transformer (the actual architecture has skip connections, layer norm, dense nets, and most importantly it uses Multi-Head Attention instead of regular Attention):"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 131, "execution_count": 131,
@ -2128,6 +2135,69 @@
"outputs = output_layer(final_enc)" "outputs = output_layer(final_enc)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Here's a basic implementation of the `MultiHeadAttention` layer. One will likely be added to `keras.layers` in the near future. Note that `Conv1D` layers with `kernel_size=1` (and the default `padding=\"valid\"` and `strides=1`) is equivalent to a `TimeDistributed(Dense(...))` layer."
]
},
{
"cell_type": "code",
"execution_count": 132,
"metadata": {},
"outputs": [],
"source": [
"K = keras.backend\n",
"\n",
"class MultiHeadAttention(keras.layers.Layer):\n",
" def __init__(self, n_heads, causal=False, use_scale=False, **kwargs):\n",
" self.n_heads = n_heads\n",
" self.causal = causal\n",
" self.use_scale = use_scale\n",
" super().__init__(**kwargs)\n",
" def build(self, batch_input_shape):\n",
" self.dims = batch_input_shape[0][-1]\n",
" self.q_dims, self.v_dims, self.k_dims = [self.dims // self.n_heads] * 3 # could be hyperparameters instead\n",
" self.q_linear = keras.layers.Conv1D(self.n_heads * self.q_dims, kernel_size=1, use_bias=False)\n",
" self.v_linear = keras.layers.Conv1D(self.n_heads * self.v_dims, kernel_size=1, use_bias=False)\n",
" self.k_linear = keras.layers.Conv1D(self.n_heads * self.k_dims, kernel_size=1, use_bias=False)\n",
" self.attention = keras.layers.Attention(causal=self.causal, use_scale=self.use_scale)\n",
" self.out_linear = keras.layers.Conv1D(self.dims, kernel_size=1, use_bias=False)\n",
" super().build(batch_input_shape)\n",
" def _multi_head_linear(self, inputs, linear):\n",
" shape = K.concatenate([K.shape(inputs)[:-1], [self.n_heads, -1]])\n",
" projected = K.reshape(linear(inputs), shape)\n",
" perm = K.permute_dimensions(projected, [0, 2, 1, 3])\n",
" return K.reshape(perm, [shape[0] * self.n_heads, shape[1], -1])\n",
" def call(self, inputs):\n",
" q = inputs[0]\n",
" v = inputs[1]\n",
" k = inputs[2] if len(inputs) > 2 else v\n",
" shape = K.shape(q)\n",
" q_proj = self._multi_head_linear(q, self.q_linear)\n",
" v_proj = self._multi_head_linear(v, self.v_linear)\n",
" k_proj = self._multi_head_linear(k, self.k_linear)\n",
" multi_attended = self.attention([q_proj, v_proj, k_proj])\n",
" shape_attended = K.shape(multi_attended)\n",
" reshaped_attended = K.reshape(multi_attended, [shape[0], self.n_heads, shape_attended[1], shape_attended[2]])\n",
" perm = K.permute_dimensions(reshaped_attended, [0, 2, 1, 3])\n",
" concat = K.reshape(perm, [shape[0], shape_attended[1], -1])\n",
" return self.out_linear(concat)"
]
},
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"Q = np.random.rand(2, 50, 512)\n",
"V = np.random.rand(2, 80, 512)\n",
"multi_attn = MultiHeadAttention(8)\n",
"multi_attn([Q, V]).shape"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
@ -2167,7 +2237,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 132, "execution_count": 134,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2212,7 +2282,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 133, "execution_count": 135,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2229,7 +2299,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 134, "execution_count": 136,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2246,7 +2316,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 135, "execution_count": 137,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@ -2267,7 +2337,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 136, "execution_count": 138,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [