{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "### Dataset类\n", "其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:\n", "\n", "* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n", "* \\_\\_len\\_\\_():返回数据集大小;\n", "* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。" ], "id": "6f460291da56d507" }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-12T12:58:38.694696Z", "start_time": "2025-06-12T12:58:38.690734Z" } }, "cell_type": "code", "source": [ "import torch\n", "from torch.utils.data import Dataset\n", "\n", "\n", "class MyDataset(Dataset):\n", " # 构造函数\n", " def __init__(self, data_tensor, target_tensor):\n", " self.data_tensor = data_tensor\n", " self.target_tensor = target_tensor\n", "\n", " # 返回数据集大小\n", " def __len__(self):\n", " return self.data_tensor.size(0)\n", "\n", " # 返回索引的数据与标签\n", " def __getitem__(self, idx):\n", " return self.data_tensor[idx], self.target_tensor[idx]" ], "id": "1244f4325aad0ac5", "outputs": [], "execution_count": 64 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-12T12:59:08.444392Z", "start_time": "2025-06-12T12:59:08.435084Z" } }, "cell_type": "code", "source": [ "# 生成数据\n", "data_tensor = torch.randn(10, 3)\n", "target_tensor = torch.randint(2, (10,)) # 二分类任务\n", "\n", "print(data_tensor)\n", "print(target_tensor)" ], "id": "e2f8ad8dc637791a", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[-1.1016, 0.8803, 1.3645],\n", " [ 0.6674, 0.4500, 0.0367],\n", " [ 1.3528, 0.2438, 0.4721],\n", " [-1.3651, -1.1315, 0.0299],\n", " [ 1.4538, 1.3622, -0.5165],\n", " [ 0.7515, 0.6168, -0.9036],\n", " [ 1.3542, -0.3779, -1.2439],\n", " [-0.4588, -0.2233, -0.3531],\n", " [-0.0515, 0.8951, 0.3544],\n", " [ 0.0181, 0.4075, -2.0888]])\n", "tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])\n" ] } ], "execution_count": 66 }, { "metadata": { "ExecuteTime": { "end_time": "2025-06-12T12:59:25.092189Z", "start_time": "2025-06-12T12:59:25.085985Z" } }, "cell_type": "code", "source": [ "# 创建数据集\n", "my_dataset = MyDataset(data_tensor, target_tensor)\n", "\n", "# 查看数据集大小\n", "print(\"dataset size:\",len(my_dataset))\n", "\n", "# 使用索引获取数据\n", "print(\"data:\", my_dataset[0]) # 获取第一个数据" ], "id": "cd8f86852f54b0c5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset size: 10\n", "data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))\n" ] } ], "execution_count": 67 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }