pytorch数据集与数据预处理

本文最后更新于 2025年2月19日 晚上

一. PyTorch 内置数据集

PyTorch 通过 torchvision.datasets 模块提供了许多常用的数据集,例如:

  • MNIST:手写数字图像数据集,用于图像分类任务。
  • CIFAR:包含 10 个类别、60000 张 32x32 的彩色图像数据集,用于图像分类任务。
  • COCO:通用物体检测、分割、关键点检测数据集,包含超过 330k 个图像和 2.5M 个目标实例的大规模数据集。
  • ImageNet:包含超过 1400 万张图像,用于图像分类和物体检测等任务。
  • STL-10:包含 100k 张 96x96 的彩色图像数据集,用于图像分类任务。
  • Cityscapes:包含 5000 张精细注释的城市街道场景图像,用于语义分割任务。
  • SQUAD:用于机器阅读理解任务的数据集。

以上数据集可以通过 torchvision.datasets 模块中的函数进行加载,也可以通过自定义的方式加载其他数据集。

  • torchvision: 一个图形库,提供了图片数据处理相关的 API 和数据集接口,包括数据集加载函数和常用的图像变换。
  • torchtext: 自然语言处理工具包,提供了文本数据处理和建模的工具,包括数据预处理和数据加载的方式。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from torchvision import datasets, transforms

train_dataset = datasets.MNIST(
'./data', # 数据存储路径
train=True, # 训练集
download=False, # 关闭下载
transform=transform # 应用预处理
)

test_dataset = datasets.MNIST(
'./data',
train=False, # 测试集
download=False, # 关闭下载
transform=transform
)

二. 定义数据集

1. torch.utils.data.Dataset

  • 作用:定义数据集的统一接口,需继承并实现关键方法。
  • 功能
    • __len__(self): 返回数据集的总样本数。
    • __getitem__(self, idx): 根据索引 idx 返回单个样本(数据 + 标签)。
  • 适用场景
    • 处理自定义格式数据(如非标准文件结构)。
    • 需要复杂的数据预处理逻辑(如动态生成数据)。
  • 代码示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch.utils.data import Dataset

# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, X_data, Y_data):
"""
初始化数据集,X_data 和 Y_data 是两个列表或数组
X_data: 输入特征
Y_data: 目标标签
"""
self.X_data = X_data
self.Y_data = Y_data

def __len__(self):
"""返回数据集的大小"""
return len(self.X_data)

def __getitem__(self, idx):
"""返回指定索引的数据"""
x = torch.tensor(self.X_data[idx], dtype=torch.float32) # 转换为 Tensor
y = torch.tensor(self.Y_data[idx], dtype=torch.float32)
return x, y

# 示例数据
X_data = [[1, 2], [3, 4], [5, 6], [7, 8]] # 输入特征
Y_data = [1, 0, 1, 0] # 目标标签

# 创建数据集实例
dataset = MyDataset(X_data, Y_data)

2. torch.utils.data.TensorDataset

  • 作用:基于张量(Tensor)的数据集类,适合处理数据-标签对。
  • 功能
    • 它接受多个张量作为输入(通常是数据和标签),并将它们组合成一个数据集。
    • 直接支持批处理和迭代。
  • 适用场景
    • 处理自定义格式数据(如非标准文件结构)。
    • 需要复杂的数据预处理逻辑(如动态生成数据)。
  • 代码示例
1
2
3
4
5
6
import torch
from torch.utils.data import TensorDataset

data = torch.randn(100, 3) # 100个样本,每个样本有3个特征
labels = torch.randint(0, 2, (100,)) # 100个标签,0或1
dataset = TensorDataset(data, labels)

三. 加载数据集

1. torch.utils.data.DataLoader

  • 作用:一个数据加载器,用于封装 Dataset 并提供高效的迭代功能。
  • 功能
    • 支持批处理(batch_size):将数据集分成小批量。
    • 支持数据打乱(shuffle):在每个 epoch 开始时打乱数据顺序。
    • 支持多线程加载(num_workers):加速数据加载过程。
    • drop_last:如果数据集中的样本数不能被 batch_size 整除,设置为 True 时,丢弃最后一个不完整的 batch。
  • 适用场景
    • 在训练模型时,通常会将 Dataset 传递给 DataLoader,以便高效地加载数据。
  • 代码示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torch.utils.data import DataLoader

# 创建 DataLoader 实例,batch_size 设置每次加载的样本数量
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 打印加载的数据
for epoch in range(1):
for batch_idx, (inputs, labels) in enumerate(dataloader):
print(f'Batch {batch_idx + 1}:')
print(f'Inputs: {inputs}')
print(f'Labels: {labels}')
"""
输出:
Batch 1:
Inputs: tensor([[3., 4.], [1., 2.]])
Labels: tensor([0., 1.])
Batch 2:
Inputs: tensor([[7., 8.], [5., 6.]])
Labels: tensor([0., 1.])
"""

2. torchvision.datasets.ImageFolder

  • 作用:这是一个专门用于加载图像数据的数据集类,适用于图像分类任务。
  • 功能
    • 从文件夹中加载图像数据,每个子文件夹代表一个类别。
    • 自动为每个图像分配标签(根据子文件夹名称)。
    • 支持数据预处理(通过 transform 参数)。
  • 适用场景
    • 当你的图像数据按类别存储在文件夹中时,可以直接使用 ImageFolder。
  • 代码示例
1
2
3
4
5
6
7
8
9
10
from torchvision.datasets import ImageFolder
from torchvision import transforms

transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
])

dataset = ImageFolder(root='path/to/image/folder', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 验证

1
2
3
4
5
6
7
# 检查 train_loader 的输出格式
for inputs, labels in train_loader:
print("Input shape:", inputs.shape) # 应该是 [batch_size, seq_len]
print("Label shape:", labels.shape) # 应该是 [batch_size]
print("Input type:", type(inputs)) # 应该是 <class 'torch.Tensor'>
print("Label type:", type(labels)) # 应该是 <class 'torch.Tensor'>
break

四. 数据预处理

  • transforms.Compose():将多个变换操作组合在一起。
  • transforms.Resize(size):调整图像大小。
  • transforms.ToTensor():将图像转换为 PyTorch 张量,值会被归一化到 [0, 1] 范围。
  • transforms.Normalize(mean, std):标准化图像数据,通常使用预训练模型时需要进行标准化处理。
  • transforms.CenterCrop(size):从图像中心裁剪指定大小的区域。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torchvision.transforms as transforms
from PIL import Image

# 定义数据预处理的流水线
transform = transforms.Compose([
transforms.Resize((128, 128)), # 将图像调整为 128x128
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
transform = transforms.CenterCrop(128) # 128*128的区域
])

# 加载图像
image = Image.open('image.jpg')

# 应用预处理
image_tensor = transform(image)
print(image_tensor.shape) # 输出张量的形状

五. 图像数据增强

  • transforms.RandomHorizontalFlip(p):随机水平翻转图像。
  • transforms.RandomRotation(degrees):随机旋转图像。
  • transforms.ColorJitter(brightness, contrast, saturation, hue):调整图像的亮度、对比度、饱和度和色调。
  • transforms.RandomCrop(size):随机裁剪指定大小的区域。
  • transforms.RandomResizedCrop(size):随机裁剪图像并调整到指定大小。
1
2
3
4
5
6
7
transform = transforms.Compose([
transform = transforms.RandomHorizontalFlip(p=0.5), # 50% 概率翻转
transform = transforms.RandomRotation(degrees=30), # 随机旋转 -30 到 +30 度
transform = transforms.ColorJitter(brightness=0.5, contrast=0.5),
transform = transforms.RandomCrop(128),
transform = transforms.RandomResizedCrop(224)
])

六. 用多个数据源(Multi-source Dataset)

1
2
3
4
5
from torch.utils.data import ConcatDataset

# 假设 dataset1 和 dataset2 是两个 Dataset 对象
combined_dataset = ConcatDataset([dataset1, dataset2])
combined_loader = DataLoader(combined_dataset, batch_size=64, shuffle=True)

七. 实例–加载MNIST 数据集,并应用转换。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 定义转换
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])

# 加载数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# 使用 DataLoader
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)

# 查看转换后的数据
for images, labels in train_loader:
print("图像张量大小:", images.size()) # [batch_size, 1, 128, 128]
break

输出结果为:

1
图像张量大小: torch.Size([32, 1, 128, 128])

pytorch数据集与数据预处理
https://jimes.cn/2025/01/26/pytorch数据集以及数据预处理/
作者
Jimes
发布于
2025年1月26日
许可协议