requires_grad=True with a tensor, backward() and retain_grad() in PyTorch



This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)

requires_grad(bool, optional-Default:False) with True can enable a tensor to compute and accumulate the gradient as shown below:

*Memos:

  • There are a leaf tensor and non-leaf tensor.
  • data must be float or complex type with requires_grad=True.
  • tensor.backward() can do backpropagation.
  • A gradient is accumulated each time tensor.backward() is called.
  • To call tensor.backward():
    • requires_grad must be True.
    • data must be the scalar(only one element) of float type of the 0D or more D tensor.
  • tensor.grad can get a gradient.
  • To call tensor.retain_grad(), requires_grad must be True.
  • To enable a non-leaf tensor to get a gradient without a warning using tensor.grad, tensor.retain_grad() must be called before it
  • Using retain_graph=True with tensor.backward() prevents error.

With 1 tensor:

import torch

my_tensor = torch.tensor(data=7., requires_grad=True) # Leaf tensor

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), None, True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(2.), True)

my_tensor.backward()

my_tensor, my_tensor.grad, my_tensor.is_leaf
# (tensor(7., requires_grad=True), tensor(3.), True)

With 3 tensors:

import torch

tensor1 = torch.tensor(data=7., requires_grad=True) # Leaf tensor

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), None, True)

tensor1.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2 = tensor1 * 4 # Non-leaf tensor

tensor2.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(1.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), None, False)

tensor2.backward(retain_graph=True)

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3 = tensor2 * 5 # Non-leaf tensor

tensor3.retain_grad()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(5.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(1.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), None, False)

tensor3.backward()

tensor1, tensor1.grad, tensor1.is_leaf
# (tensor(7., requires_grad=True), tensor(25.), True)

tensor2, tensor2.grad, tensor2.is_leaf
# (tensor(28., grad_fn=<MulBackward0>), tensor(6.), False)

tensor3, tensor3.grad, tensor3.is_leaf
# (tensor(140., grad_fn=<MulBackward0>), tensor(1.), False)


This content originally appeared on DEV Community and was authored by Super Kai (Kazuya Ito)