国产gaysexchina男同gay,japanrcep老熟妇乱子伦视频,吃奶呻吟打开双腿做受动态图,成人色网站,国产av一区二区三区最新精品

PyTorch torch.utils.data

2025-07-02 16:04 更新

PyTorch 數(shù)據(jù)加載與處理詳解

一、PyTorch 數(shù)據(jù)加載器簡介

torch.utils.data.DataLoader 是 PyTorch 提供的核心數(shù)據(jù)加載工具,它可以方便地從數(shù)據(jù)集中加載數(shù)據(jù),并支持多種高級功能,如多進(jìn)程加載、自動(dòng)批處理、自定義數(shù)據(jù)轉(zhuǎn)換等。

二、數(shù)據(jù)加載器的核心參數(shù)與用法

(一)DataLoader 的基本構(gòu)造

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None
)

  1. dataset :數(shù)據(jù)集對象,可以是映射式數(shù)據(jù)集(實(shí)現(xiàn) __getitem____len__ 方法)或迭代式數(shù)據(jù)集(實(shí)現(xiàn) __iter__ 方法)。
  2. batch_size :每個(gè)批次加載的樣本數(shù)量,默認(rèn)為 1。
  3. shuffle :是否在每個(gè) epoch 開始時(shí)打亂數(shù)據(jù),默認(rèn)為 False。
  4. sampler :自定義采樣器,用于指定數(shù)據(jù)加載順序,不能與 shuffle 同時(shí)使用。
  5. num_workers :加載數(shù)據(jù)時(shí)使用的子進(jìn)程數(shù)量,默認(rèn)為 0(即單進(jìn)程加載)。
  6. collate_fn :用于將單個(gè)樣本合并成批次的函數(shù),默認(rèn)會將樣本列表轉(zhuǎn)換為張量。
  7. pin_memory :是否將數(shù)據(jù)加載到固定內(nèi)存中,以便更快地傳輸?shù)?GPU,默認(rèn)為 False。
  8. drop_last :如果數(shù)據(jù)集大小不能被批次大小整除,是否丟棄最后一個(gè)不完整的批次,默認(rèn)為 False

(二)數(shù)據(jù)集類型

  1. 映射式數(shù)據(jù)集
    • 實(shí)現(xiàn) __getitem____len__ 協(xié)議。
    • 適合于數(shù)據(jù)已經(jīng)存儲在磁盤上且可以按索引訪問的場景。

  1. 迭代式數(shù)據(jù)集
    • 實(shí)現(xiàn) __iter__ 協(xié)議。
    • 適合于數(shù)據(jù)流式讀取的場景,如實(shí)時(shí)生成的數(shù)據(jù)或從數(shù)據(jù)庫中讀取的數(shù)據(jù)。

(三)采樣器

  1. torch.utils.data.SequentialSampler :按順序采樣數(shù)據(jù)。
  2. torch.utils.data.RandomSampler :隨機(jī)采樣數(shù)據(jù),可以指定是否替換采樣。
  3. torch.utils.data.SubsetRandomSampler :從給定的索引列表中隨機(jī)采樣。
  4. torch.utils.data.WeightedRandomSampler :根據(jù)給定的權(quán)重進(jìn)行采樣。

(四)多進(jìn)程數(shù)據(jù)加載

num_workers 參數(shù)設(shè)置為大于 0 的值可以啟用多進(jìn)程數(shù)據(jù)加載。每個(gè)工作進(jìn)程會加載一個(gè)子集的數(shù)據(jù),從而加速數(shù)據(jù)加載過程。

(五)內(nèi)存固定

pin_memory 參數(shù)設(shè)置為 True,可以將數(shù)據(jù)加載到固定內(nèi)存中,這樣在將數(shù)據(jù)傳輸?shù)?GPU 時(shí)會更快。

三、數(shù)據(jù)集的創(chuàng)建與使用

(一)創(chuàng)建自定義數(shù)據(jù)集

  1. 映射式數(shù)據(jù)集示例

import torch
from torch.utils.data import Dataset


class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels


    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


## 示例數(shù)據(jù)
data = torch.randn(100, 3, 224, 224)
labels = torch.randint(0, 10, (100,))


dataset = CustomDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, shuffle=True)

  1. 迭代式數(shù)據(jù)集示例

from torch.utils.data import IterableDataset


class CustomIterableDataset(IterableDataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels


    def __iter__(self):
        for i in range(len(self.data)):
            yield self.data[i], self.labels[i]


dataset = CustomIterableDataset(data, labels)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10, num_workers=2)

(二)數(shù)據(jù)集的分割與合并

  1. 數(shù)據(jù)集分割

from torch.utils.data import random_split


train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])


train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=10, shuffle=False)

  1. 數(shù)據(jù)集合并

from torch.utils.data import ConcatDataset


dataset1 = CustomDataset(data1, labels1)
dataset2 = CustomDataset(data2, labels2)


combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=10, shuffle=True)

四、數(shù)據(jù)加載器的高級用法

(一)自定義 collate_fn

在某些情況下,可能需要自定義如何將單個(gè)樣本合并成一個(gè)批次。例如,對于變長序列數(shù)據(jù),可以自定義 collate_fn 來填充序列使其長度一致。

def custom_collate_fn(batch):
    # batch 是一個(gè)列表,其中每個(gè)元素是一個(gè)數(shù)據(jù)樣本
    # 這里可以實(shí)現(xiàn)自定義的批次合并邏輯
    # 例如,填充序列使其長度一致
    return batch


dataloader = DataLoader(dataset, batch_size=10, collate_fn=custom_collate_fn)

(二)使用 worker_init_fn 自定義工作進(jìn)程初始化

worker_init_fn 可以在每個(gè)工作進(jìn)程初始化時(shí)執(zhí)行自定義邏輯,例如設(shè)置隨機(jī)種子。

def worker_init_fn(worker_id):
    import numpy as np
    np.random.seed(worker_id)


dataloader = DataLoader(dataset, num_workers=4, worker_init_fn=worker_init_fn)

(三)分布式數(shù)據(jù)加載

使用 torch.utils.data.distributed.DistributedSampler 可以在分布式訓(xùn)練中將數(shù)據(jù)集分割成多個(gè)子集,每個(gè)進(jìn)程加載不同的子集。

from torch.utils.data.distributed import DistributedSampler


sampler = DistributedSampler(dataset, num_replicas=4, rank=0, shuffle=True)
dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)

五、總結(jié)

通過本教程,我們詳細(xì)了解了 PyTorch 中 torch.utils.data 模塊的使用方法,包括數(shù)據(jù)加載器的核心參數(shù)與用法、數(shù)據(jù)集的創(chuàng)建與使用、采樣器的使用、多進(jìn)程數(shù)據(jù)加載、內(nèi)存固定以及高級功能如自定義 collate_fn 和分布式數(shù)據(jù)加載。合理利用這些功能可以顯著提升數(shù)據(jù)預(yù)處理和加載的效率,為模型訓(xùn)練提供有力支持。

以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號