In order to better understand the content of t5 model structure, the overall structure process of t5 model is given here
t5 overall structure and process
During the operation of t5, the key is changed_ States and values_ Value of States
layerselfattention of 6 encoder parts
Enter hidden_staes = (1,8,11,64)
First call query_states
query_states = shape(self.q(hidden_states))
obtain
query_states = (1,8,11,64)
Then enter key_states and values_ states
# get key/value states key_states = project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) value_states = project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None )
The statement called here is
def project(hidden_states, proj_layer, key_value_states, past_key_value): """projects hidden states correctly to key/query states""" if key_value_states is None: # self-attn # (batch_size, n_heads, seq_length, dim_per_head) hidden_states = shape(proj_layer(hidden_states))
Get key_states and values_ Contents of States
key_states = (1,8,11,64) value_states = (1,8,11,64)
Next, position_bias calculation
............ else: position_bias = self.compute_bias(real_seq_length, key_length)
Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket
relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, )
Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the encoder part, so the passed parameter is True.
The position calculated here_ The content of bias is
position_bias = (1,8,11,64)
Next, call mask
if mask is not None: position_bias = position_bias+mask
The mask here is either zero or None. Ignore it.
Then run the code behind the program
scores += position_bias attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( scores ) # (batch_size, n_heads, seq_length, key_length) ............ return outputs
For the first time, the layerselfattention of the six decoder parts is called
Enter hidden_ States = (1,1512), next call
query_states = shape(self.q(hidden_states))
Get query_ Parameters for States
query_states = (1,8,1,64)
Next, call key_states and values_ Contents of States
# get key/value states key_states = project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) value_states = project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None )
The statement that is called in the project function here.
if key_value_states is None: # self-attn # (batch_size, n_heads, seq_length, dim_per_head) hidden_states = shape(proj_layer(hidden_states))
Enter hidden here_ States is also (1,1512), and then passes through two linear network layers to output the key_states and values_ Contents of States
key_states = (1,8,1,64) value_states = (1,8,1,64)
Then enter position_bias calculation
if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) #if self.gradient_checkpointing and self.training: # position_bias.requires_grad = True else: position_bias = self.compute_bias(real_seq_length, key_length) # if key and values are already calculated # we want only the last query position bias if past_key_value is not None: position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket
relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, )
Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the decoder part, so the passed parameter is False.
The position calculated here_ The content of bias is
position_bias = (1,8,11,64)
The position calculated here for the first time_ Content of bias
position_bias = tensor([[[[ 3.5000]], [[ 0.4531]], [[ 3.1875]], [[ 0.9727]], [[-5.4688]], [[ 5.1875]], [[ 2.1562]], [[ 0.5391]]]])
Then add position_bias, output after a wave of conventional operation
scores += position_bias ............ outputs = (attn_output,)+(present_key_value_state,)+(position_bias,)
For the first time, the layercrossettion of the six decoder parts is called
Here is the procedure called at the beginning
batch_size,seq_length = hidden_states.shape[:2] real_seq_length = seq_length
Results obtained
batch_size = 1,seq_length = 1,real_seq_length = 1
Then call
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
The corresponding parameter is key_length = 11
Next, call query_states
query_states = shape(self.q(hidden_states))
Get query_ Contents of States
query_states = (1,1,512)
Then call key_. States and values_ Contents of States
# get key/value states key_states = project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) value_states = project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None )
Input here
key_value_states = (1,11,64)
For the content obtained from the previous six encoder network layers, the key in the first layercrossettion_ states,value_ All States are controlled by key_values get
elif past_key_value is None: # cross-attn # (batch_size, n_heads, seq_length, dim_per_head) hidden_states = shape(proj_layer(key_value_states))
Then position is called_ Content of bias
if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) ............
Position here_ Bias is all zero position_bias content.
Then perform some routine operations
scores += position_bias attn_weights = nn.functional.softmax(scores.float(),dim=-1).type_as(scores) ......
Finally, the general output content section
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
For the second time, call the layerselfattention of the six decoder parts
(the second time here is to call the T5 layerselfattention of 6 encoders and the T5 layerselfattention and T5 layercrossettion contents of 6 encoders in the decoder)
The second time here is equivalent to the second time running to a new position after predicting the first value. The past called here_ key_ Value [0] is equivalent to the key output from the same layer in the previous position_ states,past_key_value[1] is equivalent to the value output from the same layer at the previous position_ States (for example, here is the self layerattention of the second wave of 6 encoders + 3 decoders + 4 decoders, so the front is equivalent to the content of the self layerattention of the first wave of 6 encoders + 3 decoders + 4 decoders)
Next enter
key_states = project( hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None ) value_states = project( hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None )
if past_key_value is not None: if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) hidden_states = torch.cat([past_key_value, hidden_states], dim=2) else: # cross-attn hidden_states = past_key_value
Here, the first if will be called if it is T5 layerselfattention, and the second if will be called if it is crossbattention
If it is t5layerselfattment, the following code will be called in the project function
if past_key_value is not None: if key_value_states is None: # self-attn # (batch_size, n_heads, key_length, dim_per_head) hidden_states = torch.cat([past_key_value, hidden_states], dim=2) ............ return hidden_states
Obtain the output content in the second wave
key_states.size = torch.Size([1, 8, 2, 64]) value_states.size = torch.Size([1, 8, 2, 64])
Next, call the scores content
# compute scores scores = torch.matmul( query_states, key_states.transpose(3, 2) ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
Results obtained
scores = torch.Size([1, 8, 1, 2])
Next, look at position_bias calculation
if position_bias is None: if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: position_bias = self.compute_bias(real_seq_length, key_length)
Note the self. Compute calculated here_ Bias calls self_ relative_ position_ Parameters passed in from bucket
relative_position_bucket = self._relative_position_bucket( relative_position, # shape (query_length, key_length) bidirectional=(not self.is_decoder), num_buckets=self.relative_attention_num_buckets, )
Here, the bidirectional parameter passed in the encoder part is True, and the bidirectional parameter passed in the decoder part is False. Now this is the decoder part, so the passed parameter is False.
The position calculated here_ The content of bias is
position_bias = (1,8,11,64)
For the next operation, there is a corresponding line of small characters:
if key and values are already calculated, we want only the last query position bias.
Call the corresponding code
if past_key_value is not None: position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
Note that the last dimension is taken out. After taking it out, position_bias = (1,8,1,2)
The position obtained here_ Bias results
position_bias = torch.Size([1, 8, 2, 2])
Here is the original position_ Extensions of bias, such as the original position_ The content of bias is
position_bias = tensor([[[[ 3.5000]], [[ 0.4531]], [[ 3.1875]], [[ 0.9727]], [[-5.4688]], [[ 5.1875]], [[ 2.1562]], [[ 0.5391]]]])
Current position_bias is
position_bias = tensor([[[[ 3.9844, 3.5000]], [[ 1.2266, 0.4531]], [[ 4.3438, 3.1875]], [[ 2.0312, 0.9727]], [[ 0.7969, -5.4688]], [[ 4.9375, 5.1875]], [[ 4.7500, 2.1562]], [[ 4.5000, 0.5391]]]])
Then call the statement.
scores += position_bias #scores = (1,8,1,2) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( scores ) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.dropout( attn_weights, p=self.dropout, training=self.training ) # (batch_size, n_heads, seq_length, key_length) # Mask heads if we want to if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask
So far, the contents of scores are (1,8,1,2)
Next call
attn_output = unshape((torch.matmul(attn_weights,value_states))
attn_weights = (1,8,1,2),value_states = (1,8,2,64)
Multiply to get the result (1,8,1,64)
Then output after using unshape
attn_output = unshape(torch.matmul(attn_weights,value_states)) #attn_output = (1,1,512) attn_output = self.o(attn_output)
Obtain results
attn_output = (1,1,512)