当前位置 博文首页 > 炫云云:MultiHeadAttention、Transformer

    炫云云:MultiHeadAttention、Transformer

    作者:[db:作者] 时间:2021-09-10 10:16

    MultiHeadAttention

    class MultiHeadAttention(Layer):
        """ ayer
        refer to "Attention is all you Need" .
        If `query`, `key,` `value` are the same, then this is self-attention. 
        Each timestep in `query` attends to the corresponding sequence in `key`, and returns a fixed-width vector.
        """
        def __init__(self,
                     num_heads,
                     head_size,
                     dropout_rate=0.0,
                     kernel_initializer="glorot_uniform",
                     bias_initializer="zeros",
                     kernel_regularizer=None,
                     bias_regularizer=None,
                     activity_regularizer=None,
                     kernel_constraint=None,
                     bias_constraint=None,
                     **kwargs):
            super(MultiHeadAttention, self).__init__(**kwargs)
            self._num_heads = num_heads
            self._head_size = head_size
            self._dropout_rate = dropout_rate
            self._kernel_initializer = initializers.get(kernel_initializer)
            self._bias_initializer = initializers.get(bias_initializer)
            self._kernel_regularizer = regularizers.get(kernel_regularizer)
            self._bias_regularizer = regularizers.get(bias_regularizer)
            self._kernel_constraint = constraints.get(kernel_constraint)
            self._bias_constraint = constraints.get(bias_constraint)
            self._activity_regularizer = regularizers.get(activity_regularizer)
    
            self._query_dense = dense_einsum.DenseEinsum(
                output_shape=(self._num_heads, self._head_size),
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                name="attention/query")
    
            self._key_dense = dense_einsum.DenseEinsum(
                output_shape=(self._num_heads, self._head_size),
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                name="attention/key")
    
            self._value_dense = dense_einsum.DenseEinsum(
                output_shape=(self._num_heads, self._head_size),
                kernel_initializer=self._kernel_initializer,
                bias_initializer=self._bias_initializer,
                kernel_regularizer=self._kernel_regularizer,
                bias_regularizer=self._bias_regularizer,
                activity_regularizer=self._activity_regularizer,
                kernel_constraint=self._kernel_constraint,
                bias_constraint=self._bias_constraint,
                name="attention/value")
    
            self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
    
            self._dropout = keras.layers.Dropout(rate=self._dropout_rate)
    
        def build(self, input_shape):
            self._query_dense.build(input_shape[0])
            self._key_dense.build(input_shape[1])
            self._value_dense.build(input_shape[1])
            super().build(input_shape)
    
        def get_config(self):
            config = {
                "num_heads":
                    self._num_heads,
                "head_size":
                    self._head_size,
                "dropout_rate":
                    self._dropout_rate,
                "kernel_initializer":
                    tf.keras.initializers.serialize(self._kernel_initializer),
                "bias_initializer":
                    tf.keras.initializers.serialize(self._bias_initializer),
                "kernel_regularizer":
                    tf.keras.regularizers.serialize(self._kernel_regularizer),
                "bias_regularizer":
                    tf.keras.regularizers.serialize(self._bias_regularizer),
                "activity_regularizer":
                    tf.keras.regularizers.serialize(self._activity_regularizer),
                "kernel_constraint":
                    tf.keras.constraints.serialize(self._kernel_constraint),
                "bias_constraint":
                    tf.keras.constraints.serialize(self._bias_constraint)
            }
            base_config = super(MultiHeadAttention, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))
    
        def call(self, inputs):
            from_tensor = inputs[0]
            to_tensor = inputs[1]
            attention_mask = inputs[2] if len(inputs) == 3 else None
    
            # Scalar dimensions referenced here:
            #   B = batch size (number of sequences)
            #   F = `from_tensor` sequence length
            #   T = `to_tensor` sequence length
            #   N = `num_attention_heads`
            #   H = `size_per_head`
            # `query_tensor` = [B, F, N ,H]
            query_tensor = self._query_dense(from_tensor)
    
            # `key_tensor` = [B, T, N, H]
            key_tensor = self._key_dense(to_tensor)
    
            # `value_tensor` = [B, T, N, H]
            value_tensor = self._value_dense(to_tensor)
    
            # Take the dot product between "query" and "key" to get the raw
            # attention scores.
            attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
            attention_scores = tf.multiply(attention_scores,
                                           1.0 / math.sqrt(float(self._head_size)))
    
            # Normalize the attention scores to probabilities.
            # `attention_probs` = [B, N, F, T]
            attention_probs = self._masked_softmax([attention_scores, attention_mask])# [B, 1, F, T]
            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self._dropout(attention_probs)
    
            # `context_layer` = [B, F, N, H]
            return tf.einsum("BNFT,BTNH->BFNH", attention_probs, value_tensor)
    

    MaskedSoftmax

    class MaskedSoftmax(Layer):
        """Performs a softmax with optional masking on a tensor.
    
            Arguments:
            mask_expansion_axes: Any axes that should be padded on the mask tensor.
        """
    
        def __init__(self, mask_expansion_axes = None, normalization_axes = None, **kwargs):
            self._mask_expansion_axes = mask_expansion_axes
            if normalization_axes is None:
                self._normalization_axes = (-1,)
            else:
                self._normalization_axes = normalization_axes
            super(MaskedSoftmax, self).__init__(**kwargs)
    
        def call(self, inputs):
            if isinstance(inputs, list) and len(inputs) == 2:
                scores, mask = inputs
            else:
                scores, mask = (inputs, None)
    
            if mask is not None:
                if self._mask_expansion_axes is not None:
                    mask = tf.expand_dims(mask, axis = self._mask_expansion_axes)
    
                # 因为attention_mask矩阵,要attention的位置是1.0,mask的位置是0,
                # 接下来把mask变成-10000.0,这样softmax概率低
                adder = (1.0 - tf.cast(mask, scores.dtype)) * -1000000.0
                scores += adder  # 在softmax之前将其添加到原始分数中
    
            if len(self._normalization_axes) == 1:
                return tf.nn.softmax(scores, axis = self._normalization_axes[0])
            else:
                return tf.math.exp(scores - tf.math.reduce_logsumexp(
                        scores, axis = self._normalization_axes, keepdims = True))
    
        def get_config(
    
    下一篇:没有了