日韩黑丝制服一区视频播放|日韩欧美人妻丝袜视频在线观看|九九影院一级蜜桃|亚洲中文在线导航|青草草视频在线观看|婷婷五月色伊人网站|日本一区二区在线|国产AV一二三四区毛片|正在播放久草视频|亚洲色图精品一区

分享

為什么我還是無法理解transformer?

 時間煮墨 2025-04-15 發(fā)布于江西

像我這種業(yè)余的AI愛好者而不熟悉 NLP 的人來說,第一次看到 Attention is all you need 的論文是一臉懵逼的,它的架構(gòu)遠比CNN、MLP、RNN這些復(fù)雜多了。

不過后來了解了一些前因后果之后,還是摸到了一點線索。Transformer 結(jié)構(gòu)的很多設(shè)計,都是為了解決NLP之前模型(比如RNN、Seq2Seq)的問題,如果不了解這些問題的話就不能理解為什么要這么設(shè)計。為此,有必要對 NLP 任務(wù)的一些特點進行全面回顧。

1 自然語言處理基本概念

自然語言處理技術(shù) (Natural Language Processing, NLP) 的基本目標是讀懂一句話。這里的一句話,在電腦中被建模成一個序列,其中的每個元素是一個詞元 (token)。自然語言處理技術(shù)主要包括以下目標:

  1. 詞法分析:將句子中的詞元進行分類,比如名詞、動詞、代詞等。

  2. 句法分析:將句子中的詞元進行組合,比如名詞短語、句子等。

  3. 語義分析:根據(jù)句子中的詞元,判斷句子的語義。

  4. 機器翻譯:根據(jù)輸入的句子,輸出具有相同含義的另一種語言的句子。

  5. 文本生成:根據(jù)輸入的提示詞續(xù)寫文本。

  6. 機器對話:根據(jù)輸入的用戶提問,輸出對應(yīng)的回答。

如果模型能完全讀懂一個句子,那么以上問題就可以解決。

在計算機視覺領(lǐng)域,人們習(xí)慣于將基本的圖像分類模型作為基礎(chǔ)骨架來完成其他任務(wù)。但是在 NLP 中,對詞法的準確分析難度很高,而且即使能夠做到準確的詞法分析,依然難以理解句子的完整語義。這是因為自然語言往往具有非常長程的關(guān)聯(lián)關(guān)系,某個詞的語義可能由整個輸入文本共同決定

因此,現(xiàn)在人們傾向于拋棄那些著眼于文本局部特征的任務(wù),像是詞法、句法、語義分析之類的,而直接用大模型完成那些需要關(guān)注全局語義的文本生成、機器對話等任務(wù)。實踐表明,這樣反而取到了更好的效果。

1.1 詞元

最簡單的想法是,英語中的一個單詞或者中文中的一個字,就是一個詞元。但實際上,詞元的劃分要更加復(fù)雜一點,一個單詞或字可能由多個詞元表示,某些詞元也可以表示多個單詞或者多個字。

不論如何,詞元可能的種類數(shù)量是有限的,所以可以把它們編制成詞表,起到類似字典的作用。這樣,就可以用詞表中的某個序號來唯一表示某個詞元了,從而完成對自然語言的數(shù)字化。這種方案和 Unicode 編碼表有點類似,只不過編碼表是嚴格的一個字符對應(yīng)一個序號,而詞表則是根據(jù)自然語言的規(guī)律確定的,與具體的模型選擇有關(guān)。

設(shè)詞表為 V = {t_1, t_2, ..., t_N} ,其中 N 為詞表大小, t_i 表示第 i 個詞元。對于輸入文本序列 S = (w_1, w_2, ..., w_T) ,經(jīng)過分詞器處理后得到詞元序列:

S_{\text{tokenized}} = (t_{i_1}, t_{i_2}, ..., t_{i_M}) \quad \text{其中} \ t_{i_k} \in V

例如,“原神啟動”就可以這樣變?yōu)?\{342, 645, 7544\} ,分別為 “原”、“神”、“啟動”三個詞元在詞表中的序號。

僅僅是序號仍不適合作為神經(jīng)網(wǎng)絡(luò)的輸入和輸出,因為這是一種整數(shù)量,而詞表的前后關(guān)系一般來說并不具有特別明顯的數(shù)學(xué)運算意義。

如果把"選擇某個詞元"這件事情看作是分類問題 (而不是回歸問題) 的話,那么生成自然語言的模型輸出,就應(yīng)當是一個離散的概率分布而不是一個序號。和圖像分類一樣,這種輸出是一個向量,每個元素表示對應(yīng)序號的詞元的概率。對于確定的結(jié)果,那么就是對應(yīng)的數(shù)字為 1 而其他數(shù)字為 0。這就自然引出了獨熱 (One-hot) 編碼:用一個和詞表長度一樣長的向量來表示,它哪個元素是1就表示這是哪個詞元:

\mathbf{e}_i = [0,...,0,\underbrace{1}_{\text{第}i\text{位}},0,...,0] \in \{0,1\}^N

其中 \mathbf{e}_i 表示第 i 個詞元的獨熱編碼向量,滿足 |\mathbf{e}_i|_1 = 1 。當模型預(yù)測結(jié)果為概率分布 \mathbf{p} \in [0,1]^N 時,需滿足 \sum_{k=1}^N p_k = 1 ,這通常通過softmax函數(shù)實現(xiàn):

p_i = \dfrac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}

其中 z_i 為模型對第 i 個詞元的原始輸出分數(shù)。

1.2 詞向量

要想讓詞元的數(shù)學(xué)運算關(guān)系更加明確一點的話,就需要使用詞向量。它通過把詞元對應(yīng)到一個固定的向量中,從而讓詞元的語義能和向量的數(shù)學(xué)運算 (加法和數(shù)乘) 對應(yīng)起來。給定詞表 V ,每個詞 w 都會對應(yīng)到詞向量 \vec{v}_w\in\mathbb{R}^d ,其中 d 為詞向量維度。

詞向量空間中的幾何關(guān)系,就表示了詞的語義信息。例如類比關(guān)系可表示為:

\phi(\text{"國王"}) - \phi(\text{"男人"}) + \phi(\text{"女人"}) \approx \phi(\text{"女王"})

其中 \phi 是把詞元映射為詞向量的函數(shù)。

詞向量的有效性源于分布式假設(shè):具有相似上下文的詞具有相似語義。通過最小化預(yù)測誤差,模型迫使語義相近的詞在向量空間中彼此靠近,同時保持詞對間的方向關(guān)系對應(yīng)語義關(guān)系。相比于高度稀疏、占空間的獨熱編碼,詞向量維度較低、緊湊,而且具有一定的數(shù)學(xué)結(jié)構(gòu),天然適合于優(yōu)化計算 (也就是模型訓(xùn)練)。這種表示方法為后續(xù)的神經(jīng)網(wǎng)絡(luò)語言模型奠定了基礎(chǔ)。

在數(shù)學(xué)中,Embedding (嵌入) 是一種特殊的映射,其核心目標是保留原空間的結(jié)構(gòu)或性質(zhì),同時將對象映射到另一個(通常是更低維或更簡單的)空間中。在這里我們也是把高維但稀疏的詞元編號 (獨熱編碼) 映射到低維的詞向量空間中,因此詞向量也常常被成為詞嵌入 (Word Embedding)。

2 循環(huán)神經(jīng)網(wǎng)絡(luò)

循環(huán)神經(jīng)網(wǎng)絡(luò) (Recurrent Neural Network, RNN) 是一種專門用于處理序列數(shù)據(jù)的神經(jīng)網(wǎng)絡(luò)架構(gòu)。給定輸入序列 \vec{X} = (\vec{x}_1, \vec{x}_2, ..., \vec{x}_T) ,其中 \vec{x}_t \in \mathbb{R}^d 表示第 t 個時間步的輸入向量, d 為輸入特征維度, T 為序列長度。在自然語言處理中, \vec{x}_t 就是一個詞元所對應(yīng)的詞向量,而 d 就是詞向量的維度。

RNN通過引入隱藏狀態(tài) \vec{h}_t \in \mathbb{R}^{d_h} 來捕捉序列的時序信息,其中 d_h 為隱藏狀態(tài)維度。在每個時間步 t ,RNN根據(jù)當前輸入 \vec{x}_t 和前一時刻的隱藏狀態(tài) \vec{h}_{t-1} 計算當前隱藏狀態(tài):

\vec{h}_t = \sigma(W^{h} \vec{h}_{t-1} + W^{x} \vec{x}_t + \vec_h)

其中 W^{h} \in \mathbb{R}^{d_h \times d_h} 為隱藏狀態(tài)權(quán)重矩陣, W^{x} \in \mathbb{R}^{d_h \times d} 為輸入權(quán)重矩陣, \vec_h \in \mathbb{R}^{d_h} 為偏置項, \sigma 為激活函數(shù)(通常為tanh或ReLU)。輸出 \vec{y}_t \in \mathbb{R}^m 通過以下公式計算:

\vec{y}_t = W^{y} \vec{h}_t + \vec_y

其中 W^{y} \in \mathbb{R}^{N \times d_h} 為輸出權(quán)重矩陣, \vec_y \in \mathbb{R}^N 為輸出偏置項, N 為輸出維度。

這種結(jié)構(gòu)使得RNN能夠處理任意長度的序列,并通過隱藏狀態(tài)在時間步之間傳遞信息,從而捕捉序列中的長期依賴關(guān)系。一般來說,對于生成字符的應(yīng)用,比如機器翻譯而言, N 是詞表的長度,而 y_t^{(n)} 就對應(yīng)于詞表中第 n 個詞元的概率分數(shù) (因而 \vec{y}_{t} 的意義及維度和輸入詞元的向量 \vec{x}_t 完全不同)。

3 序列到序列模型

傳統(tǒng) RNN 每讀取一個詞元,就要輸出一個詞元,輸入和輸出序列長度必須相同。這限制了其在機器翻譯等任務(wù)中的應(yīng)用,因為這些任務(wù)通常需要將不同長度的源語言序列轉(zhuǎn)換為目標語言序列。序列到序列模型 (Sequence to Sequence, Seq2Seq) 通過編碼器-解碼器架構(gòu)解決了這些問題,它允許輸入和輸出序列具有不同的長度,特別適用于機器翻譯、文本摘要等任務(wù)。

3.1 編碼器

給定輸入序列 \vec{X} = (\vec{x}_1, \vec{x}_2, ..., \vec{x}_{T^{\text{enc}}}) ,其中 \vec{x}_t \in \mathbb{R}^d 表示第 t 個時間步的輸入向量, d 為輸入特征維度, T^{\text{enc}} 為輸入序列長度。編碼器 (Encoder) 將整個輸入序列編碼為固定維度的上下文向量 \vec{c} \in \mathbb{R}^{d_h} ,其中 d_h 為隱藏狀態(tài)維度(為簡化表示,編碼器和解碼器使用相同的隱藏狀態(tài)維度,實際中可以不同)。編碼器在每個時間步 t 計算隱藏狀態(tài) \vec{h}_t :

\vec{h}_t = \sigma(W^{h, \text{enc}} \vec{h}_{t-1} + W^{x, \text{enc}} \vec{x}_t + \vec^{h, \text{enc}})

其中 W^{h, \text{enc}} \in \mathbb{R}^{d_h \times d_h} 為編碼器隱藏狀態(tài)權(quán)重矩陣, W^{x, \text{enc}} \in \mathbb{R}^{d_h \times d} 為編碼器輸入權(quán)重矩陣, \vec^{h, \text{enc}} \in \mathbb{R}^{d_h} 為編碼器偏置項, \sigma 為激活函數(shù)(通常為tanh或ReLU)。編碼過程可以表示為:

\vec{c} = f_{\text{enc}}(\vec{X}) = \vec{h}_{T^{\text{enc}}}

其中 f_{\text{enc}} 表示編碼器把詞向量的序列轉(zhuǎn)化為上下文向量的過程。在這里,最簡單的選擇是,把 RNN 讀取完整個輸入句子的最后一個隱藏狀態(tài) \vec{h}_{T^{\text{enc}}} 作為上下文向量 \vec{c} 。這基本上就要求 RNN 在閱讀句子的過程中,能在隱藏狀態(tài)里記住整個句子的信息。

3.2 解碼器

讀完句子之后,就輪到解碼器 (Decoder) 干活了,它開始一個詞一個詞地進行生成。我們用 \tau 表示解碼器的時間步,以區(qū)別于編碼器的時間步 t 。解碼器根據(jù)上下文向量 \vec{c} 生成輸出序列 \vec{Y} = (\vec{y}_1, \vec{y}_2, ..., \vec{y}_{T^{\text{dec}}}) ,其中 \vec{y}_{\tau} \in \mathbb{R}^N 表示第 \tau 個時間步的輸出向量, N 為輸出特征維度, T^{\text{dec}} 為輸出序列長度。和一般的 RNN 一樣,對于機器翻譯這類生成字符的應(yīng)用而言,輸出的 \vec{y}_{\tau} 就是詞表中各個詞元的概率分數(shù),通過 softmax 函數(shù)即可直接轉(zhuǎn)換為概率。

在每個輸出的時間步 \tau ,解碼器根據(jù)前一時刻的輸出 \vec{y}_{\tau-1} 、隱藏狀態(tài) \vec{s}_{\tau-1} 以及上下文向量 \vec{c} 計算當前輸出:

\vec{y}_{\tau} = f_{\text{dec}}(\vec{y}_{\tau-1}, \vec{s}_{\tau-1}, \vec{c})

解碼器的隱藏狀態(tài) \vec{s}_{\tau} 則和一般的 RNN 一樣,由前一個 \vec{s}_{\tau - 1} 轉(zhuǎn)移而來。只不過需要多加一個上下文向量 \vec{c} ,以及還要讀取上一個輸出的字符 \vec{y}_{\tau - 1} :

\vec{s}_{\tau} = \sigma(W^{s, \text{dec}} \vec{s}_{\tau-1} + W^{y, \text{dec}} \vec{y}_{\tau-1} + W^{c, \text{dec}} \vec{c} + \vec^{s, \text{dec}})

其中 W^{s, \text{dec}} 為解碼器隱藏狀態(tài)權(quán)重矩陣, W^{y, \text{dec}} 為解碼器輸出權(quán)重矩陣, W^{c, \text{dec}} 為解碼器上下文權(quán)重矩陣, \vec^{s, \text{dec}} 為解碼器偏置項, \sigma 為激活函數(shù)。這種計算方式使得解碼器能夠同時考慮前一時刻的隱藏狀態(tài)、輸出以及上下文信息,從而更好地生成當前時刻的輸出。

注意,量 \vec{y}_{\tau - 1} 嚴格來說表示的是前一時刻輸出詞元的概率分布,但是我們這里直接用它來表示已經(jīng)根據(jù)概率分布采樣完畢的、確定了是哪一個詞元的輸出。這種需要把已經(jīng)輸出的內(nèi)容,重新輸入回神經(jīng)網(wǎng)絡(luò)的行為,叫做自回歸生成(Auto-regressive Generation)。

4 Bahdanau 注意力機制

傳統(tǒng)的Seq2Seq模型將整個輸入序列編碼為固定長度的上下文向量 \vec{c} ,這在處理長序列時存在信息瓶頸問題。特別是當輸入序列較長時,RNN難以將所有相關(guān)信息壓縮到單個向量中,導(dǎo)致模型性能下降。

Bahdanau 等人引入的注意力機制 (Attention Mechanism),通過為每個解碼時間步 \tau 動態(tài)生成不同的上下文向量 \vec{c}_{\tau} 來解決這一問題。具體來說,在生成每個輸出 \vec{y}_{\tau} 時,解碼器會關(guān)注輸入序列的不同部分,而不是使用固定的上下文向量。這種關(guān)注程度通過注意力權(quán)重 \alpha_{\tau t} 來表示,即在生成第 \tau 個輸出時對第 t 個輸入編碼狀態(tài)的關(guān)注程度。

換句話說,注意力機制的重點,就是對輸入進行加權(quán)平均。在傳統(tǒng) Seq2Seq 中,要求把所有輸入句子的內(nèi)容,記在一個有限維度的上下文向量 \vec{c} 中,但凡句子稍微長一點,就肯定會出問題——你很難想象一個 2000 維度的上下文向量能表示一篇 1000 字長的文章。那如果在每一步解碼的時候,能有辦法把所有輸入步所對應(yīng)的隱藏狀態(tài) \vec{h}_t 利用起來,這樣對神經(jīng)網(wǎng)絡(luò)的要求就降低了很多,它只需要確保 t 步的隱藏狀態(tài) \vec{h}_t 能表示它周圍局部的語義即可。

所以現(xiàn)在我們明確了,需要用編碼器的所有隱藏狀態(tài) \{\vec{h}_t\}_{t=1}^{T^{\text{enc}}} 來作為解碼器的輸入,代替原有的、只用最后一個隱藏狀態(tài)的上下文向量 \vec{c}=\vec{h}_{T^{\text{enc}}} 。那么,下一個問題就是,怎么替換呢?

用一堆向量來表示一個向量,最簡單的做法就是用線性表示,即

\vec{c}_{\tau}^{\text{attn}} = \sum_{t=1}^{T^{\text{enc}}} \alpha_{\tau t} \vec{h}_t

其中 \alpha_{\tau t} 就是線性表示引入的不同的線性系數(shù)。這樣不同隱藏狀態(tài)的影響就有大有小,很符合直覺。輸入文本中有些字,和當前我們想要輸出的字沒什么關(guān)系;而另一些字則關(guān)系密切,在翻譯的時候必須考慮。

但是,輸入句子的長度 T^{\text{enc}} 是沒法固定的 (不然這也太不優(yōu)雅了,把循環(huán)神經(jīng)網(wǎng)絡(luò)的好處都扔掉了),所以線性系數(shù) \alpha_{\tau t} 不能直接用神經(jīng)網(wǎng)絡(luò)的參數(shù) (因為參數(shù)的數(shù)目在訓(xùn)練的時候就已經(jīng)定死了)。

既然如此,這個系數(shù)就應(yīng)當是一個和 t 以及 \tau 都有關(guān)的、動態(tài)的量。什么量會由時間步?jīng)Q定?那自然就是編碼器和解碼器的隱藏狀態(tài), \vec{h}_t以及\vec{s}_{\tau-1} 了。這里用了 \tau - 1 ,是因為我們還沒計算出來 \vec{s}_{\tau} 呢。當然,也可以考慮直接用詞元 \vec{x}_{t} 和 \vec{y}_{\tau-1} ,但這樣總歸是丟失了"詞在句子中的語義",而只剩下"詞本身的語義"了。注意我們這里為了避免引入過多變量,混淆了輸入詞元和詞向量,以及輸出詞元分布、詞元和詞向量。

于是現(xiàn)在,我們要用兩個隱藏狀態(tài)向量,來計算一個數(shù)字,作為線性表示的系數(shù),或者說注意力權(quán)重。那此時就可以用神經(jīng)網(wǎng)絡(luò)的參數(shù)來計算啦,畢竟隱藏狀態(tài)向量的維度是固定的。所以注意力權(quán)重通過以下方式計算:

\alpha_{\tau t} = \dfrac{\exp(e_{\tau t})}{\sum_{k=1}^{T^{\text{enc}}} \exp(e_{\tau k})}

并且

e_{\tau t} = {\vec{v}^{\text{attn}}} \cdot \tanh(W^{s, \text{attn}} \vec{s}_{\tau-1} + W^{h, \text{attn}} \vec{h}_t + \vec^{\text{attn}})

其中 W^{s, \text{attn}} 和 W^{h, \text{attn}} 為注意力機制中的神經(jīng)網(wǎng)絡(luò)參數(shù)矩陣, \vec{v}^{\text{attn}} 為注意力機制中的參數(shù)列向量, \vec^{\text{attn}} 為注意力機制的偏置項。這里除了用神經(jīng)網(wǎng)絡(luò)來計算注意力權(quán)重之外,還要加一個 softmax 歸一化,這樣所有注意力權(quán)重的和就是 1 了。

解碼器在生成每個輸出、更新隱藏狀態(tài)時,就使用 \vec{c}_{\tau}^{\text{attn}} 這個動態(tài)上下文向量,來替換原有的上下文向量 \vec{c} :

\vec{s}_{\tau} = \sigma(W^{s, \text{dec}} \vec{s}_{\tau-1} + W^{y, \text{dec}} \vec{y}_{\tau-1} + W^{c, \text{dec}} \vec{c}_{\tau} + \vec^{s, \text{dec}})

其中 W^{s, \text{dec}} 、 W^{y, \text{dec}} 和 W^{c, \text{dec}} 為解碼器的參數(shù)矩陣, \vec^{s, \text{dec}} 為解碼器的偏置項。這種機制允許模型在處理長序列時更有效地利用輸入信息,顯著提高了Seq2Seq模型在機器翻譯等任務(wù)中的性能。

5 Transformer

只有在理解了 RNN 、Seq2Seq 以及為何需要在 Seq2Seq 中引入注意力機制之后,才能理解,為什么 Transformer 的論文要取名為 Attention is all you need,以及作者們在論文標題中,暗示我們用注意力機制取代的那種東西,到底是什么?當然,或許可以猜到他們想要干掉 RNN。所以為什么要干掉 RNN ?

在這里,我們對上一節(jié)介紹用 RNN 進行自然語言處理的發(fā)展脈絡(luò),尤其是每提出一種結(jié)構(gòu)所解決的問題以及引入的新問題,做一個回顧。

模型/技術(shù)核心結(jié)構(gòu)優(yōu)勢問題
基礎(chǔ)RNN循環(huán)隱藏狀態(tài)處理變長序列輸出與輸入長度一致
Seq2Seq編碼器-解碼器架構(gòu)輸入輸出長度解耦上下文向量成為信息瓶頸
注意力機制動態(tài)上下文向量長程依賴捕捉前后依賴關(guān)系, 難以并行
Transformer自注意力機制完全并行計算上下文平方增長計算量

5.1 自注意力機制

Transformer 的核心是自注意力機制,它是這種新架構(gòu)能超越RNN的關(guān)鍵,因而它的思想值得細細體會。

5.1.1 自注意力機制的思想

如果說,原始的注意力機制是讓解碼器在輸出時,能夠得以關(guān)注編碼器在處理輸入時的信息。那么,如果把這種機制,用在加強文本自身的表述能力上,就形成了自注意力機制。這是 Transformer 最核心的模塊,所以我們最先討論其原理和思想,隨后再依次介紹詳細的架構(gòu)。

考慮這樣的敘事:

在一段文本中,某個詞元的真正語義,有可能要等到后續(xù)的所有內(nèi)容出現(xiàn)了之后,才能完全確定。

這種情況在英語中尤其常見,比如各種各樣的從句,而在漢語中也常常會遇到這種情況。當然,還有的時候哪怕文本讀完了都不能確定語義,但這一般就是病句或者碰到謎語人了?,F(xiàn)在我們的目標是讓模型能翻譯一般的句子,解讀謎語人有點太強模型所難了,我們暫且忽略這種情況。

那么 RNN 就有點難以處理這種情況。RNN 在讀取文字的時候,必須順序讀取,并且隱藏狀態(tài) \vec{h}_t 只由已經(jīng)讀過的詞元決定。這意味著,RNN 如果在讀到某個詞元時感到困惑了,那么它就不得不將這份困惑所對應(yīng)的隱藏狀態(tài)留在原地;而在后面的內(nèi)容也讀完、能確定這個詞元的語義時,才能在隱藏狀態(tài)中默默地把詞元的語義給更新一下。當然對于 RNN 來說,更新已經(jīng)讀取的詞元的語義并不是完全做不到的,只是確實不那么簡單。

那有沒有更加直觀、有效的方式,能夠直接表示出"詞元在句子中的語義"?自注意力機制,就是這種能夠完整利用整個句子的所有信息,來表達"某個詞元的在句子中語義"的結(jié)構(gòu)。

給定輸入序列 \vec{X} = (\vec{x}_1, \vec{x}_2, ..., \vec{x}_{T}) ,其中 \vec{x}_t \in \mathbb{R}^d 表示第 t 個時間步的輸入向量 (默認是列向量), d 為輸入特征維度, T 為輸入序列長度。下面的推導(dǎo)都很不嚴謹,不過我們不要在意這些,更多的是希望能馬后炮地探究一下,自注意力機制的形式為什么是論文中呈現(xiàn)的樣子。我們用 d 維列向量 \vec{z}_{t} 表示"第 t 個詞元在句子中的語義"這件事,那么其"語義應(yīng)由句子中的所有詞元共同決定"這件事,可以視作:

\vec{z}_{t} = f(\vec{x}_1, \vec{x}_2, ..., \vec{x}_{t}, \vec{x}_{t+1}, ..., \vec{x}_{T})

也就是所有這些向量的函數(shù)。用一堆向量映射到同樣維度的另一個向量,最簡單的映射方法當然是線性表示:

\vec{z}_{t} = \sum_{i=1}^{T} \alpha_{t i} \vec{x}_{i}

這樣這個式子就有點注意力機制那種在序列中加權(quán)求和的思路了。

不過,這個形式的有效性可能會有點問題:線性表示是和基底向量排列的順序無關(guān)的。也就是說理論上如果我把兩個詞元交換一下位置,其結(jié)果并不會改變。這顯然不符合自然語言的情況!不過我們先放下這個問題,后面我們會看到實際上進行這一步計算的向量中本身是帶了順序信息的,方法是位置編碼。這里,我們只需要知道這里的確考慮了順序信息,而依然使用線性表示即可。

現(xiàn)在我們來確定系數(shù)的形式,和原始的注意力機制一樣,因為序列是不定長的,所以不能直接拿模型參數(shù)作為系數(shù)。不過,假如說,我們考慮某個句子,其每個詞元的單獨的語義都和句子里的其他詞元沒什么關(guān)系,它們是正交的。那么此時,系數(shù)就可以用向量分解表示了,并且系數(shù)就是內(nèi)積:

\vec{z}_{t} \sim \vec{x}_t = \sum_{i=1}^{T} (\vec{x}_t \cdot \vec{x}_i) \vec{x}_{i}

這個式子當然不是 Self-Attention 的正確形式,它只是一個平凡的向量展開式而已,而且在大多數(shù)情況下都不滿足正交性的先決條件。但我們至少可以從它的數(shù)學(xué)結(jié)構(gòu) (特別要關(guān)注下角標) 中窺探到一些信息:線性組合的系數(shù)由當前詞元 \vec{z}_t 與其他詞元 \vec{x}_i, i=1,2,\cdots,T 的某個函數(shù) f(\vec{x}_t, \vec{x}_i) 決定。

那么具體這個函數(shù)是什么呢?我們可以考慮一下三個量的地位:

  1. 系數(shù)中的 \vec{x}_t ,表示的是當前詞元。有點像"試探者"的感覺,如果你學(xué)過大學(xué)物理,那么這有點像測試電荷的地位。我們姑且把它比作數(shù)據(jù)庫中的"查詢",也就是 Query。盡管我覺得這個比喻比較牽強,但 Transformer 原作者愿意使用這個比喻,并且以此給相關(guān)的向量或者矩陣來命名。

  2. 系數(shù)中的 \vec{x}_i ,表示的是其他詞元。有點像在大學(xué)物理中,有點像被測試的電荷分布的感覺。在數(shù)據(jù)庫的比喻中,這是"鍵",也就是 Key。

  3. 作為正交基底的 \vec{x}_i ,表示的是組成詞語在句子中的語義的基底。有點像在大學(xué)物理中,得到的各個電荷各自的電場力的感覺。在數(shù)據(jù)庫的比喻中,這是"值",也就是 Value。

既然三者地位不同,而對應(yīng)矢量的值直接表示語義,那么現(xiàn)在這些矢量也不能直接用詞向量 \vec{x} 了。應(yīng)該用什么,把孤立的詞向量的語義,轉(zhuǎn)化為在句子中具有對應(yīng)地位的語義呢?一個簡單的想法是,直接作線性變換:

\vec{q}_t = W^{Q}\vec{x}_t,\quad \vec{k}_t = W^{K}\vec{x}_t,\quad \vec{v}_t = W^{V}\vec{x}_t

就搞定了。我們可以期待,這種把孤立語義,轉(zhuǎn)化為具有在計算中的特定地位語義的行為,是相對固定的,也就是說,三個矩陣 W^{Q}, W^{K}, W^{V} 可以是模型參數(shù)。

現(xiàn)在用變換后的三個向量去替換那個平凡的向量展開式,得到

\vec{z}_t = \sum_{i=1}^{T} \alpha(\vec{q}_t, \vec{k}_i) \vec{v}_{i}

那么系數(shù) \alpha(\vec{q}_t, \vec{k}_i) 是什么形式呢?和上面一樣,就還是作點積,只是加一個 softmax 進行一下歸一化而已:

\alpha(\vec{q}_t, \vec{k}_i) = \text{Softmax}\left( \vec{q}_t \cdot \vec{k}_i / \sqrtgo7n8yo \right) = \dfrac{ \exp \left(\vec{q}_t \cdot \vec{k}_i / \sqrta9fczpo\right) }{ \sum_{j=1}^{T}\exp\left(\vec{q}_t \cdot \vec{k}_j / \sqrtzlmcszy\right) }

這里我們給點積還加了一個 1/\sqrtquch8ci 的因子,主要是希望詞向量的維度比較大的時候,數(shù)值穩(wěn)定性能好一點。另外,盡管這里加 softmax 看上去好像僅僅只是一個平凡的歸一化操作,但它的實際作用是非常大的。后面會更詳細地討論這些細節(jié)。

現(xiàn)在,我們將序列中的所有向量拼接成矩陣形式。令輸入序列的矩陣表示為 X = [\vec{x}_1, \cdots, \vec{x}_{T}] \in \mathbb{R}^{d \times T} ,其中每個 \vec{x}_t 是 d 維的列向量。類似地,我們可以定義:

  • 查詢矩陣 Q = W^Q X \in \mathbb{R}^{d \times T} ,其中 W^Q \in \mathbb{R}^{d \times d} 是查詢對應(yīng)的模型參數(shù)矩陣

  • 鍵矩陣 K = W^K X \in \mathbb{R}^{d \times T} ,其中 W^K \in \mathbb{R}^{d \times d} 是鍵對應(yīng)的模型參數(shù)矩陣

  • 值矩陣 V = W^V X \in \mathbb{R}^{d \times T} ,其中 W^V \in \mathbb{R}^{d \times d} 是值對應(yīng)的模型參數(shù)矩陣

則自注意力機制可以表示為矩陣運算:

Z = \left[ \text{Softmax}\left( \frac{Q^\top K}{\sqrtndixfjp} \right) \right]^\top V

其中 Z \in \mathbb{R}^{d \times T} 是自注意力層的輸出矩陣, \text{Softmax} 函數(shù)按列進行歸一化。這里的形式與原始 Transformer 論文中的形式略有不同,主要是因為我們將輸入向量定義為列向量,而原論文中采用的是行向量的定義方式。兩種定義在數(shù)學(xué)上是等價的,只是矩陣的轉(zhuǎn)置位置有所不同。

現(xiàn)在,我們分析自注意力機制的計算復(fù)雜度。對于輸入序列長度為 T 的情況,計算主要包含以下幾個步驟:

  1. 計算 Q, K, V 矩陣的復(fù)雜度為 O(d^2 T) ,其中 d 是向量維度;

  2. 計算注意力分數(shù)矩陣 Q^\top K 的復(fù)雜度為 O(d T^2) ;

  3. 計算 softmax 歸一化的復(fù)雜度為 O(T^2) ;

  4. 計算加權(quán)和 AV 的復(fù)雜度為 O(d T^2) 。

其中,后三個步驟的復(fù)雜度都與 T^2 成正比,因此自注意力機制的整體計算復(fù)雜度為 O(T^2) ,即隨著輸入上下文長度的平方增長。

這種平方增長的計算復(fù)雜度具有雙重意義:

一方面,每個詞元都需要與序列中所有其他詞元進行交互計算,這保證了每個字的語義以及字與字之間的關(guān)聯(lián)都能被充分理解和捕捉,這種全局的注意力機制是 Transformer 模型效果強大的重要保障;

另一方面,隨著輸入序列長度的增加,計算量會快速膨脹,這限制了 Transformer 模型能夠處理的輸入序列長度。我們和大模型聊天的時候,聊著聊著就發(fā)現(xiàn)大模型忘記了前面的內(nèi)容,原因正在于此。

5.1.2 位置編碼

實際上,在自注意力機制的思想中,可能會注意到一個問題。當時,我們說"詞元在句子中的語義",需要由整個句子中的所有詞元共同決定。我們使用了線性表示的公式,來表示這件事情。

但是,線性表示是與順序無關(guān)的。也就是說,如果交換了兩個詞元的順序,那么計算結(jié)果并不會變化。這顯然不是自然語言中的情況,哪種語言能隨意交換詞的順序,而絕不會改變語義呢?如果我們還想保留線性表示的思想,那就需要讓那些表示語義的向量包含能夠表示位置的信息。

所幸,確實有辦法能給詞向量添加位置信息,這種方法就叫做位置編碼 (Positional Encoding)。對于第 t 個詞向量 O(T^2) ,其經(jīng)過位置編碼之后的向量可以寫為

\tilde{\vec{x}}_t = \vec{x}_t + \vec{p}_t

其中 \vec{p}_t 就是第 t 個位置的位置編碼,它只和其在輸入序列中的位置,也就是 t 有關(guān)。在 Transformer 論文中,位置編碼向量是

\vec{p}_t = \begin{bmatrix} \sin(t/10000^{0/d}) \\ \cos(t/10000^{0/d}) \\ \sin(t/10000^{2/d}) \\ \cos(t/10000^{2/d}) \\ \vdots \\ \sin(t/10000^{(d-2)/d}) \\ \cos(t/10000^{(d-2)/d}) \end{bmatrix}

其中 d 是詞向量的維度, t 是詞在序列中的位置。

位置編碼初看上去還是相當奇怪的:這里和三角函數(shù)怎么扯上了關(guān)系?而且每個分量的頻率還不太一樣。這真的能把位置信息給弄進詞向量嗎?這真的不會破壞掉詞向量原本的信息嗎?

在這里提供一些唯象的、不嚴格的分析。

首先,我們考慮位置編碼會不會抹掉詞向量原本的信息。如果詞向量是一個一維的實數(shù) x_t ,那么讓它與另一個實數(shù) p_t 相加,的確就會抹掉原本的信息。但現(xiàn)在詞向量 \vec{x}_t 是一個維度很高的向量,而不是一個實數(shù)。不論與詞向量 \vec{x}_t 相加的位置編碼 \vec{p}_t 是什么,它都可以比較容易地被歸結(jié)為某個子空間內(nèi)的相加,從而在另一些子空間內(nèi)保留了原本的信息。盡管這種子空間可能是非線性的,但它只和位置有關(guān),仍很容易被神經(jīng)網(wǎng)絡(luò)學(xué)到。

然后,為什么是三角函數(shù)?我們能隱約感覺到,三角函數(shù)和旋轉(zhuǎn)、周期性、循環(huán)群之類的概念有關(guān)。而旋轉(zhuǎn)這件事天然就更加關(guān)注相對位置,而非絕對位置。為了說明這點,我們來計算兩個詞向量之間的內(nèi)積:

\begin{aligned} \vec{q}_{t_1} \cdot \vec{k}_{t_2} =& W^{Q}(\vec{x}_{t_1} + \vec{p}_{t_1}) \cdot W^{K}(\vec{x}_{t_2} + \vec{p}_{t_2}) \\ =& \underbrace{W^{Q}\vec{x}_{t_1} \cdot W^{K}\vec{x}_{t_2}}_{\text{孤立詞向量之間的關(guān)系}} + \underbrace{W^{Q}\vec{x}_{t_1} \cdot W^{K}\vec{p}_{t_2} + W^{Q}\vec{p}_{t_1} \cdot W^{K}\vec{x}_{t_2}}_{\text{詞向量與位置之間的關(guān)系}} + \underbrace{W^{Q}\vec{p}_{t_1} \cdot W^{K}\vec{p}_{t_2}}_{\text{相對位置關(guān)系}} \\ \end{aligned}

特別是,對于位置與位置之間的內(nèi)積,利用余弦和的公式,把奇偶的分量合并計算,我們有

\vec{p}_{t_1} \cdot \vec{p}_{t_2} = \sum_{k=0}^{d/2-1} \cos\left( w_k (t_1 - t_2) \right),\quad w_k = \frac{1}{10000^{2k/d}}

這意味著,任意兩個具有不同距離的位置編碼向量之間的內(nèi)積,都不太一樣,這種位置編碼適合用來區(qū)分不同詞元之間的相對位置。這就很契合自然語言的需求了:當我們在句子前加一些沒有意義的空格,我們當然希望位置編碼可以不要過多地改變詞向量的語義。

每個分量的頻率不一樣的話,就更有利于在不同尺度上考慮詞元與詞元的關(guān)系了。最上面那幾個分量頻率很高,就更加適合關(guān)注句子內(nèi)部相鄰兩個詞之間的關(guān)系;最下面的分量頻率很低,則更加適合表征大范圍的語義關(guān)系,從而在長上下文中非常有用。

當然也存在一些其他的位置編碼選擇,例如 RoPE、ALiBi 等,大多數(shù)方案都會著重表征相對位置而弱化絕對位置。例如,RoPE 對旋轉(zhuǎn)的運用比這里要更激進一些,例如直接把旋轉(zhuǎn)矩陣作用到 \vec{q} 和 \vec{k} 向量頭上了,這樣點乘的時候就會自然地引入角度,而額外的夾角就只和相對位置有關(guān)。

5.1.3 多頭自注意力機制

在介紹了位置編碼之后,我們即可對多頭自注意力部分進行嚴格定義。這里我們首先考慮編碼器中,第一層的多頭注意力 (Multi-Head Attention, MHA) 機制。堆疊之后的多頭注意力機制和這里的區(qū)別,僅僅只是輸入向量從經(jīng)過位置編碼的詞向量,變?yōu)橛汕懊娴膶铀敵龅哪切┫蛄慷选?/p>

接下來考慮什么是"多頭"。設(shè) h 是頭的數(shù)目,一般是向量維度 d 的整數(shù)因子。這樣,向量 \vec{q}_{t} 就可以被分成為 h 個長度為 \frachgmk2mt{h} 的子向量,其中每個子向量記作 {\vec{q}_{t}}^{(j)} \in\mathbb{R}^{d/h} ,其中 j=1,2,\cdots,h 。接下來,根據(jù)自注意力機制的手續(xù),我們對輸入向量作線性變換:

\begin{aligned} {\vec{q}_t}^{(j)} =& {W^{Q}}^{(j)} \vec{x}_{t},\quad j=1,2,\cdots,h \\ {\vec{k}_t}^{(j)} =& {W^{K}}^{(j)} \vec{x}_{t},\quad j=1,2,\cdots,h \\ {\vec{v}_t}^{(j)} =& {W^{V}}^{(j)} \vec{x}_{t},\quad j=1,2,\cdots,h \end{aligned}

其中 {W^{Q}}^{(j)}, {W^{K}}^{(j)}, {W^{V}}^{(j)} 分別是第 j 個頭所對應(yīng)的權(quán)重矩陣,維度是 \mathbb{R}^{\fracfaodrel{h} \times d} 。相當于是說,每個頭都有自己獨立的矩陣,把原本的 d 維向量通過線性變換降維成了 d/h 維的子向量。不過這倒不用擔(dān)心造成信息損失,因為所有頭的結(jié)果最后會拼接起來的,所以最后仍是 d 維向量。

然后,在某個頭內(nèi),自注意力的輸出子向量是

{\vec{z}_{t}}^{(j)} = \sum_{i=1}^{T} \alpha_{ti}^{(j)}{\vec{v}_{i}}^{(j)}

其中 \alpha_{ti}^{(j)} 是一個注意力權(quán)重,也就是線性表示的系數(shù)。老朋友了,按照慣例是點積+Softmax:

\alpha_{ti}^{(j)} = \dfrac{\exp\left( ({\vec{q}_t}^{(j)} \cdot {\vec{k}_i}^{(j)}) / \sqrt{d/h} \right) }{ \sum_{k=1}^{T} \exp \left( ({\vec{q}_t}^{(j)} \cdot {\vec{k}_k}^{(j)}) / \sqrt{d/h} \right) }

最后,把所有頭的輸出子向量拼接起來,再通過一個線性變換,即可得到多頭自注意力層的輸出向量:

\vec{z}_{t} = W^{O} \begin{bmatrix} {\vec{z}_{t}}^{(1)} \\ {\vec{z}_{t}}^{(2)} \\ \vdots \\ {\vec{z}_{t}}^{(h)} \end{bmatrix}

其中 W^{O} \in \mathbb{R}^{d \times d} 是輸出權(quán)重矩陣。

所以,為什么要多頭?八股文的說法是,這種多頭機制允許模型在不同的子空間中學(xué)習(xí)不同的注意力模式,從而捕捉更豐富的語義信息。然后在論文中作者做了消融實驗,發(fā)現(xiàn),總參數(shù)量相等的情況下,用多頭 (權(quán)重是 h 個 \fraclrwciwv{h}\times d 矩陣)確實比單頭 (權(quán)重是 1 個 d\times d 矩陣) 的效果好一點。

但還是很難看出來多頭和單頭的區(qū)別。如果全都是線性變換的話,這樣先劃分再拼接的操作似乎是不會有什么區(qū)別的,畢竟矩陣乘法本來就可以寫成分塊矩陣乘法再拼接的形式。所以我們這里做一些計算。我們設(shè) d=2 ,然后分別計算 h=1 (單頭注意力) 和 h=2 (多頭注意力) 兩種情況下的最終結(jié)果會有什么不同。

  1. 單頭注意力, h=1 。此時,所有計算都在單一子空間中進行。設(shè)輸入向量為 \vec{x}_t = [x_1, x_2]^\top ,則:

\vec{q}_t = W^{Q} \vec{x}_t = \begin{bmatrix} w_{11} & w_{12} \\ w_{21} & w_{22} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}

類似地計算 \vec{k}_t 和 \vec{v}_t 。注意力權(quán)重計算為:

\alpha_{ti} = \dfrac{\exp\left( (\vec{q}_t \cdot \vec{k}_i) / \sqrt{2} \right)}{\sum_{k=1}^{T} \exp\left( (\vec{q}_t \cdot \vec{k}_k) / \sqrt{2} \right)}

最終輸出為:

\vec{z}_t = W^{O} \left( \sum_{i=1}^{T} \alpha_{ti} \vec{v}_i \right)

  1. 多頭注意力, h=2 。此時,輸入向量被分成兩個子空間進行處理。設(shè)輸入向量為 \vec{x}_t = [x_1, x_2]^\top ,則對于每個頭 j=1,2 ,我們有:

{\vec{q}_t}^{(j)} = {W^{Q}}^{(j)} \vec{x}_t = \begin{bmatrix} w_{11}^{(j)} & w_{12}^{(j)} \end{bmatrix} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix}

類似地計算 {\vec{k}_t}^{(j)} 和 {\vec{v}_t}^{(j)} 。每個頭的注意力權(quán)重計算為:

\alpha_{ti}^{(j)} = \dfrac{\exp\left( ({\vec{q}_t}^{(j)} \cdot {\vec{k}_i}^{(j)}) / \sqrt{1} \right)}{\sum_{k=1}^{T} \exp\left( ({\vec{q}_t}^{(j)} \cdot {\vec{k}_k}^{(j)}) / \sqrt{1} \right)}

每個頭的輸出為:

{\vec{z}_{t}}^{(j)} = \sum_{i=1}^{T} \alpha_{ti}^{(j)} {\vec{v}_{i}}^{(j)}

最后將兩個頭的輸出拼接并通過輸出矩陣變換:

\vec{z}_t = W^{O} \begin{bmatrix} {\vec{z}_{t}}^{(1)} \\ {\vec{z}_{t}}^{(2)} \end{bmatrix}

這些式子還是很抽象,不過我們?nèi)绻^察一下注意力權(quán)重 \alpha 的角標,立刻就可以發(fā)現(xiàn),在單頭注意力中,總共只有 t 和 i 兩個角標,意味著有 T^2 個注意力權(quán)重;而在多頭注意力中,則有 t 、 i 和 j 三個角標,意味著有 h \times T^2 個注意力權(quán)重。有多少個頭,注意力權(quán)重就會多多少倍。

那么是什么造成了多頭注意力的權(quán)重數(shù)目更多呢?的確,之前說過,如果一切都是線性變換, \alpha_{ti} = \vec{q}_t\cdot\vec{k}_i ,那么這些把戲不過是把矩陣乘法變成了分塊矩陣的乘法。在最后一步的拼接、左乘 W^{O} 的過程中都會無效的 (因為這一步會把相關(guān)的 \alpha 加起來)。

但是 \alpha 不是通過線性變換求出來的,而是通過 Softmax 求出來的。這是一個非線性的變換,于是其效果就和單頭注意力的效果大不相同了, h\times T^2 個注意力權(quán)重全都是有用的!它真的可以像八股文中講的那樣讓 Transformer 關(guān)注到不同的特性。

由此可以看出,Softmax 函數(shù)并不是白加的。它有著更加重要的作用。

5.1.4 注意力權(quán)重中的 Softmax 函數(shù)

在多頭注意力機制中,我們看到了在計算注意力權(quán)重 \alpha_{ti} 時 Softmax 所發(fā)揮的作用:提供非線性?,F(xiàn)在我們直接看看,如果沒有 Softmax 會發(fā)生什么:

\vec{z}_t = \sum_{i=1}^{n} \alpha_{ti} \vec{v}_i = \sum_{i=1}^{n} (\vec{q}_t\cdot\vec{k}_i) \vec{v}_i

我們令 Z=[\vec{z}_1, \vec{z}_2, \dots, \vec{z}_T] , 那么 Z 就是一個 d\times T 的矩陣,其中每一列 \vec{z}_t 是一個 d 維的列向量。寫成矩陣形式之后:

Z^\top = \begin{bmatrix} \vec{z}_1^\top\\ \vdots\\ \vec{z}_T^\top \end{bmatrix} = \begin{bmatrix} \vec{q}_1^\top\\ \vdots\\ \vec{q}_T^\top \end{bmatrix} \begin{bmatrix} \vec{k}_1, \cdots, \vec{k}_T \end{bmatrix} \begin{bmatrix} \vec{v}_1^\top\\ \vdots\\ \vec{v}_T^\top \end{bmatrix} = (Q^\top K) V^\top

其中, Q^\top \in\mathbb{R}^{T\times d} , K \in \mathbb{R}^{d\times T} , V^\top \in \mathbb{R}^{T\times d} 。

看上去很美好,注意力權(quán)重組成的是 T\times T 的矩陣,它會將 V "完整地"映射到和它相同維度的空間中。但事實果真如此嗎?根據(jù)矩陣乘法結(jié)合律, K V^\top 是 d \times d 的矩陣。用了乘法結(jié)合律之后,這個式子變成了

\vec{z}_t^\top = \vec{q}_t^\top \left(\sum_{i=1}^{T} \vec{k}_i \vec{v}_i^\top \right) = \vec{q}_t^{\top} W^{KV}, \quad t=1,2,\cdots,T

其中

W^{KV} = \sum_{i=1}^{T} \vec{k}_i \vec{v}_i^\top \in \mathbb{R}^{d\times d}

是一個 d \times d 的矩陣,而且是與 t 無關(guān)的。這意味著,我們完全可以首先把 W^{KV} 計算出來 (復(fù)雜度是 O(T) ),然后再逐個應(yīng)用到 \vec{q}_t 上,復(fù)雜度仍是 O(T + T) = O(T) 。和標準的 Attention 不一樣,這是一個線性復(fù)雜度的模型!事實上,這正是 Linear Attention。

既然是線性復(fù)雜度的模型,那從計算復(fù)雜度上考慮,它天然就會和 RNN 有類似的問題:無法完整捕捉每個字與整個句子的關(guān)系。事實上,如果考慮流式的輸入,把 W^{KV} 的求和上限從輸入終點 T 改成 t 的話,那么

\vec{z}_t^\top = \vec{q}_t^\top \left(\sum_{i=1}^{t} \vec{k}_i \vec{v}_i^\top\right) = \vec{q}_t^\top W_t^{KV}

其中 W_t^{KV} 滿足

W_t^{KV} = W_{t-1}^{KV} + \vec{k}_t \vec{v}_t^\top

至此我們就揭開了 Linear Transformer 的真實面目:如果禁止 z_t 參考后文 (從而求和只能到 t 而不能到 T ) 的話,那 W_t^{KV} 不是別的,正是 RNN 的中間狀態(tài)。這件事情就變成了,每讀取一個字 \vec{x}_t ,中間狀態(tài) W_t^{KV} 就更新一次,然后輸出 \vec{z}_t 。

所以問題出在哪里?當我們考慮很長的輸入 T \gg d 時,可以發(fā)現(xiàn),線性注意力的權(quán)重 Q^\top K 的秩不會大于 d ,而遠小于 T 。這種低秩矩陣的狀態(tài)會嚴重影響 \vec{z}_t 的"豐富程度",從而降低其所蘊含的信息量。

那么 Softmax 函數(shù)是怎么解決這個低秩矩陣的問題的?很簡單,對于矩陣的每個元素都求指數(shù)并歸一化,這是個典型的非線性操作。線性代數(shù)可沒研究過這種非線性行為,這種操作完全有可能把一個稠密的低秩矩陣變成滿秩的矩陣。

那只要非線性就可以了嗎?我們最后來看看,到底是 Softmax 的哪一步能提升矩陣的秩??紤]函數(shù) f 滿足

f(Q^\top K) = \phi(Q)^\top \psi(K)

也就是說,它可以拆分成兩個矩陣各自對應(yīng)函數(shù)的乘法。那么,此時在計算注意力權(quán)重時,考慮

\vec{z}_t = \sum_{i=1}^{T} \alpha_{ti} \vec{v}_i = \sum_{i=1}^{T} \dfrac{ \phi(\vec{q}_t)\cdot \psi(\vec{k}_i)}{ \sum_{j=1}^{T} \phi(\vec{q}_t)\cdot \psi(\vec{k}_j) } \vec{v}_i

也就是說,把 e^{\vec{q}_t \cdot \vec{k}_i} 這種內(nèi)積整體的指數(shù)函數(shù),拆分成了先進行非線性函數(shù)計算,然后再進行內(nèi)積計算的形式。注意,原來的指數(shù)函數(shù)本身是不能這么拆的,這里是把指數(shù)函數(shù)替換成了一個能拆的非線性函數(shù)。除了分母上的 \sqrt9xlzyds 之外,它與原始的注意力機制就僅僅只有"可分離性"這點差別,而沒有其他任何差別了。同樣有非線性,也同樣有分母的歸一化。

我們來證明,這仍是 Linear Attention。

令 \Phi(Q) = [\phi(\vec{q}_1), \dots, \phi(\vec{q}_T)]^\top 和 \Psi(K) = [\psi(\vec{k}_1), \dots, \psi(\vec{k}_T)] ,則分子部分可表示為 \phi(\vec{q}_t)^\top \psi(\vec{k}_i) = [\Phi(Q) \Psi(K)]_{ti} 。歸一化因子可展開為:

\sum_{j=1}^{T^{\text{enc}}} \phi(\vec{q}_t) \cdot \psi(\vec{k}_j) = \phi(\vec{q}_t) \cdot S, \quad S = \sum_{j=1}^{T} \psi(\vec{k}_j) \in \mathbb{R}^tmty6ip

這是一個全局統(tǒng)計量,與 t 無關(guān)。將注意力權(quán)重代入輸出計算:

Z^\top = \left[ \frac{\phi(\vec{q}_1)}{\phi(\vec{q}_1) \cdot S} \sum_{i=1}^{T} \psi(\vec{k}_i) \vec{v}_i^\top, \dots, \frac{\phi(\vec{q}_T)}{\phi(\vec{q}_T) \cdot S} \sum_{i=1}^{T} \psi(\vec{k}_i) \vec{v}_i^\top \right]

定義 W = \sum_{i=1}^{T} \psi(\vec{k}_i) \vec{v}_i^\top ,則輸出簡化為:

Z^\top = \tilde{\Phi}(Q) \cdot W, \quad \tilde{\Phi}(Q) = \left[ \frac{\phi(\vec{q}_1)}{\phi(\vec{q}_1) \cdot S}, \dots, \frac{\phi(\vec{q}_T)}{\phi(\vec{q}_T) \cdot S} \right]

現(xiàn)在來看復(fù)雜度。

  1. 計算 S 的復(fù)雜度為 O(Td) ;

  2. 計算 W 的復(fù)雜度為 O(Td^2) ;

  3. 計算 \tilde{\Phi}(Q) 的復(fù)雜度為 O(Td) ;

  4. 最終矩陣乘法的復(fù)雜度為 O(Td) ;

總時間復(fù)雜度為 O(Td^2) ,屬于線性復(fù)雜度。因此,即使引入非線性變換和歸一化,該機制仍屬于 Linear Attention 的范疇。至此,我們可以下結(jié)論了:正是 Softmax 中指數(shù)函數(shù)不可拆分的特性,讓低秩矩陣變得具有更高的秩,從而保障了 Transformer 能正確給出字與字兩兩之間的關(guān)聯(lián)。

當然,在 Softmax 中分母進行歸一化則可以讓數(shù)值更加穩(wěn)定 (注意力權(quán)重處于 0 到 1 之間),而且能盡可能地把那些與當前字 \vec{z}_t 不太關(guān)聯(lián)的字 \vec{x}_i 的權(quán)重趨于 0 (這是 Softmax 函數(shù)的特性,只有最大的幾個值能顯著大于 0),從而讓 Transformer 能夠更加聚焦于少數(shù)幾個關(guān)鍵的概念。

可以注意到,這一點和多頭注意力機制產(chǎn)生了奇妙的聯(lián)動:多頭注意力只有在像 Softmax 這種非線性函數(shù)中才有實際的作用,而另一方面,正是 Softmax 函數(shù)極度稀疏的特性,才使得我們需要引入多頭注意力,讓權(quán)重不要過于稀疏,從而能以多個角度關(guān)注不同的概念。

5.1.5 點積縮放的作用

在自注意力機制中,點積除以 \sqrtvc62ubr (在多頭注意力中是除以 \sqrt{d/h} ) 的主要作用是保持數(shù)值穩(wěn)定性。假設(shè) \vec{q}_t 和 \vec{k}_i 的每個分量都是獨立同分布的隨機變量,均值為 0,方差為 1。那么點積的期望為 0,方差為 d 。因此,對點積除以標準差 \sqrt7ujgwd6 ,即可提升不同維度向量作點積的數(shù)值穩(wěn)定性。

下面我們來證明點積的方差為 d 。我們有

\text{Var}[\vec{q}_t \cdot \vec{k}_i] = \text{Var}\left[\sum_{j=1}^y4cs6pe q_{jt}k_{ji}\right]

由于 q_{jt} 和 k_{ji} 是獨立同分布的隨機變量,且不同維度之間相互獨立,因此方差可以拆分為各維度方差之和:

\text{Var}[\vec{q}_t \cdot \vec{k}_i]= \sum_{j=1}^14rpuks \text{Var}[q_{jt}k_{ji}]

根據(jù)方差的性質(zhì),對于獨立隨機變量 X 和 Y ,有

\text{Var}(XY) = \mathbb{E}[X^2]\mathbb{E}[Y^2] - \mathbb{E}[X]^2\mathbb{E}[Y]^2

因此:

\text{Var}[\vec{q}_t \cdot \vec{k}_i]= \sum_{j=1}^jmaofcs \left(\mathbb{E}[q_{jt}^2]\mathbb{E}[k_{ji}^2] - \mathbb{E}[q_{jt}]^2\mathbb{E}[k_{ji}]^2\right)

根據(jù)題設(shè), q_{jt} 和 k_{ji} 均值為 0,方差為 1,即 \mathbb{E}[q_{jt}^2] = \mathbb{E}[k_{ji}^2] = 1 , \mathbb{E}[q_{jt}] = \mathbb{E}[k_{ji}] = 0 。代入得:

\text{Var}[\vec{q}_t \cdot \vec{k}_i]= \sum_{j=1}^cxujamk (1 \times 1 - 0 \times 0) = \sum_{j=1}^b49639i 1 = d

可以看到,點積的方差隨著維度 d 線性增長。當 d 較大時,點積的絕對值會變得很大,這會導(dǎo)致 softmax 函數(shù)的輸入值過大,使得梯度變得非常?。ㄌ荻认栴})。通過除以 \sqrtwromu7z ,我們可以將點積的方差歸一化為 1:

\text{Var}\left[\dfrac{\vec{q}_t \cdot \vec{k}_i}{\sqrtgay1bpw} \right] = 1

這樣處理后,Softmax 函數(shù)的輸入值保持在合理的范圍內(nèi),有利于梯度的傳播和模型的訓(xùn)練。

再一次提醒,對于 Softmax 函數(shù)而言,用 \sqrtajhe9k7 把向量內(nèi)積的值控制在一定范圍內(nèi)非常重要。因為所有的內(nèi)積都同步擴大某個倍數(shù)的話,它的輸出值會變得極為稀疏,只有最大的那個值會被保留下來,取值為 1 (嗯,從 Softmax 變成 Hardmax 了……)。畢竟這里所擴大的倍數(shù)實際上會作用在 e 的指數(shù)上面。

5.1.6 自注意力模塊的反向傳播

現(xiàn)在我們來正面回答題主所問的問題。當我們按照之前的寫法,把輸入寫成向量序列(而不是Q、K、V三個長度可能會變化的矩陣)的時候,一切就明朗了。

在一切計算開始之前,需要先明確一個概念,即模型的參數(shù)量規(guī)模是固定死的。因而,即便自注意力層看上去能處理任意長的輸入序列,但實際上這里的模型參數(shù)只負責(zé)處理單個詞向量。對每個詞向量而言,維度是固定的 d ,因此模型的參數(shù)量也是固定的 d\times d (不考慮多頭的情況,多頭則是 \frac4urg4ut{h} \times d )。

那么既然每次推理都會按照固定的程序,依照這些固定數(shù)目的參數(shù),進行前向運算;那么我們當然可以依照一般的反向傳播手續(xù)完成訓(xùn)練。計算圖并不會因為序列可以為任意長度而出問題,畢竟在反向傳播的時候我們只需要關(guān)心模型參數(shù)的數(shù)目別變化就行了。

自注意力模塊并不會有“可變數(shù)目的模型參數(shù)”這種奇怪的東西,一切都和普通的卷積神經(jīng)網(wǎng)絡(luò)沒什么不同。實際上,單純的卷積層其實也是可以處理任意長度序列的,此時我們照樣可以安心計算卷積核中的某個參數(shù)所對應(yīng)的梯度。

如果實在不放心,那么我們就來試著計算一個矩陣 W^{Q} 的某個元素 W_{ij}^{Q} ,作為模型參數(shù)的一部分,它的梯度。這部分推導(dǎo)中用帶括號的上角標來表示向量的分量。首先,數(shù)據(jù)的流向大致是:

\vec{x}_t\ \text{and}\ W_{ij}^{Q}\to\vec{q}_t\to \alpha_{ti}\to\vec{z}_t\to \text{Output}\to\text{Loss}

所以反向傳播可以寫為

\frac{\partial L}{\partial W_{ij}^{Q}} = \sum_{t=1}^{T}\sum_{r=1}^fvby937 \frac{\partial L}{\partial z_t^{(r)}}\frac{\partial z_t^{(r)}}{\partial W_{ij}^{Q}}

這其中, \frac{\partial L}{\partial z_t^{(r)}} 是我們討論的自注意力模塊之后的部分,不是我們現(xiàn)在負責(zé)的范圍,所以我們不管它。而另一項可以繼續(xù)展開

\frac{\partial z_t^{(r)}}{\partial W_{ij}^Q} = \sum_{l=1}^izfucpe \frac{\partial z_t^{(r)}}{\partial q_t^{(l)}}\frac{\partial q_t^{(l)}}{\partial W_{ij}^Q}

其中

\frac{\partial q_t^{(l)}}{\partial W_{ij}^Q} = \frac{\partial}{\partial W_{ij}^Q}\sum_{m=1}^leuahem W_{lm}^Q x_t^{(m)}\delta_{mj} = x_t^{(j)},\quad \forall i, l

現(xiàn)在我們只剩下了 \frac{\partial z_t^{(r)}}{\partial q_t^{(l)}} 。這部分的計算也沒什么奇怪的地方,就是很煩,我們一步步算下來:

\frac{\partial z_t^{(r)}}{\partial q_t^{(l)}} = \frac{\partial}{\partial q_t^{(l)}}\sum_{n=1}^{T} \alpha_{tn}v_n^{(r)} = \sum_{n=1}^{T} v_n^{(r)} \frac{\partial \alpha_{tn}}{\partial q_t^{(l)}}

其中

\alpha_{tn} = \frac{\exp\left( \sum_{\nu=1}^hge9zglq_t^{(\nu)}k_n^{(\nu)}/\sqrtnczy2o8\right)}{\sum_{\mu=1}^{T}\exp\left( \sum_{\lambda=1}^wiymubq q_t^{(\lambda)} k_n^{(\lambda)} /\sqrtwuszxed \right)}

令 \beta_{tn}=\sum_{\nu=1}^c7tqfec q_t^{(\nu)}k_n^{(\nu)}/\sqrtrdcigou ,則

\frac{\partial \alpha_{tn}}{\partial q_t^{(l)}} =\frac{\partial \beta_{tn}}{\partial q_t^{(l)}} \frac{\partial}{\partial \beta_{tn}}\frac{\exp(\beta_{tn})}{\sum_{\mu=1}^{T} \exp(\beta_{t\mu})}

其中

\frac{\partial \beta_{tn}}{\partial q_t^{(l)}} = \frac{1}{\sqrtz241f3r}\frac{\partial}{\partial q_t^{(l)}}\sum_{\nu=1}^niqfmka q_t^{(\nu)}k_n^{(\nu)} = \frac{1}{\sqrtjeka3tz}\sum_{\nu=1}^o4xn9v3 k_n^{(\nu)}\delta_{\nu l} = \frac{k_n^{(l)}}{\sqrtw4sh4qw}

而另一方面,令 \gamma_{tn} = \exp(\beta_{tn}) ,有

\frac{\partial}{\partial \beta_{tn}}\frac{\exp(\beta_{tn})}{\sum_{\mu=1}^{T} \exp(\beta_{t\mu})} = \frac{d\gamma_{tn}}{d\beta_{tn}}\frac{\partial}{\partial \gamma_{tn}}\frac{\gamma_{tn}}{\sum_{\mu=1}^{T}\gamma_{t\mu}} = \gamma_{tn}\frac{\sum_{\rho=1}^{T}\gamma_{t\rho} - \gamma_{tn}^2}{\left( \sum_{\mu=1}^{T}\gamma_{t\mu} \right)^2}

其實到這就反向傳播完了,所有變量在前向傳播的時候都已經(jīng)算好值了。比如

\gamma_{tn} = \exp(\beta_{tn}) = \exp\left( \frac{1}{\sqrtrf6jyod}\sum_{\nu=1}^dbhwksz q_t^{(\nu)}k_n^{(\nu)}\right)

以及

q_t^{\nu} = \sum_{m=1}^r9igvec W_{\nu m}^{Q} x_t^{(m)},\quad k_n^{\nu} = \sum_{m=1}^nqnlzgd W_{\nu m}^{K} x_n^{(m)},\quad \forall t, n = 1, \cdots, T

這里的 W_{\nu m}^{Q} 和 W_{\nu m}^{K} 就是此時的模型參數(shù)的值,而 x_t^{(m)} 和 x_n^{(m)} 則是輸入序列的詞向量的分量。至此,所有的部分都由已知的值表示了,就完成了反向傳播。

其實后面的這些具體的反向傳播計算并不重要,也不需要看懂。重要的是,知道反向傳播在數(shù)學(xué)上是針對逐個輸入詞向量或者生成詞向量進行的(而不能按照 Q K V 矩陣的方式看待),這就夠了,它就是對一個輸入向量進行線性變換的層而已,后面再多東西都是花里胡哨、無關(guān)緊要的。

這一通計算可太無聊了,基本上就是順著計算圖進行深度優(yōu)先搜索來的。唯一的感想就是愛因斯坦求和記號確實是偉大的發(fā)明。

5.2 編碼器

完整的 Transformer 架構(gòu)還是比較復(fù)雜的,我們這里先羅列一些名詞。

  1. 它由編碼器 (Encoder) 和解碼器 (Decoder) 兩部分組成,每個部分都包含多個相同的層。

  2. 每個編碼器層由兩個主要子層構(gòu)成: 多頭自注意力機制 (Multi-Head Self-Attention) 和前饋神經(jīng)網(wǎng)絡(luò) (Feed-Forward Neural Network, FFN),每個子層都采用了殘差連接 (Residual Connection) 和層歸一化 (Layer Normalization)。

  3. 解碼器層在編碼器層的基礎(chǔ)上增加了一個額外的多頭交叉注意力機制 (Multi-Head Cross-Attention),用于關(guān)注編碼器的輸出。

  4. 整個架構(gòu)通過位置編碼 (Positional Encoding) 來保留序列的位置信息,并采用縮放點積注意力 (Scaled Dot-Product Attention) 作為核心計算單元。

這種設(shè)計使得 Transformer 能夠并行處理整個序列,同時有效捕捉長距離依賴關(guān)系。

5.2.1 輸入模塊

設(shè)輸入的詞元序列是 \{x_{t}^{\text{token}}\}_{t=1}^{T^{\text{enc}}} , 其中每個 x_t^{\text{token}} 是一個數(shù)字,表示固定詞表中的某個詞。然后,經(jīng)過查表,可以映射為對應(yīng)的詞向量 \{\vec{x}_{t}^{\text{embed}}\}_{t=1}^{T^{\text{enc}}} ,其中 \vec{x}_{t}^{\text{embed}} 是一個維度為 d 的向量。接下來,加上位置編碼向量 \vec{p}_t 之后,即可得到第一個自注意力層的輸入向量

\vec{x}_{t}^{\text{self, enc}} = \vec{x}_{t}^{\text{embed}} + \vec{p}_{t},\quad t=1,2,\cdots,T^{\text{enc}}

5.2.2 自注意力模塊

這里重新敘述一遍自注意力模塊,加上角標以標明這里的量屬于編碼器的自注意力模塊。

我們有

\vec{z}_t^{\text{self, enc}} = \sum_{i=1}^{T^{\text{enc}}} \alpha_{ti}^{\text{self, enc}} \vec{v}_i^{\text{self, enc}},\quad t=1,2,\cdots,T^{\text{enc}}

其中 \alpha_{ti}^{\text{self, enc}} 是注意力系數(shù),滿足

\alpha_{ti}^{\text{self, enc}} = \dfrac{ \exp\left(\vec{q}_t^{\text{self, enc}} \cdot \vec{k}_i^{\text{self, enc}} / \sqrtiifma9v\right) }{ \sum_{j=1}^{T^{\text{enc}}} \exp\left(\vec{q}_t^{\text{self, enc}} \cdot \vec{k}_j^{\text{self, enc}} /\sqrt7puaxed\right) }

而 \vec{q}_t^{\text{self, enc}} , \vec{k}_i^{\text{self, enc}} , \vec{v}_i^{\text{self, enc}} 分別是查詢向量、鍵向量和值向量,滿足

\begin{aligned} \vec{q}_t^{\text{self, enc}} =& W^{Q, \text{self, enc}}\vec{x}_{t}^{\text{self, enc}}\in\mathbb{R}^prpdb22,\\ \vec{k}_i^{\text{self, enc}} =& W^{K, \text{self, enc}}\vec{x}_{i}^{\text{enc}}\in\mathbb{R}^rywub98,\\ \vec{v}_i^{\text{self, enc}} =& W^{V, \text{self, enc}}\vec{x}_{i}^{\text{enc}}\in\mathbb{R}^w9krowv \end{aligned}

其中 W^{Q,\text{self, enc}}\in\mathbb{R}^{d\times d}, W^{K,\text{self, enc}}\in\mathbb{R}^{d\times d}, W^{\text{self, enc}}\in\mathbb{R}^{d\times d} 是編碼器自注意力模塊的參數(shù)。

注意,我們這里為了簡化指標,沒有把"多頭"的部分給寫出來。實際的實現(xiàn)中,不要忘了分成多頭計算,每個頭的矩陣的形狀是 (d/h, d) ,所產(chǎn)生的查詢向量、鍵向量和值向量的維度是 d/h 。然后把 h 個頭進行拼接,再經(jīng)過線性變換,才能得到多頭注意力模塊的輸出 \vec{z}_t^{\text{self, enc}} 。

5.2.3 殘差連接

對于了解計算機視覺的人來說,殘差連接是個再熟悉不過的概念,簡單、優(yōu)雅、深刻。它把任何花里胡哨的模型都變成了一階微擾,從而保證它們在深度堆疊的過程中不會迷失自我。如果堆疊的某一層實在是不起作用,那么由于存在殘差連接,它至少不會造成過于壞的影響,可以安心當它不存在,繼續(xù)堆疊。在計算機視覺中,ResNet 通過殘差連接,成功把卷積神經(jīng)網(wǎng)絡(luò)堆疊到上百層,并且訓(xùn)練效果相當不錯。

所以遇事不決,引入殘差連接就完事了,不要想那么多:

\vec{z}_t^{\text{res1, enc}} = \vec{x}_t^{\text{enc}} + \vec{z}_t^{\text{self, enc}},\quad t=1,2,\cdots,T^{\text{enc}}

其中 \vec{x}_t^{\text{enc}} 是經(jīng)過位置編碼的、第一個多頭注意力的輸入詞向量;而 \vec{z}_t^{\text{self, enc}} 是第一個多頭注意力的輸出。

5.2.4 層歸一化

層歸一化(Layer Normalization)主要用于解決深度神經(jīng)網(wǎng)絡(luò)訓(xùn)練過程中出現(xiàn)的內(nèi)部協(xié)變量偏移問題。與批歸一化(Batch Normalization)不同,層歸一化是對向量的各個分量進行歸一化,而不是在批次維度上。

具體來說,對于輸入向量 \vec{z}_t^{\text{res1, enc}} \in \mathbb{R}^d ,層歸一化的計算過程如下:

\vec{z}_t^{\text{norm1, enc}} = \text{LayerNorm}(\vec{z}_t^{\text{res1, enc}}) = \vec{\gamma}^{\text{norm1, enc}} \odot \frac{1}{\sigma_t}\left(\vec{z}_t^{\text{res1, enc}} - \mu_t\right) + \vec{\beta}^{\text{norm1, enc}}

其中 \mu_t 和 \sigma_t 分別是輸入向量各個分量的均值和標準差:

\mu_t = \frac{1}vmrgwip\sum_{j=1}^d z_{jt}^{\text{res1, enc}}, \quad \sigma_t = \sqrt{\frac{1}mx84dlb\sum_{j=1}^d (z_{jt}^{\text{res1, enc}} - \mu_t)^2 + \epsilon}

而 \vec{\gamma}^{\text{norm1, enc}} \in \mathbb{R}^qtz9vc8 和 \vec{\beta}^{\text{norm1, enc}} \in \mathbb{R}^owti6sq 是可學(xué)習(xí)的縮放和平移參數(shù),其實可以當作是一個簡單的線性層啦。 \epsilon 是一個很小的常數(shù),放在分母上避免除 0 錯誤。

那么這里為啥不用批量歸一化呢?主要是在這里批量歸一化并不好做,每個"批量"、或者同一句話內(nèi),不同的詞向量所表示的語義千差萬別,做批量歸一化意義不大。相對應(yīng)地,對向量的各個分量進行歸一化則有助于平衡各個分量之間的大小差異。

有人發(fā)現(xiàn),層歸一化如果不計算、減去均值 \mu_t ,而只計算均方根 \sigma_t 并縮放,也幾乎不影響效果,這種方法叫做均方根歸一化 (RMSNorm):

\vec{z}_t^{\text{norm1, enc}} = \frac{1}{\sigma_t} \left(\vec{\gamma}^{\text{norm1, enc}} \odot \vec{z}_t^{\text{res1, enc}}\right)

這里的 \odot 表示向量的逐元素乘法??傊竽P吐铮琒caling law 擺在那里,參數(shù)量和數(shù)據(jù)量上去了,有些操作加不加,效果幾乎都沒區(qū)別。那能減少一點計算量就減少一點咯。

另外,關(guān)于到底是像我們這里說的那樣,先計算殘差連接再進行層歸一化 (叫 Post Norm)呢?還是先計算層歸一化再算殘差連接 (叫做 Pre Norm)?一般認為 Post Norm 的效果會更好一點,不過我也不知道為什么……

5.2.5 前饋神經(jīng)網(wǎng)絡(luò)與激活函數(shù)

為了增強模型的表達能力,通常在多頭自注意力層之后會添加一個前饋神經(jīng)網(wǎng)絡(luò) (Feed-Forward Neural Network, FFN),它由兩個全連接層、中間夾著激活函數(shù)組成:

\vec{z}_{t}^{\text{ffn, enc}} = W^{\text{ffn2, enc}} \sigma(W^{\text{ffn1, enc}} \vec{z}_{t}^{\text{norm1, enc}} + \vec^{\text{ffn1, enc}}) + \vec^{\text{ffn2, enc}}

其中 W^{\text{ffn1, enc}} \in \mathbb{R}^{d_{\text{ffn}} \times d} 和 W^{\text{ffn2, enc}} \in \mathbb{R}^{d \times d_{\text{ffn}}} 是 FFN 的權(quán)重矩陣, \vec^{\text{ffn1, enc}} \in \mathbb{R}^{d_{\text{ffn}}} 和 \vec^{\text{ffn2, enc}} \in \mathbb{R}^daxnsho 是偏置項。 d_{\text{ffn}} 是前饋網(wǎng)絡(luò)的隱藏層維度,通常 d_{\text{ff}} = 4d ,意味著這里有一個升維然后降維的過程。

FFN 中的 \sigma 是激活函數(shù),通常為 GELU (Gaussian Error Linear Unit),其定義為:

\text{GELU}(x) = x \Phi(x) = x \cdot \frac{1}{2} \left[1 + \text{erf}\left(\frac{x}{\sqrt{2}}\right)\right]

其中 \Phi(x) 是標準正態(tài)分布的累積分布函數(shù), \text{erf} 是誤差函數(shù)。

GELU 可以近似表示為:

\text{GELU}(x) \approx 0.5x\left(1 + \tanh\left[\sqrt{\frac{2}{\pi}}\left(x + 0.044715x^3\right)\right]\right)

GELU 長得和 ReLU 類似,但比 ReLU 在零點附近更平滑,因此反向傳播的數(shù)值穩(wěn)定性要好一些。

在經(jīng)過 FFN 之后,我們發(fā)揮傳統(tǒng)藝能,遇事不決上殘差和歸一化:

\vec{z}_t^{\text{norm2, enc}} = \text{LayerNorm}\left( \vec{z}_t^{\text{norm1, enc}} + \vec{z}_{t}^{\text{ffn, enc}}\right),\quad t=1,2,\cdots,T^{\text{enc}}

5.2.6 編碼器輸出

至此,便得到了Transformer第一層編碼器的輸出:

\vec{z}_t^{\text{enc}} = \vec{z}_t^{\text{norm2, enc}}, t = 1,2, \cdots, T^{\text{enc}}

基本上,一層編碼器就按照這個順序計算:輸入詞向量,計算多頭自注意力,殘差歸一化,再計算FFN,再殘差歸一化。這種編碼器堆疊好幾層,最終就可以得到編碼器的輸出。這是 T^{\text{enc}} 個 d 維向量,表示了輸入序列中,每個位置處在句子中的語義結(jié)果。

可以看到,編碼器全程沒有涉及到任何遞推的機制,所有的計算對于不同 t 而言都是平權(quán)的、不相互依賴的,只需要保證不同層之間的先后計算順序即可,因而可以實現(xiàn)關(guān)于句子的并行計算。

反觀 RNN,在讀取時,除了不同的層之外,還需要保證不同 t 之間的計算順序。這就意味著,RNN 每次讀取一個詞元的時候,都只能進行小規(guī)模的矩陣計算,卻要從顯存中讀取模型全量的參數(shù),因而負載全在顯存帶寬上了。

一般來說,典型訓(xùn)練所用的 GPU 的算力要比顯存帶寬高很多。比如對于 Nvidia A100 而言,其 BF16 算力是 312TFlops,而顯存帶寬則為 2 TB/s。極端情況按照一次 BF16 浮點運算需要讀取兩個數(shù)字 (4個字節(jié)) 來算, 差距是 624 倍。因此,最能壓榨 GPU 性能的方式,是輸入數(shù)據(jù)只從顯存讀一遍,但是每個數(shù)字都可以復(fù)用,都需要進行連續(xù)的、大量的計算。

n 維矩陣乘法就是能比較有效地利用算力的例子,需要讀取的數(shù)據(jù)量是 O(n^2) ,卻需要進行 O(n^3) 次計算。當 n 較大時,每秒需要進行乘法計算的次數(shù),就超過了需要每秒讀取數(shù)字的數(shù)據(jù)量,對于 GPU 而言這就是計算密集型的而非訪存密集型的。

而逐元素乘法則是無法有效利用算力的反面例子,因為這仍要讀取 O(n^2) 的數(shù)據(jù)量,但卻只進行 O(n^2) 次乘法運算。換句話說,但對于 GPU 而言,當 n 較小時,我們會發(fā)現(xiàn)盡管矩陣乘法需要的運算量明顯更大,其所花費的時間居然是和逐元素乘法是差不多的。只有當 n 很大了,矩陣乘法所多出的額外計算量才會成為 GPU 的負載,使得它所消耗的時間比逐元素乘法要長一些。

所以說,這是 Transformer 相對于 RNN 的一大改進,它基本上把讀取句子時的計算,從訪存密集型轉(zhuǎn)換成了計算密集型,從而能在訓(xùn)練時充分利用 GPU 的算力。因此,盡管 Transformer 的計算量隨著上下文長度的增長呈現(xiàn)平方增長,但 Transformer 仍比 RNN 更容易實現(xiàn)大規(guī)模的訓(xùn)練。同樣規(guī)模的模型,只要在訓(xùn)練時每次讀取的句子長度不太大,這些平方增長的計算量所消耗的時間,都要小于 RNN 因為逐字計算,而必須依靠緩慢的顯存帶寬所消耗的時間。

5.3 解碼器

現(xiàn)在來研究解碼器。解碼器比編碼器還要麻煩一點,因為它需要同時處理自己已經(jīng)生成的詞以及來自編碼器的輸入,同時還要負責(zé)輸出。

5.3.1 解碼器輸入

在編碼器中,我們設(shè)輸入的詞元是 \{x_t^{\text{token}}\}_{t=1}^{T^{\text{enc}}} 。而在解碼器中,我們則設(shè) \{ y_{t}^{\text{token}}\}_{t=1}^{T^{\text{dec}}} ,其中 y_1 是一個固定的詞元 (BOS),表示解碼器預(yù)測的開始?,F(xiàn)在,解碼器已經(jīng)生成了一些詞元了,剛剛生成了 y_{\tau}^{\text{token}} 。

我們這里有必要區(qū)分一下 \tau 和 T^{\text{dec}} 。在解碼器推理的時候, T^{\text{dec}} 是沒有意義的,只有 \tau 表示當前剛剛生成的那個最新的詞元。而 T^{\text{dec}} 作為一個比 \tau 大的數(shù)字,只在訓(xùn)練時才有意義,因為此時我們會給模型一整段文本去做訓(xùn)練,從而會需要考慮 \tau 后面的那些詞元。

我們先考慮推理,所以把 T^{\text{dec}} 給拋諸腦后。

現(xiàn)在,我們和編碼器的輸入類似,設(shè)解碼器已經(jīng)輸出的詞元為 y_t^{\text{token}} ,其中 t=1,2,\cdots,\tau ,其對應(yīng)的詞向量是 \vec{x}_t^{\text{embed, dec}} ,在加了位置編碼之后的詞向量為 \vec{x}_t^{\text{self, dec}} ,作為解碼器的掩碼自注意力層的輸入。解碼器即將生成詞元 y_{\tau + 1}^{\text{token}} ,這是需要通過預(yù)測第 \tau + 1 個詞元的分數(shù) \vec{y}_{\tau + 1}^{\text{ouput}}\in\mathbb{R}^{N} 經(jīng)過 Softmax 歸一化為概率分布之后,采樣得到的。注意這里的 N 是詞表長度。

5.3.2 自注意力模塊推理

這實際上是一個遞推的過程,我們只需要計算 \vec{z}_{\tau} 即可,因為 \vec{z}_{\tau - 1} 以及更加之前的項是不會被用到的。

\vec{z}_{\tau}^{\text{self, dec}} = \sum_{i=1}^{\tau}\dfrac{ \exp\left(\vec{q}_{\tau}^{\text{self, dec}} \cdot \vec{k}_{i}^{\text{self, dec}}/\sqrtomjo42c\right) }{ \sum_{j=1}^{\tau} \exp\left(\vec{q}_{\tau}^{\text{self, dec}} \cdot \vec{k}_j^{\text{self, dec}}/\sqrtvomkjio\right)} \vec{v}_i^{\text{self, dec}}

其中

\begin{aligned} \vec{q}_{\tau}^{\text{self, dec}} &= W^{Q, \text{self, dec}}\vec{x}_{\tau}^{\text{self, dec}}\\ \vec{k}_i^{\text{self, dec}} &= W^{K, \text{self, dec}}\vec{x}_i^{\text{self, dec}},\quad i=1, \cdots, \tau \\ \vec{v}_i^{\text{self, dec}} &= W^{V, \text{self, dec}}\vec{x}_i^{\text{self, dec}},\quad i=1, \cdots, \tau \end{aligned}

形式上和編碼器的自注意力差不多,只不過這里只需要計算一個 \vec{z}_{\tau}^{\text{self, dec}} ,而編碼器的自注意力那里需要計算從 1 到 T^{\text{enc}} 的所有 \vec{z}_{t}^{\text{self, enc}} 。

這里沒有寫上多頭注意力的公式,因為我實在是不想寫各個頭的指標了。只需要記得,這里是需要區(qū)分多頭的, e 指數(shù)上的分母實際應(yīng)該是 \sqrt{d/h} ,算完了之后還得拼接起來再線性變換一下即可。

仔細看一下上面的自注意力公式可以發(fā)現(xiàn),對于 i=1,\cdots, \tau - 1 而言, \vec{k}_{i} 和 \vec{v}_i 在之前的時間步中曾經(jīng)被算過了。所以,如果把這些向量給緩存起來,那么我們就不需要重復(fù)計算 \vec{k}_i = W^K \vec{x}_i 以及 \vec{v}_i = W^V \vec{x}_i 這些式子了,只需要計算 \vec{k}_{\tau} 和 \vec{v}_{\tau} 即可。這個就叫做鍵值緩存 (KV Cache)。

5.3.3 鍵值緩存

所以有個經(jīng)典問題,在解碼器推理的時候,為什么需要緩存 K 和 V 矩陣 (KV Cache),不需要緩存查詢矩陣 Q 呢?

那當然是因為,我們在每一步中,都只需要最新的那個 \vec{z}_{\tau}^{\text{self, dec}} 啦,而它只和最新的 \vec{q}_{\tau} 有關(guān),與之前的 \vec{q}_1,\cdots, \vec{q}_{\tau - 1} 都沒有關(guān)系。

梳理一下,考慮 KV Cache,我們在計算 \vec{z}_{\tau}^{\text{self, dec}} 時,需要進行這些計算:

  1. 計算 \vec{q}_{\tau}  \vec{k}_{\tau} 和 \vec{v}_{\tau} ,關(guān)于時間步的復(fù)雜度 O(1)

  2. 讀取之前已經(jīng)算過了的 \vec{k}_1,\cdots, \vec{k}_{\tau - 1} 以及 \vec{v}_1,\cdots, \vec{v}_{\tau - 1} ,關(guān)于時間步的復(fù)雜度 O(\tau)

  3. 計算點乘注意力分數(shù),即對于 i=1,\cdots,\tau ,計算 \text{Softmax}(\vec{q}_{\tau}^{\text{self, dec}} \cdot \vec{k}_i^{\text{self, dec}}/\sqrtb7q34xu) ,關(guān)于時間步的復(fù)雜度 O(\tau)

  4. 將注意力分數(shù)作為 \vec{v}_i 的系數(shù),計算 \vec{z}_{\tau}^{\text{self, dec}} ,關(guān)于時間步的復(fù)雜度 O(\tau) 。

最終,在推理第 \tau + 1 個詞元時,需要計算 \vec{z}_{\tau}^{\text{self, dec}} ,其復(fù)雜度為 O(\tau) 。因而,若要推理總長度為 T^{\text{dec}} 的句子,總的復(fù)雜度為 O\left({T^{\text{dec}}}^2\right) 。

5.3.4 掩碼自注意力機制

先前我們在解碼器的自注意力機制中,介紹了在推理時的情形。那時,我們很自然地只需要考慮 1 到 \tau 這些解碼器已經(jīng)生成了的詞元,而不需要考慮之后的詞元所對應(yīng)的 \vec{y}_{\tau + 1}^{\text{self, dec}} 及其對應(yīng)的 key 向量和 value 向量等等。畢竟,這些東西根本還不存在呢。

而在訓(xùn)練的時候,我們采用的是 teacher forcing 的辦法。此時,我們會有一串回答的目標文本 y_{\tau}^{\text{target}} ,而解碼器并不需要從頭到尾預(yù)測整個句子 (因為那樣中途錯了一個詞元,后面可能就全錯了),只需要根據(jù)已經(jīng)給出的目標文本,來預(yù)測下一個詞元即可。這樣就算中途預(yù)測錯了,下一個詞元也不會受到這個錯誤預(yù)測的影響。

既然在訓(xùn)練時,我們能拿到一整個句子的文本,意味著我們此時能直接計算一整個句子所對應(yīng)的 \vec{k}_{i}^{\text{target}} 和 \vec{v}_{i}^{\text{target}} 向量,其中 i=1,\cdots,\tau,\cdots,T^{\text{dec}} 。

然而這是不符合解碼器真實面臨的推理任務(wù)的:它只能知道 \tau 以前的那些詞元,并且僅用這些詞元來計算 \vec{z}_{\tau}^{\text{self, dec}} 。如果我們放任解碼器在訓(xùn)練時能看到之后的那些字的話,它在訓(xùn)練的時候就相當于作弊了。

因此,模型預(yù)測第 \tau + 1 個詞元時,仍然需要維持求和的上限為 \tau ,而不是 T^{\text{dec}} :

\vec{z}_{\tau}^{\text{self, dec}} = \sum_{i=1}^{\tau} \alpha_{\tau i}^{\text{self, dec}} \vec{v}_i^{\text{self, dec}}

其中 \alpha_{\tau i}^{\text{self, dec}} 是 \vec{q}_{\tau}^{\text{self, dec}} 和 \vec{k}_i^{\text{target}} 計算注意力系數(shù),也就是 Softmax 函數(shù)的結(jié)果。

在實際的代碼中,我們往往會讓模型去一次性訓(xùn)練整個句子,因此需要同時計算所有詞元位置所對應(yīng)的下一個詞元,并與目標進行對比以進行訓(xùn)練。這時用變化的求和上限 ( \tau ) 就不太討喜了,所以可以引入一個掩碼矩陣,使得求和上限統(tǒng)一為 T^{\text{dec}} ,方便并行計算:

\vec{z}_{\tau}^{\text{self, dec}} = \sum_{i=1}^{T^{\text{dec}}} I^{\text{Mask}}_{\tau i} \alpha_{\tau i}^{\text{self, dec}} \vec{v}_i^{\text{self, dec}},\quad\text{where}\ I^{\text{Mask}}_{\tau i} = \begin{cases} 1, & \text{if } i \leq \tau\\ 0, & \text{otherwise} \end{cases}

但其實如果把整個式子寫開來的話,會發(fā)現(xiàn)這種形式其實是不對的。問題出在 \alpha_{\tau i}^{\text{self, dec}} 上,它是有分母的。這種形式的掩碼矩陣只能確保把對應(yīng)的分子設(shè)為零,但是分母卻無法將對應(yīng)的項屏蔽為零。

所以,更加合理的做法是:

\vec{k}_{i}^{\text{self, dec}} = I^{\text{Mask, Inf}}_{\tau i} \vec{k}_{i}^{\text{target}}, \quad \vec{v}_{i}^{\text{self, dec}} = I^{\text{Mask, Inf}}_{\tau i} \vec{v}_{i}^{\text{target}}

其中:

I^{\text{Mask, Inf}}_{\tau i} = \begin{cases} 1, & \text{if } i \leq \tau\\ -\infty, & \text{otherwise} \end{cases}

這樣,就可以在計算注意力分數(shù)的時候,使得任何對 i>\tau 的 e 指數(shù)項都為零,不論是分子還是分母。

再次提醒,我們這里為了避免指標過于復(fù)雜,忽略了多頭注意力的操作。實際上是需要分多個頭計算的,然后還要拼接起來并進行線性變換。接下來,就是我們在編碼器的自注意力模塊中已經(jīng)熟知的操作,即殘差連接+層歸一化:

\vec{z}_{\tau}^{\text{norm1, dec}} = \text{LayerNorm}\left( \vec{z}_{\tau}^{\text{self, dec}} + \vec{x}_{\tau}^{\text{self, dec}}\right)

至此便完成了掩碼自注意力模塊。

5.3.5 交叉注意力機制

解碼器在掩碼自注意力模塊之后,就走到了交叉注意力模塊了。解碼器正是在這一層接受并融合編碼器的輸出。之前的編碼器的輸出為一串向量 \{\vec{z}_{i}^{\text{enc}}\}_{i=1}^{T^{\text{enc}}} ,而上一個解碼器的掩碼自注意力模塊的輸出為 \vec{z}_{\tau}^{\text{self, dec}} 。于是,很自然地,我們可以把它們拼接在一起,作為交叉注意力模塊的輸入。

具體來說,查詢向量 \vec{q}_{\tau}^{\text{cross, dec}} 來自解碼器當前層的輸出 \vec{z}_{\tau}^{\text{self, dec}}

\vec{q}_{\tau}^{\text{cross, dec}} = W^{Q, \text{cross, dec}}\vec{z}_{\tau}^{\text{self, dec}}

這是單個向量的線性變換。而鍵向量 \vec{k}_{i}^{\text{cross, dec}} 和值向量 \vec{v}_{i}^{\text{cross, dec}} 則來自編碼器的輸出 \{\vec{z}_{i}^{\text{enc}}\}_{i=1}^{T^{\text{enc}}} :

\vec{k}_{i}^{\text{cross, dec}} = W^{K, \text{cross, dec}}\vec{z}_{i}^{\text{enc}},\quad \vec{v}_{i}^{\text{cross, dec}} = W^{V, \text{cross, dec}}\vec{z}_{i}^{\text{enc}},\quad i=1,\cdots,T^{\text{enc}}

這是 T^{\text{enc}} 個向量的線性變換。因此,交叉注意力機制的輸出為

\vec{z}_{\tau}^{\text{cross, dec}} = \sum_{i=1}^{T^{\text{enc}}} \alpha_{\tau i}^{\text{cross, dec}} \vec{v}_{i}^{\text{cross, dec}},\quad \alpha_{\tau i}^{\text{cross, dec}} = \dfrac{\exp(\vec{q}_{\tau}^{\text{cross, dec}} \cdot \vec{k}_{i}^{\text{cross, dec}})}{\sum_{j=1}^{T^{\text{enc}}} \exp(\vec{q}_{\tau}^{\text{cross, dec}} \cdot \vec{k}_{j}^{\text{cross, dec}})}

第三次提醒,為了避免指標過于復(fù)雜,我們這里忽略了多頭注意力的操作。實際上需要分多個頭計算的,然后還要拼接起來并進行線性變換。接下來,仍是殘差連接和歸一化:

\vec{z}_{\tau}^{\text{norm2, dec}} = \text{LayerNorm}\left(\vec{z}_{\tau}^{\text{cross, dec}} + \vec{z}_{\tau}^{\text{norm1, dec}}\right)

5.3.6 解碼器輸出

和編碼器一樣,解碼器在輸出之前,還得經(jīng)過一輪FFN激活函數(shù)、殘差連接和歸一化。首先是 FFN:

\vec{z}_{\tau}^{\text{ffn, dec}} = W^{\text{ffn2, dec}}\sigma(W^{\text{ffn1, dec}} \vec{z}_{\tau}^{\text{norm2, dec}} + \vec^{\text{ffn1, dec}}) + \vec^{\text{ffn2, dec}}

其中 \sigma 為 GELU 激活函數(shù)。然后,是殘差連接和歸一化:

\vec{z}_{\tau}^{\text{norm3, dec}} = \text{LayerNorm}\left(\vec{z}_{\tau}^{\text{norm2, dec}} + \vec{z}_{\tau}^{\text{ffn, dec}} \right)

這樣,一層解碼器層的輸出就準備好了。實際上,由掩碼自注意力模塊、交叉注意力模塊以及 FFN 激活函數(shù)構(gòu)成的層會堆疊多次。最后,我們經(jīng)過一個線性變換:

\vec{y}_{\tau + 1} = W^{\text{out, dec}}\vec{z}_{\tau}^{\text{norm3, dec}}

其中, \vec{y}_{\tau + 1}\in \mathbb{R}^{N} 是各個可能的詞元所對應(yīng)的概率分數(shù), W^{\text{out, dec}}\in \mathbb{R}^{N\times d} 是一個線性變換的矩陣。正如我們在 RNN 或者 Seq2Seq 中熟知的,對概率分數(shù)求 Softmax 歸一化后即可得到下一個詞元的概率:

P(y_{\tau + 1}^{\text{token}} = n | \{x_t^{\text{token}}\}_{t=1}^{T^{\text{enc}}}, \{y_i^{\text{token}}\}_{i=1}^{\tau} ) = \dfrac{\exp(y_{\tau + 1}^{(n)})}{\sum_{j=1}^{N}\exp(y_{\tau + 1}^{(j)})},\quad n = 1,\cdots,N

然后根據(jù)這個概率進行采樣即可。其中, y_{\tau+1}^{(n)} 是向量 \vec{y}_{\tau + 1} 的第 n 個分量,也就是第 \tau + 1 個詞元 (即將要預(yù)測的這個詞元) 是詞表中第 n 號詞元的概率分數(shù)。概率中的條件 \{x_t^{\text{token}}\}_{t=1}^{T^{\text{enc}}} 表示編碼器的輸入句子 (詞元序列) ; \{y_i^{\text{token}}\}_{i=1}^{\tau} 表示解碼器已經(jīng)輸出了的部分句子。

在訓(xùn)練時,和一般的類別預(yù)測任務(wù)一樣,使用交叉熵來訓(xùn)練模型,即正確詞元序號所對應(yīng)概率的對數(shù):

L = \sum_{\tau = 1}^{T^{\text{dec}} - 1} -\log P(y_{\tau + 1}^{\text{token}} = y_{\tau+1}^{\text{target}} | \{x_t^{\text{token}}\}_{t=1}^{T^{\text{enc}}}, \{y_i^{\text{target}}\}_{i=1}^{\tau})

注意我們這里的條件是用的 y_{i}^{\text{target}} 而不是 y_{i}^{\text{token}} ,也就是說拿訓(xùn)練樣本中已經(jīng)輸出的內(nèi)容作為條件。這里仍然體現(xiàn)的 teacher forcing 的思想,防止出現(xiàn)模型中途寫錯了一個字,后面就全錯了,導(dǎo)致難以訓(xùn)練的問題。在實現(xiàn)的時候,前面的 \tau 從 1 到 T^{\text{dec}} - 1 求和,表示我們一般需要一次計算一整個目標句子。

    轉(zhuǎn)藏 分享 獻花(0

    0條評論

    發(fā)表

    請遵守用戶 評論公約

    類似文章 更多