Files
pytorch-study/04.ipynb

7.0 KiB
Raw Permalink Blame History

In [1]:
import torch
import numpy as np

torch.__version__
Out[1]:
'2.2.1'
In [2]:
a = torch.tensor(1)
b = a.item()
print(a)
print(b)
tensor(1)
1
In [3]:
a = [1, 2, 3]
b = torch.tensor(a)
c = b.numpy().tolist()
print(c)
[1, 2, 3]
In [4]:
a = torch.zeros(2, 3, 5)
print(a.shape)

print(a.size())

print(a.numel())
torch.Size([2, 3, 5])
torch.Size([2, 3, 5])
30
In [6]:
x = torch.rand(2, 3, 5)
print(x.shape)
print(x)
torch.Size([2, 3, 5])
tensor([[[0.1437, 0.3582, 0.4219, 0.4514, 0.6537],
         [0.0089, 0.5737, 0.0201, 0.7728, 0.1827],
         [0.6573, 0.1262, 0.0877, 0.2302, 0.0151]],

        [[0.0757, 0.7126, 0.4238, 0.0535, 0.0578],
         [0.4909, 0.5616, 0.7342, 0.7925, 0.8879],
         [0.3011, 0.1606, 0.2856, 0.8165, 0.4100]]])
In [7]:
# 矩阵转秩
x = x.permute(2, 1, 0)
print(x.shape)
print(x)
torch.Size([5, 3, 2])
tensor([[[0.1437, 0.0757],
         [0.0089, 0.4909],
         [0.6573, 0.3011]],

        [[0.3582, 0.7126],
         [0.5737, 0.5616],
         [0.1262, 0.1606]],

        [[0.4219, 0.4238],
         [0.0201, 0.7342],
         [0.0877, 0.2856]],

        [[0.4514, 0.0535],
         [0.7728, 0.7925],
         [0.2302, 0.8165]],

        [[0.6537, 0.0578],
         [0.1827, 0.8879],
         [0.0151, 0.4100]]])
In [8]:
x = torch.rand(2, 3, 4)
x = x.transpose(1, 0)
print(x.shape)
torch.Size([3, 2, 4])
In [13]:
x = torch.rand(4, 4)
x = x.view(2, 8)
x = x.permute(1, 0)
# x.view(4,4) # 不能直接用view因为view需要连续的内存
x.reshape(4, 4)
Out[13]:
torch.Size([2, 8])
In [29]:
# 增减维度
x = torch.rand(2, 1, 3)
print(x)
x = x.squeeze(1)  # 去掉维度为1的维度

print(x.shape)
print(x)
tensor([[[0.0287, 0.7995, 0.4072]],

        [[0.4378, 0.6384, 0.2777]]])
torch.Size([2, 3])
tensor([[0.0287, 0.7995, 0.4072],
        [0.4378, 0.6384, 0.2777]])
In [30]:
# 增减维度
x = torch.rand(2, 1, 3)
print(x)
x = x.unsqueeze()  # 去掉维度为1的维度

print(x.shape)
print(x)
tensor([[[0.4243, 0.1581, 0.4620]],

        [[0.8510, 0.5490, 0.7694]]])
torch.Size([2, 1, 1, 3])
tensor([[[[0.4243, 0.1581, 0.4620]]],


        [[[0.8510, 0.5490, 0.7694]]]])