從手寫三層循環(huán)到標(biāo)準(zhǔn)實(shí)現(xiàn),矩陣相乘運(yùn)行效率提高三萬六千倍之路
前言
矩陣乘法可以說是最常見的運(yùn)算之一。
本文介紹不同的方式實(shí)現(xiàn)的矩陣乘法,并比較它們運(yùn)行速度的差異。
表示矩陣的方式有很多種,完善的矩陣類應(yīng)該實(shí)現(xiàn)切片取值,獲得矩陣形狀等操作,但本文并不打算直接從原生Python實(shí)現(xiàn)一個(gè)矩陣類,而是直接用 Pytorch中的tensor表示矩陣。
開始: 三層循環(huán)
根據(jù)矩陣相乘定義,可通過三層循環(huán)實(shí)現(xiàn)該運(yùn)算。
def matmul(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
for j in range(c2):
for k in range(c1):
rst[i][j] += a[i][k] * b[k][j]
return rst
那么這個(gè)函數(shù)的運(yùn)行效率如何呢?讓我們嘗試兩個(gè)較大的矩陣相乘,測試一下運(yùn)行時(shí)間。
m1 = torch.randn(5, 784)
m2 = torch.randn(784, 10)
%timeit -n 10 matmul(m1, m2)
得到結(jié)果如下:
624 ms ± 3.76 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
即每次矩陣相乘, 需要耗時(shí) 600ms 左右,這是一個(gè)非常非常慢的速度,慢到兩次矩陣乘法居然要耗時(shí)1秒多,這是不可能被接受的。
相同形狀的張量進(jìn)行運(yùn)算
如果兩個(gè)張量的形狀相同,則他們的運(yùn)算為同一位置的數(shù)字進(jìn)行運(yùn)算。
a = torch.tensor([1., 2, 3])
b = torch.tensor([4., 5, 6])
a + b # tensor([5., 7., 9.])
a * b # tensor([ 4., 10., 18.])
康康之前用三層循環(huán)實(shí)現(xiàn)的矩陣相乘,發(fā)現(xiàn)最里面一層循環(huán)的本質(zhì)就是兩個(gè)同樣大小的張量相乘,再進(jìn)行求和。
即第一個(gè)矩陣中的一行 跟 第二個(gè)矩陣中的一列 進(jìn)行運(yùn)算,且這行和列中的元素個(gè)數(shù)相同,則我們可以通過同樣形狀的張量運(yùn)算改寫最內(nèi)層循環(huán):
def matmul(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
for j in range(c2):
rst[i][j] = (a[i,:] * b[:,j]).sum() # 改了這里
return rst
%timeit -n 10 matmul(m1, m2)
得到結(jié)果如下
1.4 ms ± 92.2 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
624 / 1.4=445,只改寫了一下最內(nèi)層循環(huán),就使得矩陣乘法快了445倍!
廣播機(jī)制
廣播機(jī)制使得不同形狀的張量間可以進(jìn)行運(yùn)算:
- 兩個(gè)張量擴(kuò)充成同樣的形狀
- 再按相同形狀的張量進(jìn)行運(yùn)算
# shape: [2, 3]
a = torch.tensor([
[1, 2, 3],
[4, 5, 6],
])
# shape: [1]
b = torch.tensor([1])
# shape: [3]
c = torch.tensor([10, 20, 30])
形狀為 [2, 3] 和 [1] 的兩個(gè)張量相加:
a + b
"""輸出:
tensor([[2, 3, 4],
[5, 6, 7]])
"""
形狀為 [2, 3] 和 [3] 的兩個(gè)張量相加:
b + c
"""輸出:
tensor([[11, 22, 33],
[14, 25, 36]])
"""
這兩個(gè)例子中,維度低的張量都是暗地里先擴(kuò)充成了維度高的張量,然后再參與的運(yùn)算。
那么如何查看擴(kuò)充后的張量是啥呢?用 expand_as 函數(shù)就可以查看:
b.expand_as(a)
"""輸出
tensor([[1, 1, 1],
[1, 1, 1]])
"""
b.expand_as(a)
"""輸出
tensor([[10, 20, 30],
[10, 20, 30]])
"""
這就一目了然了,形狀不同的張量可以通過廣播機(jī)制擴(kuò)充成形狀一致的張量再進(jìn)行運(yùn)算。
那么任意形狀的兩個(gè)張量都可以運(yùn)算嗎?當(dāng)然不是了,判斷兩個(gè)張量是否能運(yùn)算的規(guī)則如下:
先從兩個(gè)張量的最后一個(gè)維度看起,如果維度的維數(shù)相同,或者其中一個(gè)維數(shù)為1,則可以繼續(xù)判斷,否則就失敗。
然后看倒數(shù)第二個(gè)維度,倒數(shù)第三個(gè)維數(shù),一直到遍歷完某個(gè)張量的維數(shù)為止,一直沒有失敗則這兩個(gè)張量可以通過廣播機(jī)制進(jìn)行運(yùn)算。
那么這個(gè)廣播機(jī)制和矩陣乘法有什么關(guān)系呢?答案就是它可以幫我們再去掉一層循環(huán)。
現(xiàn)在的最內(nèi)存循環(huán)的本質(zhì)是 一個(gè)形狀為 [c1] 的張量 和 一個(gè)形狀為 [c1, c2] 的張量做運(yùn)算,最終生成一個(gè)形狀為 [c2] 的張量。
則我們可以把矩陣運(yùn)算改寫為:
def matmul(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
rst = torch.zeros(r1, c2)
for i in range(r1):
rst[i] = (a[i, :].unsqueeze(-1) * b).sum(0)
return rst
%timeit -n 10 matmul(m1, m2)
"""輸出
249 μs ± 66.4 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""
現(xiàn)在已經(jīng)把每次矩陣運(yùn)算的時(shí)間壓縮到了 249 μs!!!,比最開始的 624ms 快了 2500倍!
對于 unsqueeze 操作不太熟悉的小伙伴請看我的另一篇文檔: Pytorch 中張量的理解
但是還沒結(jié)束。。。因?yàn)閮蓚€(gè)矩陣的相乘,就是 [r1, c1] 和 [c1, c2] 兩個(gè)張量的運(yùn)算,我們可以直接把它用廣播機(jī)制一次到位的算出結(jié)果,連唯一的那層循環(huán)也可以省去:
def matmul(a, b):
r1, c1 = a.shape
r2, c2 = b.shape
assert c1 == r2
return (a.unsqueeze(-1) * b.unsqueeze(0)).sum(1)
%timeit -n 10 matmul(m1, m2)
"""輸出:
169 μs ± 41.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""
這個(gè) 169μs 已經(jīng)是最開始矩陣相乘版本的 3700 倍了。。( ? ^ ? )淚目,果然知識是第一生產(chǎn)力。
愛因斯坦求和
接下來就是 pytorch 自帶的矩陣運(yùn)算工具了,其中一個(gè)是愛因斯坦求和,貌似知道這個(gè)的同學(xué)不多。。
簡單來說,它能讓我們幾乎不編寫代碼就能進(jìn)行矩陣運(yùn)算,只需要確定輸入和輸出矩陣的形狀即可:
def matmul(a, b):
return torch.einsum("ik,kj->ij", a, b)
%timeit -n 10 matmul(a, b)
"""輸出
74 μs ± 25.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""
74μs 這個(gè)速度已經(jīng)是原始版本的 8000 多倍了。。。但是對于工業(yè)級別的要求似乎仍然不夠快~
pytorch 的矩陣相乘標(biāo)準(zhǔn)實(shí)現(xiàn)
最后祭出 pytorch 的矩陣相乘官方版本:
def matmul(a, b):
return a @ b
%timeit -n 10 matmul(m1, m2)
"""輸出
17.1 μs ± 28.5 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
"""
17.1 μs 是原始三層循環(huán)版本的 36000 倍,官方實(shí)現(xiàn)就是這么簡單枯燥,樸實(shí)無華~

浙公網(wǎng)安備 33010602011771號