![]() MXNet 作者 / 亞馬遜主任科學(xué)家 李沐 新智元推薦 作者:MXNet 作者 / 亞馬遜主任科學(xué)家 李沐 【新智元導(dǎo)讀】PyTorch是一個純命令式的深度學(xué)習(xí)框架。它因為提供簡單易懂的編程接口而廣受歡迎,而且正在快速的流行開來。MXNet通過ndarray和 gluon模塊提供了非常類似 PyTorch 的編程接口。本文將簡單對比如何用這兩個框架來實現(xiàn)同樣的算法。 ![]() PyTorch是一個純命令式的深度學(xué)習(xí)框架。它因為提供簡單易懂的編程接口而廣受歡迎,而且正在快速的流行開來。例如 Caffe2 最近就并入了 PyTorch。 可能大家不是特別知道的是,MXNet 通過 ndarray和 gluon模塊提供了非常類似 PyTorch 的編程接口。本文將簡單對比如何用這兩個框架來實現(xiàn)同樣的算法。 安裝 PyTorch 默認(rèn)使用 conda 來進(jìn)行安裝,例如 而 MXNet 更常用的是使用 pip。我們這里使用了 --pre來安裝 nightly 版本 多維矩陣 對于多維矩陣,PyTorch 沿用了 Torch 的風(fēng)格稱之為 tensor,MXNet 則追隨了 NumPy 的稱呼 ndarray。下面我們創(chuàng)建一個兩維矩陣,其中每個元素初始化成 1。然后每個元素加 1 后打印。 忽略包名的不一樣的話,這里主要的區(qū)別是 MXNet 的形狀傳入?yún)?shù)跟 NumPy 一樣需要用括號括起來。 模型訓(xùn)練 下面我們看一個稍微復(fù)雜點的例子。這里我們使用一個多層感知機(jī)(MLP)來在 MINST 這個數(shù)據(jù)集上訓(xùn)練一個模型。我們將其分成 4 小塊來方便對比。 讀取數(shù)據(jù) 這里我們下載 MNIST 數(shù)據(jù)集并載入到內(nèi)存,這樣我們之后可以一個一個讀取批量。 這里的主要區(qū)別是 MXNet 使用 transform_first來表明數(shù)據(jù)變化是作用在讀到的批量的第一個元素,既 MNIST 圖片,而不是第二個標(biāo)號元素。 定義模型 下面我們定義一個只有一個單隱層的 MLP 。 我們使用了 Sequential容器來把層串起來構(gòu)造神經(jīng)網(wǎng)絡(luò)。這里 MXNet 跟 PyTorch 的主要區(qū)別是: 大家知道 Sequential下只能神經(jīng)網(wǎng)絡(luò)只能逐一執(zhí)行每個層。PyTorch 可以繼承 nn.Module來自定義 forward如何執(zhí)行。同樣,MXNet 可以繼承 nn.Block來達(dá)到類似的效果。 損失函數(shù)和優(yōu)化算法 這里我們使用交叉熵函數(shù)和最簡單隨機(jī)梯度下降并使用固定學(xué)習(xí)率 0.1 訓(xùn)練 最后我們實現(xiàn)訓(xùn)練算法,并附上了輸出結(jié)果。注意到每次我們會使用不同的權(quán)重和數(shù)據(jù)讀取順序,所以每次結(jié)果可能不一樣。 MXNet 跟 PyTorch 的不同主要在下面這幾點: |
|