DeepSeek最近憑借一己之力擊落Nvidia神話,從此硬件不再是AI的決定性條件,低成本訓(xùn)練LLM模型成為現(xiàn)實(shí)。之所以能夠實(shí)現(xiàn)低成本訓(xùn)練大模型,主要?dú)w功于DeepSeek底層采用的MLA和MoE結(jié)構(gòu),MoE結(jié)構(gòu)我們之前已經(jīng)介紹過,相比于訓(xùn)練一個(gè)統(tǒng)一的專業(yè)大模型,MoE通過訓(xùn)練N的專家模型,能夠更高效的實(shí)現(xiàn)推理任務(wù)。 
MLA(Multi-head Latent Attention)結(jié)構(gòu)是DeepSeek的核心創(chuàng)新之一,旨在優(yōu)化Transformer模型中的注意力機(jī)制,以減少推理過程中的KV緩存(鍵值緩存)并提高計(jì)算效率。其核心原理如下: 1. MLA的核心思想:低秩聯(lián)合壓縮MLA的核心思想是對鍵(Key)和值(Value)進(jìn)行低秩聯(lián)合壓縮。具體來說,傳統(tǒng)的Transformer模型在計(jì)算注意力時(shí),需要分別計(jì)算和存儲每個(gè)頭的鍵(K)和值(V),這會導(dǎo)致大量的內(nèi)存消耗和計(jì)算開銷。而MLA通過引入一個(gè)低秩向量(c),將K和V的存儲和計(jì)算過程進(jìn)行壓縮。具體步驟如下: - 計(jì)算低秩向量c:首先計(jì)算一個(gè)低秩向量c,其維度遠(yuǎn)小于K和V的維度。
- 轉(zhuǎn)換回K和V:然后通過兩個(gè)權(quán)重矩陣(W)將c轉(zhuǎn)換回K和V。
這樣,在推理階段,只需要存儲低秩向量c,而不是完整的K和V,從而顯著減少了KV緩存的大小。 2. 權(quán)重矩陣的合并與吸收MLA的另一大特點(diǎn)是權(quán)重矩陣的合并與吸收。傳統(tǒng)的MHA(Multi-head Attention)模型中,每個(gè)頭都有獨(dú)立的Q、K、V投影矩陣,而在MLA中,這些矩陣被合并和吸收,以減少參數(shù)量和計(jì)算量。具體來說: - Q的投影矩陣:Q的投影矩陣可以與后續(xù)的投影計(jì)算合并。
- K和V的投影矩陣:K和V的投影矩陣也可以進(jìn)行類似的合并操作。
通過這種合并,MLA在存儲和計(jì)算上都優(yōu)于傳統(tǒng)的MHA。例如,KV緩存的大小可以從 減少到 ,其中h是head的數(shù)量。 3. 解耦的旋轉(zhuǎn)位置編碼(RoPE)MLA還對旋轉(zhuǎn)位置編碼(RoPE)進(jìn)行了優(yōu)化。RoPE對K和Q的位置敏感,但由于低秩向量c改變了位置,因此無法直接使用原有的位置編碼。為了解決這個(gè)問題,MLA采用了以下方法: - 使用額外的多頭Q和共享的K:引入額外的多頭Q和共享的K來攜帶旋轉(zhuǎn)位置編碼。
- 合并計(jì)算:最后,將這些信息合并在一起進(jìn)行計(jì)算。
這種解耦的方法確保了位置編碼的正確性,同時(shí)不影響MLA的低秩壓縮和權(quán)重合并。 4. MLA的計(jì)算流程MLA的計(jì)算流程可以概括如下: - 低秩壓縮:Q和K路徑中分別進(jìn)行低秩壓縮,生成latent / low-rank部分。
- 解耦RoPE:K路徑中的一部分進(jìn)行解耦的RoPE變換。
- 緩存與合并:推理階段,壓縮后的KV矩陣和位置編碼后的矩陣被緩存,并進(jìn)行合并吸收。
- 輸出計(jì)算:最終,通過合并后的權(quán)重矩陣計(jì)算輸出。
5. MLA的優(yōu)勢- 減少KV緩存:通過低秩聯(lián)合壓縮,MLA顯著減少了KV緩存的大小。
- 降低計(jì)算成本:權(quán)重矩陣的合并與吸收減少了參數(shù)量和計(jì)算量。
- 提高推理效率:在相同的硬件條件下,MLA可以實(shí)現(xiàn)更長的上下文長度或更大的batch size,從而提高推理速度或吞吐量。
6. MLA與MQA的比較MLA的計(jì)算過程在某些方面與MQA(Multi-query Attention)類似,但MLA在QK維度上進(jìn)行了RoPE變換,并且V與未施加RoPE的K共享激活值,這使得MLA在保持高效的同時(shí),仍然能夠保持較好的模型性能。 QuantML-Qlib實(shí)現(xiàn)MoE之前已經(jīng)集成在我們的框架之中,詳見: QuantML-Qlib開發(fā)版 | MoE混合專家系統(tǒng)用于提升Transformer表現(xiàn)【附代碼】
本次我們將MLA結(jié)構(gòu)也集成進(jìn)框架之中,代碼路徑:examples/benchmarks/DeepSeek 
其中MLA結(jié)構(gòu)的代碼為:
class MultiHeadLatentAttention(nn.Module): def __init__(self, config: Config): super().__init__()
assert config.v_head_dim is not None , f"v_head_dim is not defined {config.v_head_dim=}" assert config.q_lora_rank is not None , f"q_lora_rank is not defined {config.q_lora_rank=}" assert config.kv_lora_rank is not None , f"kv_lora_rank is not defined {config.kv_lora_rank=}" assert config.rope_head_dim is not None , f"rope_head_dim is not defined {config.rope_head_dim=}"
self.config = config
self.dim = config.d_model self.num_heads = config.num_heads self.v_head_dim = config.v_head_dim
self.nope_head_dim = config.nope_head_dim self.rope_head_dim = config.rope_head_dim
self.q_lora_rank = config.q_lora_rank self.kv_lora_rank = config.kv_lora_rank
self.dropout = config.dropout
self.value_dim = self.num_heads * self.v_head_dim
# this is dims between wQ and wK self.nope_dim = self.num_heads * self.nope_head_dim self.rope_dim = self.num_heads * self.rope_head_dim
# query compression self.compress_q_linear = nn.Linear(self.dim, self.q_lora_rank, bias=False) # W_DQ self.decompress_q_nope = nn.Linear(self.q_lora_rank, self.nope_dim, bias=False) self.decompress_q_rope = nn.Linear(self.q_lora_rank, self.rope_dim, bias=False) self.q_norm = RMSNorm(dim=self.q_lora_rank) # key and value compression self.compress_kv_linear = nn.Linear(self.dim, self.kv_lora_rank, bias=False) # W_DKV self.decompress_k_nope = nn.Linear(self.kv_lora_rank, self.nope_dim, bias=False) self.decompress_v_linear = nn.Linear(self.kv_lora_rank, self.value_dim, bias=False) self.kv_norm = RMSNorm(dim=self.kv_lora_rank)
self.k_rope_linear = nn.Linear(self.dim, self.rope_head_dim , bias=False) # self.rope_norm = RMSNorm(self.rope_dim) # not in deepseekv2
self.proj = nn.Linear(self.value_dim , self.dim, bias=False) self.res_dropout = nn.Dropout(p=config.dropout)
def forward(self, x: Tensor,mask: torch.Tensor, freqs_cis: Tensor): batch_size, seq_len, _ = x.shape
compressed_q = self.compress_q_linear(x) norm_q = self.q_norm(compressed_q) query_nope:Tensor = self.decompress_q_nope(norm_q) query_rope:Tensor = self.decompress_q_rope(norm_q)
compressed_kv = self.compress_kv_linear(x) norm_kv = self.kv_norm(compressed_kv) key_nope: Tensor = self.decompress_k_nope(norm_kv) value: Tensor = self.decompress_v_linear(norm_kv)
key_rope:Tensor = self.k_rope_linear(x) # norm_rope = self.rope_norm(key_rope)
query_nope = query_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2) query_rope = query_rope.view(batch_size, seq_len, self.num_heads, self.rope_head_dim).transpose(1,2)
key_rope = key_rope.view(batch_size, seq_len, 1, self.rope_head_dim).transpose(1,2) key_nope = key_nope.view(batch_size, seq_len, self.num_heads, self.nope_head_dim).transpose(1,2)
value = value.view(batch_size, seq_len, self.num_heads, self.v_head_dim).transpose(1,2)
# *** the line that fixes MLA :) *** key_rope = key_rope/self.num_heads
q_rope,k_rope = apply_rope(query_rope,key_rope, cis=freqs_cis)
q_recombined = torch.empty((batch_size,self.num_heads,seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device) k_recombined = torch.empty((batch_size, self.num_heads, seq_len, self.rope_head_dim + self.nope_head_dim), device=x.device)
q_recombined[:,:,:,:self.nope_head_dim] = query_nope q_recombined[:,:,:,self.nope_head_dim:] = q_rope
# k_rope = torch.repeat_interleave(k_rope, self.num_heads, dim=1) # >> you dont need to do this << # ?? broadcasting will do replication krope to all heads automagically k_recombined[:,:,:,:self.nope_head_dim] = key_nope k_recombined[:,:,:,self.nope_head_dim:] = k_rope
output = F.scaled_dot_product_attention(q_recombined, k_recombined, value, is_causal=True, dropout_p=self.dropout)
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_heads * self.v_head_dim)
output = self.proj(output) output = self.res_dropout(output) return output
簡單測試了一下,500內(nèi)的IC為0.025,歡迎繼續(xù)測試優(yōu)化代碼結(jié)果。
|