Files
pytorch-study/06.ipynb

11 KiB
Raw Permalink Blame History

Dataset类

其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法

  • __init__():构造函数,可自定义数据读取方法以及进行数据预处理;
  • __len__():返回数据集大小;
  • __getitem__():索引数据集中的某一个数据。
In [1]:
import torch
from torch.utils.data import Dataset


class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor

    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)

    # 返回索引的数据与标签
    def __getitem__(self, idx):
        return self.data_tensor[idx], self.target_tensor[idx]
In [2]:
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,))  # 二分类任务

print(data_tensor)
print(target_tensor)
tensor([[ 1.5479,  0.1374,  1.6763],
        [-0.5671, -0.0821, -1.8523],
        [-0.4039, -0.5871, -0.3510],
        [-0.2339,  0.2773, -0.0820],
        [-0.3131,  0.5911,  0.2030],
        [-0.7087, -0.2614, -0.2661],
        [ 0.3220, -0.0340, -1.2429],
        [-0.1282, -0.2188, -0.7576],
        [ 1.1233, -0.2452,  0.4664],
        [ 1.2570,  0.3728,  0.9745]])
tensor([0, 1, 1, 1, 0, 0, 0, 0, 0, 0])
In [3]:
# 创建数据集
my_dataset = MyDataset(data_tensor, target_tensor)

# 查看数据集大小
print("dataset size", len(my_dataset))

# 使用索引获取数据
print("data:", my_dataset[0])  # 获取第一个数据
dataset size 10
data: (tensor([1.5479, 0.1374, 1.6763]), tensor(0))
In [9]:
from torch.utils.data import DataLoader

tensor_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=0)

# 以循环的方式获取数据
for data, target in tensor_loader:
    print("data:", data)
    print("target:", target)

print("one batch tensor data:", next(iter(tensor_loader)))
data: tensor([[ 0.3220, -0.0340, -1.2429],
        [-0.2339,  0.2773, -0.0820]])
target: tensor([0, 1])
data: tensor([[-0.5671, -0.0821, -1.8523],
        [-0.3131,  0.5911,  0.2030]])
target: tensor([1, 0])
data: tensor([[-0.1282, -0.2188, -0.7576],
        [-0.4039, -0.5871, -0.3510]])
target: tensor([0, 1])
data: tensor([[-0.7087, -0.2614, -0.2661],
        [ 1.2570,  0.3728,  0.9745]])
target: tensor([0, 0])
data: tensor([[ 1.5479,  0.1374,  1.6763],
        [ 1.1233, -0.2452,  0.4664]])
target: tensor([0, 0])
one batch tensor data: [tensor([[ 0.3220, -0.0340, -1.2429],
        [ 1.1233, -0.2452,  0.4664]]), tensor([0, 0])]
In [11]:
!pip install torchvision pillow
Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (0.17.1)
Requirement already satisfied: pillow in /opt/conda/lib/python3.10/site-packages (10.2.0)
Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision) (1.26.3)
Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchvision) (2.2.1)
Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.13.1)
Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (4.9.0)
Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (1.12)
Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1)
Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1.3)
Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (2024.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->torchvision) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchvision) (1.3.0)
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.
In [13]:
# 以MNIST数据集为例
import torchvision

mnist_dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=None, target_transform=None,
                                           download=True)
In [23]:
mnist_dataset_list = list(mnist_dataset)

display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])
No description has been provided for this image
Image label is: 5
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[23], line 6
      3 display(mnist_dataset_list[0][0])
      4 print("Image label is:", mnist_dataset_list[0][1])
----> 6 print(list(mnist_dataset_list).size)

AttributeError: 'list' object has no attribute 'size'