1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
| class GroupedQueryAttention(nn.Module):
def __init__(
self,
hidden_size: int = 8192,
num_query_heads: int = 64,
num_kv_heads: int = 4, # GQA关键:KV头数量远少于Q头
head_dim: int = 128
):
super().__init__()
self.num_query_heads = num_query_heads
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
# Q头数必须能被KV头数整除
assert num_query_heads % num_kv_heads == 0
self.num_queries_per_kv = num_query_heads // num_kv_heads
# 投影层
self.q_proj = nn.Linear(hidden_size, num_query_heads * head_dim)
self.k_proj = nn.Linear(hidden_size, num_kv_heads * head_dim)
self.v_proj = nn.Linear(hidden_size, num_kv_heads * head_dim)
self.o_proj = nn.Linear(num_query_heads * head_dim, hidden_size)
# RoPE位置编码
self.rotary_emb = RotaryPositionalEmbedding(head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
batch_size, seq_len, _ = hidden_states.shape
# 计算Q, K, V
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
# 重塑为多头格式
queries = queries.view(batch_size, seq_len, self.num_query_heads, self.head_dim)
keys = keys.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
values = values.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# 应用RoPE
queries, keys = self.rotary_emb(queries, keys, position_ids)
# GQA核心:将KV头复制以匹配Q头数量
keys = self.repeat_kv(keys, self.num_queries_per_kv)
values = self.repeat_kv(values, self.num_queries_per_kv)
# 计算注意力分数
attn_weights = torch.matmul(queries, keys.transpose(-2, -1)) / math.sqrt(self.head_dim)
# 应用注意力掩码
if attention_mask is not None:
attn_weights += attention_mask
# Softmax
attn_weights = torch.softmax(attn_weights, dim=-1)
# 应用注意力权重
attn_output = torch.matmul(attn_weights, values)
# 重塑并投影输出
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_len, -1)
attn_output = self.o_proj(attn_output)
return attn_output
def repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""重复KV头以匹配Q头数量"""
if n_rep == 1:
return hidden_states
batch, seq_len, n_kv_heads, head_dim = hidden_states.shape
hidden_states = hidden_states.unsqueeze(3).repeat(1, 1, 1, n_rep, 1)
return hidden_states.view(batch, seq_len, n_kv_heads * n_rep, head_dim)
|