Files
pytorch-study/06.ipynb

288 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"### Dataset类\n",
"其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法\n",
"\n",
"* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n",
"* \\_\\_len\\_\\_():返回数据集大小;\n",
"* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。"
],
"id": "6f460291da56d507"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T01:44:55.209362Z",
"start_time": "2025-06-13T01:44:52.973675Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class MyDataset(Dataset):\n",
" # 构造函数\n",
" def __init__(self, data_tensor, target_tensor):\n",
" self.data_tensor = data_tensor\n",
" self.target_tensor = target_tensor\n",
"\n",
" # 返回数据集大小\n",
" def __len__(self):\n",
" return self.data_tensor.size(0)\n",
"\n",
" # 返回索引的数据与标签\n",
" def __getitem__(self, idx):\n",
" return self.data_tensor[idx], self.target_tensor[idx]"
],
"id": "1244f4325aad0ac5",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T01:44:55.286472Z",
"start_time": "2025-06-13T01:44:55.230898Z"
}
},
"cell_type": "code",
"source": [
"# 生成数据\n",
"data_tensor = torch.randn(10, 3)\n",
"target_tensor = torch.randint(2, (10,)) # 二分类任务\n",
"\n",
"print(data_tensor)\n",
"print(target_tensor)"
],
"id": "e2f8ad8dc637791a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 1.5479, 0.1374, 1.6763],\n",
" [-0.5671, -0.0821, -1.8523],\n",
" [-0.4039, -0.5871, -0.3510],\n",
" [-0.2339, 0.2773, -0.0820],\n",
" [-0.3131, 0.5911, 0.2030],\n",
" [-0.7087, -0.2614, -0.2661],\n",
" [ 0.3220, -0.0340, -1.2429],\n",
" [-0.1282, -0.2188, -0.7576],\n",
" [ 1.1233, -0.2452, 0.4664],\n",
" [ 1.2570, 0.3728, 0.9745]])\n",
"tensor([0, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n"
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T01:44:55.610922Z",
"start_time": "2025-06-13T01:44:55.606365Z"
}
},
"cell_type": "code",
"source": [
"# 创建数据集\n",
"my_dataset = MyDataset(data_tensor, target_tensor)\n",
"\n",
"# 查看数据集大小\n",
"print(\"dataset size\", len(my_dataset))\n",
"\n",
"# 使用索引获取数据\n",
"print(\"data:\", my_dataset[0]) # 获取第一个数据"
],
"id": "cd8f86852f54b0c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataset size 10\n",
"data: (tensor([1.5479, 0.1374, 1.6763]), tensor(0))\n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T01:48:47.876220Z",
"start_time": "2025-06-13T01:48:47.866170Z"
}
},
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"tensor_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=0)\n",
"\n",
"# 以循环的方式获取数据\n",
"for data, target in tensor_loader:\n",
" print(\"data:\", data)\n",
" print(\"target:\", target)\n",
"\n",
"print(\"one batch tensor data:\", next(iter(tensor_loader)))"
],
"id": "d78ddf04e79df9ae",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"data: tensor([[ 0.3220, -0.0340, -1.2429],\n",
" [-0.2339, 0.2773, -0.0820]])\n",
"target: tensor([0, 1])\n",
"data: tensor([[-0.5671, -0.0821, -1.8523],\n",
" [-0.3131, 0.5911, 0.2030]])\n",
"target: tensor([1, 0])\n",
"data: tensor([[-0.1282, -0.2188, -0.7576],\n",
" [-0.4039, -0.5871, -0.3510]])\n",
"target: tensor([0, 1])\n",
"data: tensor([[-0.7087, -0.2614, -0.2661],\n",
" [ 1.2570, 0.3728, 0.9745]])\n",
"target: tensor([0, 0])\n",
"data: tensor([[ 1.5479, 0.1374, 1.6763],\n",
" [ 1.1233, -0.2452, 0.4664]])\n",
"target: tensor([0, 0])\n",
"one batch tensor data: [tensor([[ 0.3220, -0.0340, -1.2429],\n",
" [ 1.1233, -0.2452, 0.4664]]), tensor([0, 0])]\n"
]
}
],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T01:53:08.605768Z",
"start_time": "2025-06-13T01:53:07.616662Z"
}
},
"cell_type": "code",
"source": "!pip install torchvision pillow",
"id": "b6845476dfd81659",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (0.17.1)\r\n",
"Requirement already satisfied: pillow in /opt/conda/lib/python3.10/site-packages (10.2.0)\r\n",
"Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision) (1.26.3)\r\n",
"Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchvision) (2.2.1)\r\n",
"Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.13.1)\r\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (4.9.0)\r\n",
"Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (1.12)\r\n",
"Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1)\r\n",
"Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1.3)\r\n",
"Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (2024.2.0)\r\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->torchvision) (2.1.3)\r\n",
"Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchvision) (1.3.0)\r\n",
"\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001B[0m\u001B[33m\r\n",
"\u001B[0m"
]
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T02:12:21.816805Z",
"start_time": "2025-06-13T02:12:21.421202Z"
}
},
"cell_type": "code",
"source": [
"# 以MNIST数据集为例\n",
"import torchvision\n",
"\n",
"mnist_dataset = torchvision.datasets.MNIST(root=\"./data\", train=True, transform=None, target_transform=None,\n",
" download=True)"
],
"id": "f8112ef53ed2d0ce",
"outputs": [],
"execution_count": 13
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-13T02:17:29.130555Z",
"start_time": "2025-06-13T02:17:28.225901Z"
}
},
"cell_type": "code",
"source": [
"mnist_dataset_list = list(mnist_dataset)\n",
"\n",
"display(mnist_dataset_list[0][0])\n",
"print(\"Image label is:\", mnist_dataset_list[0][1])"
],
"id": "9b9c1d2993b37bcd",
"outputs": [
{
"data": {
"text/plain": [
"<PIL.Image.Image image mode=L size=28x28>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA90lEQVR4AWNgGMyAWUhIqK5jvdSy/9/rQe5kgTlWjs3KRiAYxHsyKfDzxYMgFiOIAALDvfwQBsO/pK8Mz97fhPLAlNDtvyBwbNv3j8jCUHbAnOy/f89yM2jPwiLJwMc4628UqgQTnPvp/0eGFAQXLg5lcO/764YuhuArf3y4IAfmfoQwlBX44e/fckkMYaiA7q6/f6dJ45IViP3zdzcuSQaGn39/OkBl4WEL4euFmLIwXDuETav6lKfAIPy1DYucRNFdUPCe9MOUE3e6CpI6FogZSEKrwbFyOIATQ5v5mkcgXV9auVGlwK4NDGRguL75b88HVDla8QBFF16ADQA8sQAAAABJRU5ErkJggg==",
"image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAACzBVBJJwAO9dnp/wm8damu6Dw5dRjGf9IKw/+hkVPffCnWNJa7XVNV0Kxa1hErrNe/M2cnYqgElsAHpjkc1wlAODkV694W8c654t8M6n4TuvEctrrFw0cun3c0/lq+3AMJcDK5AyOeTkd+fPvGFn4gsvEtzF4m89tUG1ZJJjuMgUBVYN/EMKOe9YVXtK0bUtdvVs9LsZ7y4YgbIULYycZPoPc8V6lpfwh0/w7p66z8RdXj0y2z8llC4aWQ+mRn8lz9RXPfE3x1pvi46TYaPZTQadpMJghluWDSyrhQM9SMBe5Oc5NcBV7Tda1XRZJJNK1O8sXkG12tZ2iLD0JUjNQ3l9eahN517dT3MvTfNIXb16n6mq9Ff/2Q=="
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Image label is: 5\n"
]
},
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute 'size'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[23], line 6\u001B[0m\n\u001B[1;32m 3\u001B[0m display(mnist_dataset_list[\u001B[38;5;241m0\u001B[39m][\u001B[38;5;241m0\u001B[39m])\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mImage label is:\u001B[39m\u001B[38;5;124m\"\u001B[39m, mnist_dataset_list[\u001B[38;5;241m0\u001B[39m][\u001B[38;5;241m1\u001B[39m])\n\u001B[0;32m----> 6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;28;43mlist\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mmnist_dataset_list\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msize\u001B[49m)\n",
"\u001B[0;31mAttributeError\u001B[0m: 'list' object has no attribute 'size'"
]
}
],
"execution_count": 23
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}