{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "### Dataset类\n", "其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:\n", "\n", "* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n", "* \\_\\_len\\_\\_():返回数据集大小;\n", "* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。" ], "id": "6f460291da56d507" }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T01:44:55.209362Z", "start_time": "2025-06-13T01:44:52.973675Z" } }, "cell_type": "code", "source": [ "import torch\n", "from torch.utils.data import Dataset\n", "\n", "\n", "class MyDataset(Dataset):\n", " # 构造函数\n", " def __init__(self, data_tensor, target_tensor):\n", " self.data_tensor = data_tensor\n", " self.target_tensor = target_tensor\n", "\n", " # 返回数据集大小\n", " def __len__(self):\n", " return self.data_tensor.size(0)\n", "\n", " # 返回索引的数据与标签\n", " def __getitem__(self, idx):\n", " return self.data_tensor[idx], self.target_tensor[idx]" ], "id": "1244f4325aad0ac5", "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T01:44:55.286472Z", "start_time": "2025-06-13T01:44:55.230898Z" } }, "cell_type": "code", "source": [ "# 生成数据\n", "data_tensor = torch.randn(10, 3)\n", "target_tensor = torch.randint(2, (10,)) # 二分类任务\n", "\n", "print(data_tensor)\n", "print(target_tensor)" ], "id": "e2f8ad8dc637791a", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[ 1.5479, 0.1374, 1.6763],\n", " [-0.5671, -0.0821, -1.8523],\n", " [-0.4039, -0.5871, -0.3510],\n", " [-0.2339, 0.2773, -0.0820],\n", " [-0.3131, 0.5911, 0.2030],\n", " [-0.7087, -0.2614, -0.2661],\n", " [ 0.3220, -0.0340, -1.2429],\n", " [-0.1282, -0.2188, -0.7576],\n", " [ 1.1233, -0.2452, 0.4664],\n", " [ 1.2570, 0.3728, 0.9745]])\n", "tensor([0, 1, 1, 1, 0, 0, 0, 0, 0, 0])\n" ] } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T01:44:55.610922Z", "start_time": "2025-06-13T01:44:55.606365Z" } }, "cell_type": "code", "source": [ "# 创建数据集\n", "my_dataset = MyDataset(data_tensor, target_tensor)\n", "\n", "# 查看数据集大小\n", "print(\"dataset size:\", len(my_dataset))\n", "\n", "# 使用索引获取数据\n", "print(\"data:\", my_dataset[0]) # 获取第一个数据" ], "id": "cd8f86852f54b0c5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset size: 10\n", "data: (tensor([1.5479, 0.1374, 1.6763]), tensor(0))\n" ] } ], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T01:48:47.876220Z", "start_time": "2025-06-13T01:48:47.866170Z" } }, "cell_type": "code", "source": [ "from torch.utils.data import DataLoader\n", "\n", "tensor_loader = DataLoader(dataset=my_dataset, batch_size=2, shuffle=True, num_workers=0)\n", "\n", "# 以循环的方式获取数据\n", "for data, target in tensor_loader:\n", " print(\"data:\", data)\n", " print(\"target:\", target)\n", "\n", "print(\"one batch tensor data:\", next(iter(tensor_loader)))" ], "id": "d78ddf04e79df9ae", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "data: tensor([[ 0.3220, -0.0340, -1.2429],\n", " [-0.2339, 0.2773, -0.0820]])\n", "target: tensor([0, 1])\n", "data: tensor([[-0.5671, -0.0821, -1.8523],\n", " [-0.3131, 0.5911, 0.2030]])\n", "target: tensor([1, 0])\n", "data: tensor([[-0.1282, -0.2188, -0.7576],\n", " [-0.4039, -0.5871, -0.3510]])\n", "target: tensor([0, 1])\n", "data: tensor([[-0.7087, -0.2614, -0.2661],\n", " [ 1.2570, 0.3728, 0.9745]])\n", "target: tensor([0, 0])\n", "data: tensor([[ 1.5479, 0.1374, 1.6763],\n", " [ 1.1233, -0.2452, 0.4664]])\n", "target: tensor([0, 0])\n", "one batch tensor data: [tensor([[ 0.3220, -0.0340, -1.2429],\n", " [ 1.1233, -0.2452, 0.4664]]), tensor([0, 0])]\n" ] } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T01:53:08.605768Z", "start_time": "2025-06-13T01:53:07.616662Z" } }, "cell_type": "code", "source": "!pip install torchvision pillow", "id": "b6845476dfd81659", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: torchvision in /opt/conda/lib/python3.10/site-packages (0.17.1)\r\n", "Requirement already satisfied: pillow in /opt/conda/lib/python3.10/site-packages (10.2.0)\r\n", "Requirement already satisfied: numpy in /opt/conda/lib/python3.10/site-packages (from torchvision) (1.26.3)\r\n", "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (from torchvision) (2.2.1)\r\n", "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.13.1)\r\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (4.9.0)\r\n", "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (1.12)\r\n", "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1)\r\n", "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (3.1.3)\r\n", "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch->torchvision) (2024.2.0)\r\n", "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch->torchvision) (2.1.3)\r\n", "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch->torchvision) (1.3.0)\r\n", "\u001B[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001B[0m\u001B[33m\r\n", "\u001B[0m" ] } ], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T02:12:21.816805Z", "start_time": "2025-06-13T02:12:21.421202Z" } }, "cell_type": "code", "source": [ "# 以MNIST数据集为例\n", "import torchvision\n", "\n", "mnist_dataset = torchvision.datasets.MNIST(root=\"./data\", train=True, transform=None, target_transform=None,\n", " download=True)" ], "id": "f8112ef53ed2d0ce", "outputs": [], "execution_count": 13 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-13T02:17:29.130555Z", "start_time": "2025-06-13T02:17:28.225901Z" } }, "cell_type": "code", "source": [ "mnist_dataset_list = list(mnist_dataset)\n", "\n", "display(mnist_dataset_list[0][0])\n", "print(\"Image label is:\", mnist_dataset_list[0][1])" ], "id": "9b9c1d2993b37bcd", "outputs": [ { "data": { "text/plain": [ "" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA90lEQVR4AWNgGMyAWUhIqK5jvdSy/9/rQe5kgTlWjs3KRiAYxHsyKfDzxYMgFiOIAALDvfwQBsO/pK8Mz97fhPLAlNDtvyBwbNv3j8jCUHbAnOy/f89yM2jPwiLJwMc4628UqgQTnPvp/0eGFAQXLg5lcO/764YuhuArf3y4IAfmfoQwlBX44e/fckkMYaiA7q6/f6dJ45IViP3zdzcuSQaGn39/OkBl4WEL4euFmLIwXDuETav6lKfAIPy1DYucRNFdUPCe9MOUE3e6CpI6FogZSEKrwbFyOIATQ5v5mkcgXV9auVGlwK4NDGRguL75b88HVDla8QBFF16ADQA8sQAAAABJRU5ErkJggg==", "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APAACzBVBJJwAO9dnp/wm8damu6Dw5dRjGf9IKw/+hkVPffCnWNJa7XVNV0Kxa1hErrNe/M2cnYqgElsAHpjkc1wlAODkV694W8c654t8M6n4TuvEctrrFw0cun3c0/lq+3AMJcDK5AyOeTkd+fPvGFn4gsvEtzF4m89tUG1ZJJjuMgUBVYN/EMKOe9YVXtK0bUtdvVs9LsZ7y4YgbIULYycZPoPc8V6lpfwh0/w7p66z8RdXj0y2z8llC4aWQ+mRn8lz9RXPfE3x1pvi46TYaPZTQadpMJghluWDSyrhQM9SMBe5Oc5NcBV7Tda1XRZJJNK1O8sXkG12tZ2iLD0JUjNQ3l9eahN517dT3MvTfNIXb16n6mq9Ff/2Q==" }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "Image label is: 5\n" ] }, { "ename": "AttributeError", "evalue": "'list' object has no attribute 'size'", "output_type": "error", "traceback": [ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[0;31mAttributeError\u001B[0m Traceback (most recent call last)", "Cell \u001B[0;32mIn[23], line 6\u001B[0m\n\u001B[1;32m 3\u001B[0m display(mnist_dataset_list[\u001B[38;5;241m0\u001B[39m][\u001B[38;5;241m0\u001B[39m])\n\u001B[1;32m 4\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mImage label is:\u001B[39m\u001B[38;5;124m\"\u001B[39m, mnist_dataset_list[\u001B[38;5;241m0\u001B[39m][\u001B[38;5;241m1\u001B[39m])\n\u001B[0;32m----> 6\u001B[0m \u001B[38;5;28mprint\u001B[39m(\u001B[38;5;28;43mlist\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mmnist_dataset_list\u001B[49m\u001B[43m)\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msize\u001B[49m)\n", "\u001B[0;31mAttributeError\u001B[0m: 'list' object has no attribute 'size'" ] } ], "execution_count": 23 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }