28 KiB
28 KiB
In [8]:
import torchvision.models as models from sympy.printing.pytorch import torch google_net = models.googlenet(pretrained=True, )
--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) Cell In[8], line 3 1 import torch 2 import torchvision.models as models ----> 3 from sympy.printing.pytorch import torch 5 google_net = models.googlenet(pretrained=True, ) ModuleNotFoundError: No module named 'sympy.printing.pytorch'
In [9]:
import torch # æ<><C3A6>å<EFBFBD>–分类器的输入特å¾<C3A5>æ•°é‡<C3A9> fc_in_features = google_net.fc.in_features print("fc_in_features: ", fc_in_features) # 查看分类层的输出å<C2BA>‚æ•° fc_out_features = google_net.fc.out_features print("fc_out_features: ", fc_out_features) # 修改分类器 google_net.fc = torch.nn.Linear(fc_in_features, 10)
fc_in_features: 1024 fc_out_features: 1000
In [24]:
import torchvision from torchvision import datasets, transforms from torch.utils.data import DataLoader # åŠ è½½mnistæ•°æ<C2B0>®é›† mnist_dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True, target_transform=None) # å<>–32å¼ å›¾ç‰‡ tensor_loader = DataLoader(dataset=mnist_dataset, batch_size=32) data_iter = iter(tensor_loader) img_tensor, label_tensor = next(data_iter) print(img_tensor.shape) grid_tensor = torchvision.utils.make_grid(img_tensor, nrow=8, padding=2) grid_img = transforms.ToPILImage()(grid_tensor) display(grid_img) print(grid_tensor.shape) torchvision.utils.save_image(grid_tensor, "./mnist_grid.png")
torch.Size([32, 1, 28, 28])
torch.Size([3, 122, 242])