21 lines
371 B
Python
21 lines
371 B
Python
import torch
|
|
import numpy as np
|
|
|
|
print(torch.__version__)
|
|
|
|
|
|
A = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape = (2, 3)
|
|
|
|
B0 = torch.stack([A, A], dim=0)
|
|
print("dim=0:\n", B0)
|
|
print("shape:", B0.shape)
|
|
|
|
B1 = torch.stack([A, A], dim=1)
|
|
print("dim=1:\n", B1)
|
|
print("shape:", B1.shape)
|
|
|
|
B2 = torch.stack([A, A], dim=2)
|
|
print("dim=2:\n", B2)
|
|
print("shape:", B2.shape)
|
|
|
|
np.max() |