Conclusion: the overall structure flow chart of t5 transformers

Keywords: Python Pytorch Deep Learning

In order to better understand the content of t5 model structure, the overall structure process of t5 model is given here

t5 overall structure and process

During the operation of t5, the key is changed_ States and values_ Value of States

layerselfattention of 6 encoder parts

Enter hidden_staes = (1,8,11,64)
First call query_states

query_states = shape(self.q(hidden_states))

obtain

query_states = (1,8,11,64)

Then enter key_states and values_ states

# get key/value states
key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

The statement called here is

def project(hidden_states, proj_layer, key_value_states, past_key_value):
    """projects hidden states correctly to key/query states"""
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, seq_length, dim_per_head)
        hidden_states = shape(proj_layer(hidden_states))

Get key_states and values_ Contents of States

key_states = (1,8,11,64)
value_states = (1,8,11,64)

Next, position_bias calculation

............
else:
    position_bias = self.compute_bias(real_seq_length, key_length)

Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket

relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the encoder part, so the passed parameter is True.
The position calculated here_ The content of bias is

position_bias = (1,8,11,64)

Next, call mask

if mask is not None:
	position_bias = position_bias+mask

The mask here is either zero or None. Ignore it.
Then run the code behind the program

scores += position_bias
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
............
return outputs

For the first time, the layerselfattention of the six decoder parts is called

Enter hidden_ States = (1,1512), next call

query_states = shape(self.q(hidden_states))

Get query_ Parameters for States

query_states = (1,8,1,64)

Next, call key_states and values_ Contents of States

# get key/value states
key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

The statement that is called in the project function here.

if key_value_states is None:
    # self-attn
    # (batch_size, n_heads, seq_length, dim_per_head)
    hidden_states = shape(proj_layer(hidden_states))

Enter hidden here_ States is also (1,1512), and then passes through two linear network layers to output the key_states and values_ Contents of States

key_states = (1,8,1,64)
value_states = (1,8,1,64)

Then enter position_bias calculation

if position_bias is None:
   if not self.has_relative_attention_bias:
       position_bias = torch.zeros(
           (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
       )
       #if self.gradient_checkpointing and self.training:
       #    position_bias.requires_grad = True
   else:
       position_bias = self.compute_bias(real_seq_length, key_length)

       # if key and values are already calculated
       # we want only the last query position bias
       if past_key_value is not None:
           position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket

relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the decoder part, so the passed parameter is False.
The position calculated here_ The content of bias is

position_bias = (1,8,11,64)

The position calculated here for the first time_ Content of bias

position_bias = 
tensor([[[[ 3.5000]],
         [[ 0.4531]],
         [[ 3.1875]],
         [[ 0.9727]],
         [[-5.4688]],
         [[ 5.1875]],
         [[ 2.1562]],
         [[ 0.5391]]]])

Then add position_bias, output after a wave of conventional operation

scores += position_bias
............
outputs = (attn_output,)+(present_key_value_state,)+(position_bias,)

For the first time, the layercrossettion of the six decoder parts is called

Here is the procedure called at the beginning

batch_size,seq_length = hidden_states.shape[:2]
real_seq_length = seq_length

Results obtained

batch_size = 1,seq_length = 1,real_seq_length = 1

Then call

key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

The corresponding parameter is key_length = 11
Next, call query_states

query_states = shape(self.q(hidden_states))

Get query_ Contents of States

query_states = (1,1,512)

Then call key_. States and values_ Contents of States

# get key/value states
key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)

Input here

key_value_states = (1,11,64)

For the content obtained from the previous six encoder network layers, the key in the first layercrossettion_ states,value_ All States are controlled by key_values get

elif past_key_value is None:
    # cross-attn
    # (batch_size, n_heads, seq_length, dim_per_head)
    hidden_states = shape(proj_layer(key_value_states))

Then position is called_ Content of bias

if position_bias is None:
    if not self.has_relative_attention_bias:
        position_bias = torch.zeros(
            (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
        )
        ............

Position here_ Bias is all zero position_bias content.
Then perform some routine operations

scores += position_bias
attn_weights = nn.functional.softmax(scores.float(),dim=-1).type_as(scores)
......

Finally, the general output content section

present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

For the second time, call the layerselfattention of the six decoder parts

(the second time here is to call the T5 layerselfattention of 6 encoders and the T5 layerselfattention and T5 layercrossettion contents of 6 encoders in the decoder)
The second time here is equivalent to the second time running to a new position after predicting the first value. The past called here_ key_ Value [0] is equivalent to the key output from the same layer in the previous position_ states,past_key_value[1] is equivalent to the value output from the same layer at the previous position_ States (for example, here is the self layerattention of the second wave of 6 encoders + 3 decoders + 4 decoders, so the front is equivalent to the content of the self layerattention of the first wave of 6 encoders + 3 decoders + 4 decoders)
Next enter

key_states = project(
    hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
)
value_states = project(
    hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
)
if past_key_value is not None:
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, key_length, dim_per_head)
        hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
    else:
        # cross-attn
        hidden_states = past_key_value

Here, the first if will be called if it is T5 layerselfattention, and the second if will be called if it is crossbattention
If it is t5layerselfattment, the following code will be called in the project function

if past_key_value is not None:
    if key_value_states is None:
        # self-attn
        # (batch_size, n_heads, key_length, dim_per_head)
        hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
        ............
return hidden_states

Obtain the output content in the second wave

key_states.size = torch.Size([1, 8, 2, 64])
value_states.size = torch.Size([1, 8, 2, 64])

Next, call the scores content

# compute scores
scores = torch.matmul(
    query_states, key_states.transpose(3, 2)
)  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

Results obtained

scores = torch.Size([1, 8, 1, 2])

Next, look at position_bias calculation

if position_bias is None:
     if not self.has_relative_attention_bias:
         position_bias = torch.zeros(
             (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
         )
         if self.gradient_checkpointing and self.training:
             position_bias.requires_grad = True
     else:
         position_bias = self.compute_bias(real_seq_length, key_length)

Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket

relative_position_bucket = self._relative_position_bucket(
    relative_position,  # shape (query_length, key_length)
    bidirectional=(not self.is_decoder),
    num_buckets=self.relative_attention_num_buckets,
)

Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the decoder part, so the passed parameter is False.
The position calculated here_ The content of bias is

position_bias = (1,8,11,64)

For the next operation, there is a corresponding line of small characters:

if key and values are already calculated,
we want only the last query position bias.

Call the corresponding code

if past_key_value is not None:
   position_bias = position_bias[:, :, -hidden_states.size(1) :, :]

Note that the last dimension is taken out. After taking it out, position_bias = (1,8,1,2)
The position obtained here_ Bias results

position_bias = torch.Size([1, 8, 2, 2])

Here is the original position_ Extensions of bias, such as the original position_ The content of bias is

position_bias = 
tensor([[[[ 3.5000]],

         [[ 0.4531]],

         [[ 3.1875]],

         [[ 0.9727]],

         [[-5.4688]],

         [[ 5.1875]],

         [[ 2.1562]],

         [[ 0.5391]]]])

Current position_bias is

position_bias = 
tensor([[[[ 3.9844,  3.5000]],

         [[ 1.2266,  0.4531]],

         [[ 4.3438,  3.1875]],

         [[ 2.0312,  0.9727]],

         [[ 0.7969, -5.4688]],

         [[ 4.9375,  5.1875]],

         [[ 4.7500,  2.1562]],

         [[ 4.5000,  0.5391]]]])

Then call the statement.

scores += position_bias
#scores = (1,8,1,2)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
    scores
)  # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.dropout(
    attn_weights, p=self.dropout, training=self.training
)  # (batch_size, n_heads, seq_length, key_length)

# Mask heads if we want to
if layer_head_mask is not None:
    attn_weights = attn_weights * layer_head_mask

So far, the contents of scores are (1,8,1,2)
Next call

attn_output = unshape((torch.matmul(attn_weights,value_states))

attn_weights = (1,8,1,2),value_states = (1,8,2,64)
Multiply to get the result (1,8,1,64)
Then output after using unshape

attn_output = unshape(torch.matmul(attn_weights,value_states))
#attn_output = (1,1,512)
attn_output = self.o(attn_output)

Obtain results

attn_output = (1,1,512)

For the second time, call the layercrossettion of the six decoder parts

Posted by EPJS on Tue, 30 Nov 2021 23:13:03 -0800