3.7 KiB
3.7 KiB
Dataset类¶
其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:
- __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
- __len__():返回数据集大小;
- __getitem__():索引数据集中的某一个数据。
In [64]:
import torch from torch.utils.data import Dataset class MyDataset(Dataset): # 构造函数 def __init__(self, data_tensor, target_tensor): self.data_tensor = data_tensor self.target_tensor = target_tensor # 返回数据集大小 def __len__(self): return self.data_tensor.size(0) # 返回索引的数据与标签 def __getitem__(self, idx): return self.data_tensor[idx], self.target_tensor[idx]
In [66]:
# 生成数据 data_tensor = torch.randn(10, 3) target_tensor = torch.randint(2, (10,)) # 二分类任务 print(data_tensor) print(target_tensor)
tensor([[-1.1016, 0.8803, 1.3645],
[ 0.6674, 0.4500, 0.0367],
[ 1.3528, 0.2438, 0.4721],
[-1.3651, -1.1315, 0.0299],
[ 1.4538, 1.3622, -0.5165],
[ 0.7515, 0.6168, -0.9036],
[ 1.3542, -0.3779, -1.2439],
[-0.4588, -0.2233, -0.3531],
[-0.0515, 0.8951, 0.3544],
[ 0.0181, 0.4075, -2.0888]])
tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])
In [67]:
# 创建数据集 my_dataset = MyDataset(data_tensor, target_tensor) # 查看数据集大小 print("dataset size:",len(my_dataset)) # 使用索引获取数据 print("data:", my_dataset[0]) # 获取第一个数据
dataset size: 10
data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))