139 lines
3.7 KiB
Plaintext
139 lines
3.7 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"### Dataset类\n",
|
||
"其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了Dataset类。而在继承Dataset类时,至少需要重写以下几个方法:\n",
|
||
"\n",
|
||
"* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n",
|
||
"* \\_\\_len\\_\\_():返回数据集大小;\n",
|
||
"* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。"
|
||
],
|
||
"id": "6f460291da56d507"
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-06-12T12:58:38.694696Z",
|
||
"start_time": "2025-06-12T12:58:38.690734Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import torch\n",
|
||
"from torch.utils.data import Dataset\n",
|
||
"\n",
|
||
"\n",
|
||
"class MyDataset(Dataset):\n",
|
||
" # 构造函数\n",
|
||
" def __init__(self, data_tensor, target_tensor):\n",
|
||
" self.data_tensor = data_tensor\n",
|
||
" self.target_tensor = target_tensor\n",
|
||
"\n",
|
||
" # 返回数据集大小\n",
|
||
" def __len__(self):\n",
|
||
" return self.data_tensor.size(0)\n",
|
||
"\n",
|
||
" # 返回索引的数据与标签\n",
|
||
" def __getitem__(self, idx):\n",
|
||
" return self.data_tensor[idx], self.target_tensor[idx]"
|
||
],
|
||
"id": "1244f4325aad0ac5",
|
||
"outputs": [],
|
||
"execution_count": 64
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-06-12T12:59:08.444392Z",
|
||
"start_time": "2025-06-12T12:59:08.435084Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 生成数据\n",
|
||
"data_tensor = torch.randn(10, 3)\n",
|
||
"target_tensor = torch.randint(2, (10,)) # 二分类任务\n",
|
||
"\n",
|
||
"print(data_tensor)\n",
|
||
"print(target_tensor)"
|
||
],
|
||
"id": "e2f8ad8dc637791a",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"tensor([[-1.1016, 0.8803, 1.3645],\n",
|
||
" [ 0.6674, 0.4500, 0.0367],\n",
|
||
" [ 1.3528, 0.2438, 0.4721],\n",
|
||
" [-1.3651, -1.1315, 0.0299],\n",
|
||
" [ 1.4538, 1.3622, -0.5165],\n",
|
||
" [ 0.7515, 0.6168, -0.9036],\n",
|
||
" [ 1.3542, -0.3779, -1.2439],\n",
|
||
" [-0.4588, -0.2233, -0.3531],\n",
|
||
" [-0.0515, 0.8951, 0.3544],\n",
|
||
" [ 0.0181, 0.4075, -2.0888]])\n",
|
||
"tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 66
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-06-12T12:59:25.092189Z",
|
||
"start_time": "2025-06-12T12:59:25.085985Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 创建数据集\n",
|
||
"my_dataset = MyDataset(data_tensor, target_tensor)\n",
|
||
"\n",
|
||
"# 查看数据集大小\n",
|
||
"print(\"dataset size:\",len(my_dataset))\n",
|
||
"\n",
|
||
"# 使用索引获取数据\n",
|
||
"print(\"data:\", my_dataset[0]) # 获取第一个数据"
|
||
],
|
||
"id": "cd8f86852f54b0c5",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"dataset size: 10\n",
|
||
"data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 67
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|