import torch import numpy as np print(torch.__version__) A = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape = (2, 3) B0 = torch.stack([A, A], dim=0) print("dim=0:\n", B0) print("shape:", B0.shape) B1 = torch.stack([A, A], dim=1) print("dim=1:\n", B1) print("shape:", B1.shape) B2 = torch.stack([A, A], dim=2) print("dim=2:\n", B2) print("shape:", B2.shape) np.max()