一个关于使用gpt2生成的past_key_values的问题

0 人关注

最近,我遇到了一个问题:如何使用gpt2生成的past_key_values。下面是演示。

# last_hidden_states, h_s in short
# past_key_values, p_k_v in short
# A_h_s.shape = (bs, A_len, hs=768)
(_, A_p_k_v, A_h_s) = gpt2_model(A_input_ids, A_token_type_ids, A_attention_mask, A_position_ids)
# B_h_s.shape = (bs, B_len, hs=768)
(_, B_p_k_v, B_h_s) = gpt2_model(B_input_ids, B_token_type_ids, B_attention_mask, A_position_ids)
# Do some operations on A_h_s, such as integrating some external knowledge
A_h_s = do_something(A_h_s)  # (bs, A_len, hs=768)
The following parts are the problem.
During the training, I hope to be able to use A_h_s and B_h_s to predict C.
What am I supposed to do?
# (bs, A_len + B_len + C_len)
attention_mask = torch.cat([A_attention_mask, B_attention_mask, C_attention_mask], dim=-1)
position_ids = torch.cumsum(C_attention_mask , dim=-1)[:,-c_len:].type_as(C_input_ids) - 1
past_key_values = ?????