11 KiB
11 KiB
Dataset类¶
其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:
- __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
- __len__():返回数据集大小;
- __getitem__():索引数据集中的某一个数据。
In [1]:
import torch from torch.utils.data import Dataset class MyDataset(Dataset): # 构造函数 def __init__(self, data_tensor, target_tensor): self.data_tensor = data_tensor self.target_tensor = target_tensor # 返回数据集大小 def __len__(self): return self.data_tensor.size(0) # 返回索引的数据与标签 def __getitem__(self, idx): return self.data_tensor[idx], self.target_tensor[idx]
In [2]:
# 生成数据 data_tensor = torch.randn(10, 3) target_tensor = torch.randint(2, (10,)) # 二分类任务 print(data_tensor) print(target_tensor)
tensor([[ 1.5479, 0.1374, 1.6763],
[-0.5671, -0.0821, -1.8523],
[-0.4039, -0.5871, -0.3510],
[-0.2339, 0.2773, -0.0820],
[-0.3131, 0.5911, 0.2030],
[-0.7087, -0.2614, -0.2661],
[ 0.3220, -0.0340, -1.2429],
[-0.1282, -0.2188, -0.7576],
[ 1.1233, -0.2452, 0.4664],
[ 1.2570, 0.3728, 0.9745]])
tensor([0, 1, 1, 1, 0, 0, 0, 0, 0, 0])
In [3]:
# 创建数据集 my_dataset = MyDataset(data_tensor, target_tensor) # 查看数据集大小 print("dataset size:", len(my_dataset)) # 使用索引获取数据 print("data:", my_dataset[0]) # 获取第一个数据
dataset size: 10
data: (tensor([1.5479, 0.1374, 1.6763]), tensor(0))
In [9]:
from torch.utils.data import DataLoader tensor_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=0) # 以循环的方式获取数据 for data, target in tensor_loader: print("data:", data) print("target:", target) print("one batch tensor data:", next(iter(tensor_loader)))
data: tensor([[ 0.3220, -0.0340, -1.2429],
[-0.2339, 0.2773, -0.0820]])
target: tensor([0, 1])
data: tensor([[-0.5671, -0.0821, -1.8523],
[-0.3131, 0.5911, 0.2030]])
target: tensor([1, 0])
data: tensor([[-0.1282, -0.2188, -0.7576],
[-0.4039, -0.5871, -0.3510]])
target: tensor([0, 1])
data: tensor([[-0.7087, -0.2614, -0.2661],
[ 1.2570, 0.3728, 0.9745]])
target: tensor([0, 0])
data: tensor([[ 1.5479, 0.1374, 1.6763],
[ 1.1233, -0.2452, 0.4664]])
target: tensor([0, 0])
one batch tensor data: [tensor([[ 0.3220, -0.0340, -1.2429],
[ 1.1233, -0.2452, 0.4664]]), tensor([0, 0])]
In [11]:
!pip install torchvision pillow
Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (0.17.1) Requirement already satisfied: pillow in /opt/conda/lib/python3.10/site-packages (10.2.0) Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision) (1.26.3) Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchvision) (2.2.1) Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.13.1) Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (4.9.0) Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (1.12) Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1) Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1.3) Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (2024.2.0) Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->torchvision) (2.1.3) Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchvision) (1.3.0) WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
In [13]:
# 以MNIST数据集为例 import torchvision mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=None, target_transform=None, download=True)
In [23]:
mnist_dataset_list = list(mnist_dataset) display(mnist_dataset_list[0][0]) print("Image label is:", mnist_dataset_list[0][1])
Image label is: 5
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) Cell In[23], line 6 3 display(mnist_dataset_list[0][0]) 4 print("Image label is:", mnist_dataset_list[0][1]) ----> 6 print(list(mnist_dataset_list).size) AttributeError: 'list' object has no attribute 'size'