feat: rename Jupyter notebook and add custom Dataset class implementation

This commit is contained in:
fada
2025-06-12 21:00:01 +08:00
parent f1bc032f11
commit adeee32a87
2 changed files with 138 additions and 0 deletions

View File

138
06.ipynb Normal file
View File

@ -0,0 +1,138 @@
{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"### Dataset类\n",
"其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法\n",
"\n",
"* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n",
"* \\_\\_len\\_\\_():返回数据集大小;\n",
"* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。"
],
"id": "6f460291da56d507"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:58:38.694696Z",
"start_time": "2025-06-12T12:58:38.690734Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class MyDataset(Dataset):\n",
" # 构造函数\n",
" def __init__(self, data_tensor, target_tensor):\n",
" self.data_tensor = data_tensor\n",
" self.target_tensor = target_tensor\n",
"\n",
" # 返回数据集大小\n",
" def __len__(self):\n",
" return self.data_tensor.size(0)\n",
"\n",
" # 返回索引的数据与标签\n",
" def __getitem__(self, idx):\n",
" return self.data_tensor[idx], self.target_tensor[idx]"
],
"id": "1244f4325aad0ac5",
"outputs": [],
"execution_count": 64
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:59:08.444392Z",
"start_time": "2025-06-12T12:59:08.435084Z"
}
},
"cell_type": "code",
"source": [
"# 生成数据\n",
"data_tensor = torch.randn(10, 3)\n",
"target_tensor = torch.randint(2, (10,)) # 二分类任务\n",
"\n",
"print(data_tensor)\n",
"print(target_tensor)"
],
"id": "e2f8ad8dc637791a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-1.1016, 0.8803, 1.3645],\n",
" [ 0.6674, 0.4500, 0.0367],\n",
" [ 1.3528, 0.2438, 0.4721],\n",
" [-1.3651, -1.1315, 0.0299],\n",
" [ 1.4538, 1.3622, -0.5165],\n",
" [ 0.7515, 0.6168, -0.9036],\n",
" [ 1.3542, -0.3779, -1.2439],\n",
" [-0.4588, -0.2233, -0.3531],\n",
" [-0.0515, 0.8951, 0.3544],\n",
" [ 0.0181, 0.4075, -2.0888]])\n",
"tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])\n"
]
}
],
"execution_count": 66
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:59:25.092189Z",
"start_time": "2025-06-12T12:59:25.085985Z"
}
},
"cell_type": "code",
"source": [
"# 创建数据集\n",
"my_dataset = MyDataset(data_tensor, target_tensor)\n",
"\n",
"# 查看数据集大小\n",
"print(\"dataset size\",len(my_dataset))\n",
"\n",
"# 使用索引获取数据\n",
"print(\"data:\", my_dataset[0]) # 获取第一个数据"
],
"id": "cd8f86852f54b0c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataset size 10\n",
"data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))\n"
]
}
],
"execution_count": 67
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}