当前位置 博文首页 > 炫云云:MultiHeadAttention、Transformer
class MultiHeadAttention(Layer):
""" ayer
refer to "Attention is all you Need" .
If `query`, `key,` `value` are the same, then this is self-attention.
Each timestep in `query` attends to the corresponding sequence in `key`, and returns a fixed-width vector.
"""
def __init__(self,
num_heads,
head_size,
dropout_rate=0.0,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self._num_heads = num_heads
self._head_size = head_size
self._dropout_rate = dropout_rate
self._kernel_initializer = initializers.get(kernel_initializer)
self._bias_initializer = initializers.get(bias_initializer)
self._kernel_regularizer = regularizers.get(kernel_regularizer)
self._bias_regularizer = regularizers.get(bias_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
self._activity_regularizer = regularizers.get(activity_regularizer)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._head_size),
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention/value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = keras.layers.Dropout(rate=self._dropout_rate)
def build(self, input_shape):
self._query_dense.build(input_shape[0])
self._key_dense.build(input_shape[1])
self._value_dense.build(input_shape[1])
super().build(input_shape)
def get_config(self):
config = {
"num_heads":
self._num_heads,
"head_size":
self._head_size,
"dropout_rate":
self._dropout_rate,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint)
}
base_config = super(MultiHeadAttention, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) == 3 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
# T = `to_tensor` sequence length
# N = `num_attention_heads`
# H = `size_per_head`
# `query_tensor` = [B, F, N ,H]
query_tensor = self._query_dense(from_tensor)
# `key_tensor` = [B, T, N, H]
key_tensor = self._key_dense(to_tensor)
# `value_tensor` = [B, T, N, H]
value_tensor = self._value_dense(to_tensor)
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._head_size)))
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
attention_probs = self._masked_softmax([attention_scores, attention_mask])# [B, 1, F, T]
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
class MaskedSoftmax(Layer):
"""Performs a softmax with optional masking on a tensor.
Arguments:
mask_expansion_axes: Any axes that should be padded on the mask tensor.
"""
def __init__(self, mask_expansion_axes = None, normalization_axes = None, **kwargs):
self._mask_expansion_axes = mask_expansion_axes
if normalization_axes is None:
self._normalization_axes = (-1,)
else:
self._normalization_axes = normalization_axes
super(MaskedSoftmax, self).__init__(**kwargs)
def call(self, inputs):
if isinstance(inputs, list) and len(inputs) == 2:
scores, mask = inputs
else:
scores, mask = (inputs, None)
if mask is not None:
if self._mask_expansion_axes is not None:
mask = tf.expand_dims(mask, axis = self._mask_expansion_axes)
# 因为attention_mask矩阵,要attention的位置是1.0,mask的位置是0,
# 接下来把mask变成-10000.0,这样softmax概率低
adder = (1.0 - tf.cast(mask, scores.dtype)) * -1000000.0
scores += adder # 在softmax之前将其添加到原始分数中
if len(self._normalization_axes) == 1:
return tf.nn.softmax(scores, axis = self._normalization_axes[0])
else:
return tf.math.exp(scores - tf.math.reduce_logsumexp(
scores, axis = self._normalization_axes, keepdims = True))
def get_config(