W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
PyTorch 數(shù)據(jù)加載與處理詳解
torch.utils.data.DataLoader
是 PyTorch 提供的核心數(shù)據(jù)加載工具,它可以方便地從數(shù)據(jù)集中加載數(shù)據(jù),并支持多種高級功能,如多進(jìn)程加載、自動(dòng)批處理、自定義數(shù)據(jù)轉(zhuǎn)換等。
DataLoader
的基本構(gòu)造DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None
)
dataset
:數(shù)據(jù)集對象,可以是映射式數(shù)據(jù)集(實(shí)現(xiàn) __getitem__
和 __len__
方法)或迭代式數(shù)據(jù)集(實(shí)現(xiàn) __iter__
方法)。batch_size
:每個(gè)批次加載的樣本數(shù)量,默認(rèn)為 1。shuffle
:是否在每個(gè) epoch 開始時(shí)打亂數(shù)據(jù),默認(rèn)為 False
。sampler
:自定義采樣器,用于指定數(shù)據(jù)加載順序,不能與 shuffle
同時(shí)使用。num_workers
:加載數(shù)據(jù)時(shí)使用的子進(jìn)程數(shù)量,默認(rèn)為 0(即單進(jìn)程加載)。collate_fn
:用于將單個(gè)樣本合并成批次的函數(shù),默認(rèn)會將樣本列表轉(zhuǎn)換為張量。pin_memory
:是否將數(shù)據(jù)加載到固定內(nèi)存中,以便更快地傳輸?shù)?GPU,默認(rèn)為 False
。drop_last
:如果數(shù)據(jù)集大小不能被批次大小整除,是否丟棄最后一個(gè)不完整的批次,默認(rèn)為 False
。__getitem__
和 __len__
協(xié)議。__iter__
協(xié)議。torch.utils.data.SequentialSampler
:按順序采樣數(shù)據(jù)。torch.utils.data.RandomSampler
:隨機(jī)采樣數(shù)據(jù),可以指定是否替換采樣。torch.utils.data.SubsetRandomSampler
:從給定的索引列表中隨機(jī)采樣。torch.utils.data.WeightedRandomSampler
:根據(jù)給定的權(quán)重進(jìn)行采樣。
將 num_workers
參數(shù)設(shè)置為大于 0 的值可以啟用多進(jìn)程數(shù)據(jù)加載。每個(gè)工作進(jìn)程會加載一個(gè)子集的數(shù)據(jù),從而加速數(shù)據(jù)加載過程。
將 pin_memory
參數(shù)設(shè)置為 True
,可以將數(shù)據(jù)加載到固定內(nèi)存中,這樣在將數(shù)據(jù)傳輸?shù)?GPU 時(shí)會更快。
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
## 示例數(shù)據(jù)
data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))
dataset = CustomDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)
from torch.utils.data import IterableDataset
class CustomIterableDataset(IterableDataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __iter__(self):
for i in range(len(self.data)):
yield self.data[i], self.labels[i]
dataset = CustomIterableDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=2)
from torch.utils.data import random_split
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)
from torch.utils.data import ConcatDataset
dataset1 = CustomDataset(data1, labels1)
dataset2 = CustomDataset(data2, labels2)
combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=10, shuffle=True)
collate_fn
在某些情況下,可能需要自定義如何將單個(gè)樣本合并成一個(gè)批次。例如,對于變長序列數(shù)據(jù),可以自定義 collate_fn
來填充序列使其長度一致。
def custom_collate_fn(batch):
# batch 是一個(gè)列表,其中每個(gè)元素是一個(gè)數(shù)據(jù)樣本
# 這里可以實(shí)現(xiàn)自定義的批次合并邏輯
# 例如,填充序列使其長度一致
return batch
dataloader = DataLoader(dataset, batch_size=10, collate_fn=custom_collate_fn)
worker_init_fn
自定義工作進(jìn)程初始化
worker_init_fn
可以在每個(gè)工作進(jìn)程初始化時(shí)執(zhí)行自定義邏輯,例如設(shè)置隨機(jī)種子。
def worker_init_fn(worker_id):
import numpy as np
np.random.seed(worker_id)
dataloader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init_fn)
使用 torch.utils.data.distributed.DistributedSampler
可以在分布式訓(xùn)練中將數(shù)據(jù)集分割成多個(gè)子集,每個(gè)進(jìn)程加載不同的子集。
from torch.utils.data.distributed import DistributedSampler
sampler = DistributedSampler(dataset, num_replicas=4, rank=0, shuffle=True)
dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)
通過本教程,我們詳細(xì)了解了 PyTorch 中 torch.utils.data
模塊的使用方法,包括數(shù)據(jù)加載器的核心參數(shù)與用法、數(shù)據(jù)集的創(chuàng)建與使用、采樣器的使用、多進(jìn)程數(shù)據(jù)加載、內(nèi)存固定以及高級功能如自定義 collate_fn
和分布式數(shù)據(jù)加載。合理利用這些功能可以顯著提升數(shù)據(jù)預(yù)處理和加載的效率,為模型訓(xùn)練提供有力支持。
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報(bào)電話:173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: