Set keepdim with keepdim argument functions PyTorch



This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)

You can set keepdim with the functions which have keepdim argument as shown below:

*Memos:

sum(). *My post explains sum():

import torch

my_tensor = torch.tensor([1, 2, 3, 4])

torch.sum(input=my_tensor)
torch.sum(input=my_tensor, dim=0)
# tensor(10)

torch.sum(input=my_tensor, dim=0, keepdim=True)
# tensor([10])

prod(). *My post explains prod():

import torch

my_tensor = torch.tensor([1, 2, 3, 4])

torch.prod(input=my_tensor)
torch.prod(input=my_tensor, dim=0)
# tensor(24)

torch.prod(input=my_tensor, dim=0, keepdim=True)
# tensor([24])

mean(). *My post explains mean():

import torch

my_tensor = torch.tensor([5., 4., 7., 7.])

torch.mean(input=my_tensor)
torch.mean(input=my_tensor, dim=0)
# tensor(5.7500)

torch.mean(input=my_tensor, dim=0, keepdim=True)
tensor([5.7500])

median(). *My post explains median():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.median(input=my_tensor, dim=0)
# torch.return_types.median(
# values=tensor(5),
# indices=tensor(0))

torch.median(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.median(
# values=tensor([5]),
# indices=tensor([0]))

min(). *My post explains min():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.min(input=my_tensor, dim=0)
# torch.return_types.min(
# values=tensor(4),
# indices=tensor(1))

torch.min(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.min(
# values=tensor([4]),
# indices=tensor([1]))

max(). *My post explains max():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.max(input=my_tensor, dim=0)
# torch.return_types.max(
# values=tensor(7),
# indices=tensor(2))

torch.max(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.max(
# values=tensor([7]),
# indices=tensor([2]))

argmin(). *My post explains argmin():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.argmin(input=my_tensor)
torch.argmin(input=my_tensor, dim=0)
# tensor(1)

torch.argmin(input=my_tensor, keepdim=True)
torch.argmin(input=my_tensor, dim=0, keepdim=True)
# tensor([1])

argmax(). *My post explains argmax():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.argmax(input=my_tensor)
torch.argmax(input=my_tensor, dim=0)
# tensor(2)

torch.argmax(input=my_tensor, keepdim=True)
torch.argmax(input=my_tensor, dim=0, keepdim=True)
# tensor([2])

all(). *My post explains all():

import torch

my_tensor = torch.tensor([True, False, True, False])

torch.all(input=my_tensor)
torch.all(input=my_tensor, dim=0)
# tensor(False)

torch.all(input=my_tensor, keepdim=True)
torch.all(input=my_tensor, dim=0, keepdim=True)
# tensor([False])

any(). *My post explains any():

import torch

my_tensor = torch.tensor([True, False, True, False])

torch.any(input=my_tensor)
torch.any(input=my_tensor, dim=0)
# tensor(True)

torch.any(input=my_tensor, keepdim=True)
torch.any(input=my_tensor, dim=0, keepdim=True)
# tensor([True])


This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)