W3Cschool
恭喜您成為首批注冊(cè)用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
在處理張量(Tensor)運(yùn)算時(shí),廣播(Broadcasting)機(jī)制是一種非常強(qiáng)大的功能,它允許我們對(duì)不同形狀的張量進(jìn)行運(yùn)算,而無(wú)需顯式地改變它們的形狀。本文將深入淺出地講解 PyTorch 廣播語(yǔ)義,并提供豐富的實(shí)例幫助你理解其在實(shí)際開(kāi)發(fā)中的應(yīng)用。無(wú)論你是初學(xué)者還是進(jìn)階開(kāi)發(fā)者,都能從中獲得啟發(fā)。
廣播語(yǔ)義允許形狀不同的張量在滿足一定條件下進(jìn)行運(yùn)算,仿佛它們具有相同的形狀。這種機(jī)制遵循特定的規(guī)則來(lái)自動(dòng)擴(kuò)展張量的形狀,無(wú)需實(shí)際復(fù)制數(shù)據(jù),既節(jié)省內(nèi)存,又提高運(yùn)算效率。
例如,當(dāng)你對(duì)一個(gè)形狀為 (3, 1) 的張量和一個(gè)形狀為 (1, 4) 的張量進(jìn)行加法運(yùn)算時(shí),PyTorch 會(huì)將它們擴(kuò)展為形狀為 (3, 4) 的張量,然后進(jìn)行逐元素相加。
你可能會(huì)好奇,為什么要使用廣播語(yǔ)義?直接改變張量形狀不是更簡(jiǎn)單嗎?
其實(shí)不然。廣播語(yǔ)義的優(yōu)勢(shì)在于:
view()
、expand()
等方法改變形狀,代碼更簡(jiǎn)潔。兩個(gè)張量是“可廣播的”,需要滿足以下規(guī)則:
代碼示例 1:
import torch
## 定義兩個(gè)可廣播的張量
x = torch.empty(5, 7, 3)
y = torch.empty(5, 7, 3)
print(x + y) # 可廣播,結(jié)果形狀為 (5, 7, 3)
代碼示例 2:
x = torch.empty(5, 3, 4, 1)
y = torch.empty(3, 1, 1)
print(x + y) # 可廣播,結(jié)果形狀為 (5, 3, 4, 1)
如果兩個(gè)張量可廣播,結(jié)果張量的形狀計(jì)算方式如下:
代碼示例 3:
x = torch.empty(5, 1, 4, 1)
y = torch.empty(3, 1, 1)
print((x + y).size()) # 結(jié)果形狀為 torch.Size([5, 3, 4, 1])
在就地操作(如 add_()
)中,不允許因廣播導(dǎo)致張量形狀改變。否則會(huì)報(bào)錯(cuò)。
代碼示例 4:
x = torch.empty(1, 3, 1)
y = torch.empty(3, 1, 7)
## 下面的代碼會(huì)報(bào)錯(cuò),因?yàn)閺V播會(huì)改變 x 的形狀
## x.add_(y)
在舊版本 PyTorch 中,某些逐點(diǎn)函數(shù)會(huì)在不同形狀但元素?cái)?shù)量相同的張量上執(zhí)行?,F(xiàn)在廣播機(jī)制引入后,這種行為可能不再適用,導(dǎo)致向后不兼容。
代碼示例 5:
import torch
from torch.utils.backcompat import broadcast_warning
torch.utils.backcompat.broadcast_warning.enabled = True
## 下面的代碼會(huì)產(chǎn)生警告,因?yàn)榕f版本行為與廣播機(jī)制行為不同
print(torch.add(torch.ones(4, 1), torch.ones(4)))
代碼示例 6:
programming_lion_data = torch.empty(3, 1)
w3cschool_weights = torch.empty(1, 4)
result = programming_lion_data + w3cschool_weights
print(result.size()) # 結(jié)果形狀為 torch.Size([3, 4])
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),我們常常需要對(duì)不同形狀的張量進(jìn)行運(yùn)算,例如將一個(gè)形狀為 (batch_size, 1) 的標(biāo)簽張量與一個(gè)形狀為 (batch_size, num_classes) 的預(yù)測(cè)張量進(jìn)行比較。
傳統(tǒng)方法:
## 假設(shè) batch_size=32,num_classes=10
labels = torch.randint(0, 10, (32, 1)) # 標(biāo)簽形狀為 (32, 1)
preds = torch.randn(32, 10) # 預(yù)測(cè)形狀為 (32, 10)
## 將標(biāo)簽展開(kāi)為 one-hot 編碼,形狀變?yōu)?(32, 10)
one_hot_labels = torch.zeros(32, 10)
one_hot_labels.scatter_(1, labels, 1)
## 計(jì)算損失
loss = torch.mean((preds - one_hot_labels) ** 2)
廣播方法:
## 直接利用廣播語(yǔ)義計(jì)算損失,無(wú)需顯式展開(kāi)標(biāo)簽
loss = torch.mean((preds - labels) ** 2) # labels 會(huì)自動(dòng)擴(kuò)展為 (32, 10)
廣播方法不僅代碼更簡(jiǎn)潔,而且避免了額外的內(nèi)存分配,提高了訓(xùn)練效率。
通過(guò)合理利用廣播語(yǔ)義,我們可以在神經(jīng)網(wǎng)絡(luò)訓(xùn)練中減少顯式操作,提高代碼可讀性和運(yùn)行效率。在實(shí)際項(xiàng)目中,建議多嘗試使用廣播語(yǔ)義來(lái)優(yōu)化代碼,但同時(shí)要注意避免因廣播導(dǎo)致的潛在問(wèn)題,如就地操作形狀改變和向后兼容性問(wèn)題。
Q1:廣播語(yǔ)義是否會(huì)影響計(jì)算結(jié)果的準(zhǔn)確性?
A1:不會(huì)。廣播語(yǔ)義只是在形狀上進(jìn)行虛擬擴(kuò)展,實(shí)際計(jì)算時(shí)仍然使用原始數(shù)據(jù),因此不會(huì)影響結(jié)果準(zhǔn)確性。
Q2:如何快速檢查兩個(gè)張量是否可廣播?
A2:可以使用以下代碼片段檢查:
def is_broadcastable(shape1, shape2):
for a, b in zip(shape1[::-1], shape2[::-1]):
if a != b and a != 1 and b != 1:
return False
return True
## 示例
print(is_broadcastable((3, 1), (1, 4))) # True
print(is_broadcastable((3, 2), (3, 4))) # False
Q3:廣播語(yǔ)義在哪些場(chǎng)景下特別有用?
A3:廣播語(yǔ)義在以下場(chǎng)景特別有用:
PyTorch 的廣播語(yǔ)義是一種高效且便捷的張量運(yùn)算機(jī)制,它在保持代碼簡(jiǎn)潔性的同時(shí),提高了計(jì)算效率。通過(guò)深入理解廣播規(guī)則,合理應(yīng)用廣播語(yǔ)義,我們可以在深度學(xué)習(xí)開(kāi)發(fā)中事半功倍。
對(duì)于初學(xué)者來(lái)說(shuō),建議多進(jìn)行廣播語(yǔ)義相關(guān)的練習(xí),嘗試不同的張量形狀組合,觀察運(yùn)算結(jié)果,從而加深對(duì)廣播機(jī)制的理解。同時(shí),關(guān)注 PyTorch 官方文檔的更新,及時(shí)了解廣播語(yǔ)義的最新發(fā)展。
在實(shí)際項(xiàng)目中,靈活運(yùn)用廣播語(yǔ)義可以顯著提升代碼質(zhì)量和運(yùn)行效率。關(guān)注編程獅(W3Cschool)平臺(tái),獲取更多優(yōu)質(zhì) PyTorch 教程和實(shí)踐案例,助力你的深度學(xué)習(xí)之旅。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: