Files
pytorch-study/08.ipynb

28 KiB
Raw Permalink Blame History

In [8]:
import torchvision.models as models
from sympy.printing.pytorch import torch

google_net = models.googlenet(pretrained=True, )
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[8], line 3
      1 import torch
      2 import torchvision.models as models
----> 3 from sympy.printing.pytorch import torch
      5 google_net = models.googlenet(pretrained=True, )

ModuleNotFoundError: No module named 'sympy.printing.pytorch'
In [9]:
import torch

# æ<><C3A6>å<EFBFBD>分类器的输入特å¾<C3A5>æ•°é‡<C3A9>
fc_in_features = google_net.fc.in_features
print("fc_in_features: ", fc_in_features)

# 查çœåˆ†ç±»å±çš„输出å<C2BA>æ•°
fc_out_features = google_net.fc.out_features
print("fc_out_features: ", fc_out_features)

# 修改分类器
google_net.fc = torch.nn.Linear(fc_in_features, 10)
fc_in_features:  1024
fc_out_features:  1000
In [24]:
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# 加载mnistæ•°æ<C2B0>®é
mnist_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True,
                               target_transform=None)

# å<>32å¼ å¾ç‰‡
tensor_loader = DataLoader(dataset=mnist_dataset, batch_size=32)

data_iter = iter(tensor_loader)
img_tensor, label_tensor = next(data_iter)
print(img_tensor.shape)

grid_tensor = torchvision.utils.make_grid(img_tensor, nrow=8, padding=2)
grid_img = transforms.ToPILImage()(grid_tensor)
display(grid_img)

print(grid_tensor.shape)

torchvision.utils.save_image(grid_tensor, "./mnist_grid.png")
torch.Size([32, 1, 28, 28])
No description has been provided for this image
torch.Size([3, 122, 242])