Files
pytorch-study/06.ipynb

3.7 KiB
Raw Blame History

Dataset类

其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法

  • __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
  • __len__():返回数据集大小;
  • __getitem__():索引数据集中的某一个数据。
In [64]:
import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, idx):
        return self.data_tensor[idx], self.target_tensor[idx]
In [66]:
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 二分类任务

print(data_tensor)
print(target_tensor)
tensor([[-1.1016,  0.8803,  1.3645],
        [ 0.6674,  0.4500,  0.0367],
        [ 1.3528,  0.2438,  0.4721],
        [-1.3651, -1.1315,  0.0299],
        [ 1.4538,  1.3622, -0.5165],
        [ 0.7515,  0.6168, -0.9036],
        [ 1.3542, -0.3779, -1.2439],
        [-0.4588, -0.2233, -0.3531],
        [-0.0515,  0.8951,  0.3544],
        [ 0.0181,  0.4075, -2.0888]])
tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])
In [67]:
# 创建数据集
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print("dataset size",len(my_dataset))

# 使用索引获取数据
print("data:", my_dataset[0])  # 获取第一个数据
dataset size 10
data: (tensor([-1.1016,  0.8803,  1.3645]), tensor(0))