PyTorch Cheatsheet

Basics

x = torch.empty(shape)
x = torch.zeros(...)
x = torch.ones(...)

x.dtype  # torch.int, torch.double, torch.float16, ...
x.size()
torch.tensor([values])

z = x + y
z = torch.add(x, y)  # sub, mul, div

y.add_(x)  # in-place

x[0, 0].item()  # Get value from a tensor with one element


x = torch.rand(4, 4)
y = x.view(16)
y = x.view(-1, 8)