diff --git a/5.ipynb b/05.ipynb similarity index 100% rename from 5.ipynb rename to 05.ipynb diff --git a/06.ipynb b/06.ipynb new file mode 100644 index 0000000..f96408b --- /dev/null +++ b/06.ipynb @@ -0,0 +1,138 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### Dataset类\n", + "其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:\n", + "\n", + "* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n", + "* \\_\\_len\\_\\_():返回数据集大小;\n", + "* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。" + ], + "id": "6f460291da56d507" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-12T12:58:38.694696Z", + "start_time": "2025-06-12T12:58:38.690734Z" + } + }, + "cell_type": "code", + "source": [ + "import torch\n", + "from torch.utils.data import Dataset\n", + "\n", + "\n", + "class MyDataset(Dataset):\n", + " # 构造函数\n", + " def __init__(self, data_tensor, target_tensor):\n", + " self.data_tensor = data_tensor\n", + " self.target_tensor = target_tensor\n", + "\n", + " # 返回数据集大小\n", + " def __len__(self):\n", + " return self.data_tensor.size(0)\n", + "\n", + " # 返回索引的数据与标签\n", + " def __getitem__(self, idx):\n", + " return self.data_tensor[idx], self.target_tensor[idx]" + ], + "id": "1244f4325aad0ac5", + "outputs": [], + "execution_count": 64 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-12T12:59:08.444392Z", + "start_time": "2025-06-12T12:59:08.435084Z" + } + }, + "cell_type": "code", + "source": [ + "# 生成数据\n", + "data_tensor = torch.randn(10, 3)\n", + "target_tensor = torch.randint(2, (10,)) # 二分类任务\n", + "\n", + "print(data_tensor)\n", + "print(target_tensor)" + ], + "id": "e2f8ad8dc637791a", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[-1.1016, 0.8803, 1.3645],\n", + " [ 0.6674, 0.4500, 0.0367],\n", + " [ 1.3528, 0.2438, 0.4721],\n", + " [-1.3651, -1.1315, 0.0299],\n", + " [ 1.4538, 1.3622, -0.5165],\n", + " [ 0.7515, 0.6168, -0.9036],\n", + " [ 1.3542, -0.3779, -1.2439],\n", + " [-0.4588, -0.2233, -0.3531],\n", + " [-0.0515, 0.8951, 0.3544],\n", + " [ 0.0181, 0.4075, -2.0888]])\n", + "tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])\n" + ] + } + ], + "execution_count": 66 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-06-12T12:59:25.092189Z", + "start_time": "2025-06-12T12:59:25.085985Z" + } + }, + "cell_type": "code", + "source": [ + "# 创建数据集\n", + "my_dataset = MyDataset(data_tensor, target_tensor)\n", + "\n", + "# 查看数据集大小\n", + "print(\"dataset size:\",len(my_dataset))\n", + "\n", + "# 使用索引获取数据\n", + "print(\"data:\", my_dataset[0]) # 获取第一个数据" + ], + "id": "cd8f86852f54b0c5", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset size: 10\n", + "data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))\n" + ] + } + ], + "execution_count": 67 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}