Files
pytorch-study/06.ipynb

139 lines
3.7 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {},
"cell_type": "markdown",
"source": [
"### Dataset类\n",
"其实这就表示无论使用自定义的数据集还是官方为我们封装好的数据集其本质都是继承了Dataset类。而在继承Dataset类时至少需要重写以下几个方法\n",
"\n",
"* \\_\\_init\\_\\_():构造函数,可自定义数据读取方法以及进行数据预处理;\n",
"* \\_\\_len\\_\\_():返回数据集大小;\n",
"* \\_\\_getitem\\_\\_():索引数据集中的某一个数据。"
],
"id": "6f460291da56d507"
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:58:38.694696Z",
"start_time": "2025-06-12T12:58:38.690734Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class MyDataset(Dataset):\n",
" # 构造函数\n",
" def __init__(self, data_tensor, target_tensor):\n",
" self.data_tensor = data_tensor\n",
" self.target_tensor = target_tensor\n",
"\n",
" # 返回数据集大小\n",
" def __len__(self):\n",
" return self.data_tensor.size(0)\n",
"\n",
" # 返回索引的数据与标签\n",
" def __getitem__(self, idx):\n",
" return self.data_tensor[idx], self.target_tensor[idx]"
],
"id": "1244f4325aad0ac5",
"outputs": [],
"execution_count": 64
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:59:08.444392Z",
"start_time": "2025-06-12T12:59:08.435084Z"
}
},
"cell_type": "code",
"source": [
"# 生成数据\n",
"data_tensor = torch.randn(10, 3)\n",
"target_tensor = torch.randint(2, (10,)) # 二分类任务\n",
"\n",
"print(data_tensor)\n",
"print(target_tensor)"
],
"id": "e2f8ad8dc637791a",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[-1.1016, 0.8803, 1.3645],\n",
" [ 0.6674, 0.4500, 0.0367],\n",
" [ 1.3528, 0.2438, 0.4721],\n",
" [-1.3651, -1.1315, 0.0299],\n",
" [ 1.4538, 1.3622, -0.5165],\n",
" [ 0.7515, 0.6168, -0.9036],\n",
" [ 1.3542, -0.3779, -1.2439],\n",
" [-0.4588, -0.2233, -0.3531],\n",
" [-0.0515, 0.8951, 0.3544],\n",
" [ 0.0181, 0.4075, -2.0888]])\n",
"tensor([0, 0, 1, 0, 0, 1, 1, 1, 1, 0])\n"
]
}
],
"execution_count": 66
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-12T12:59:25.092189Z",
"start_time": "2025-06-12T12:59:25.085985Z"
}
},
"cell_type": "code",
"source": [
"# 创建数据集\n",
"my_dataset = MyDataset(data_tensor, target_tensor)\n",
"\n",
"# 查看数据集大小\n",
"print(\"dataset size\",len(my_dataset))\n",
"\n",
"# 使用索引获取数据\n",
"print(\"data:\", my_dataset[0]) # 获取第一个数据"
],
"id": "cd8f86852f54b0c5",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"dataset size 10\n",
"data: (tensor([-1.1016, 0.8803, 1.3645]), tensor(0))\n"
]
}
],
"execution_count": 67
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}