Python: Can't call numpy() on Tensor that requires grad

avatar
Borislav Hadzhiev

Last updated: Jun 16, 2023
5 min

banner

# Table of Contents

  1. Python: Can't call numpy() on Tensor that requires grad
  2. Using the no_grad() context manager to solve the error
  3. Getting the error when drawing a scatter plot in matplotlib

# Python: Can't call numpy() on Tensor that requires grad

The Python "RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead" occurs when you try to convert a tensor with a gradient to a NumPy array.

To solve the error, convert your tensor to one that doesn't require a gradient by using detach() .

runtime error cant call numpy on tensor that requires grad

Here is an example of how the error occurs.

main.py
import torch t = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) print(t) # 👉️ tensor([1., 2., 3.], requires_grad=True) print(type(t)) # 👉️ <class 'torch.Tensor'> # ⛔️ RuntimeError: Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead. t = t.numpy()

When the requires_grad attribute is set to True , gradients need to be computed for the Tensor.

To solve the error, use the tensor.detach method to convert the tensor to one that doesn't require a gradient before calling numpy() .

main.py
import torch t = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) print(t) # 👉️ tensor([1., 2., 3.], requires_grad=True) print(type(t)) # 👉️ <class 'torch.Tensor'> # ✅ Call detach() before calling numpy() t = t.detach().numpy() print(t) # 👉️ [1. 2. 3.] print(type(t)) # 👉️ <class 'numpy.ndarray'>

call detach before calling numpy

The tensor.detach() method returns a new Tensor that is detached from the current graph.

The result never requires a gradient.

In other words, the method returns a new tensor that shares the same storage but doesn't track gradients ( requires_grad is set to False ).

The new tensor can safely be converted to a NumPy ndarray by calling the tensor.numpy method.

If you have a list of tensors, use a list comprehension to iterate over the list and call detach() on each tensor.

main.py
import torch t1 = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) t2 = torch.tensor([4.0, 5.0, 6.0], requires_grad=True) tensors = [t1, t2] result = [t.detach().numpy() for t in tensors] # 👇️ [array([1., 2., 3.], dtype=float32), array([4., 5., 6.], dtype=float32)] print(result)

call detach on each tensor

We used a list comprehension to iterate over the list of tensors.

List comprehensions are used to perform some operation for every element or select a subset of elements that meet a condition.

On each iteration, we call detach() before calling numpy() so no error is raised.

# Using the no_grad() context manager to solve the error

You can also use the no_grad() context manager to solve the error.

The context manager disables gradient calculation.

main.py
import torch t = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) print(t) # 👉️ tensor([1., 2., 3.], requires_grad=True) print(type(t)) # 👉️ <class 'torch.Tensor'> with torch.no_grad(): t = t.detach().numpy() print(t) # 👉️ [1. 2. 3.] print(type(t)) # 👉️ <class 'numpy.ndarray'>

using no grad context manager

The no_grad context manager disables gradient calculation.

In the context manager (the indented block), the result of every computation will have requires_grad=False even if the inputs have requires_grad=True .

Calling the numpy() method on a tensor that is attached to a computation graph is not allowed.

We first have to make sure that the tensor is detached before calling numpy() .

# Getting the error when drawing a scatter plot in matplotlib

If you got the error when drawing a scatter plot in matplotlib , try using the torch.no_grad() method as we did in the previous subheading.

main.py
import torch t = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) with torch.no_grad(): # 👉️ YOUR CODE THAT CAUSES THE ERROR HERE

Make sure to add your code to the indented block inside the no_grad() context manager.

The context manager will disable gradient calculation which should resolve the error as long as your code is indented inside the with torch.no_grad() statement.

If the error persists, try to add an import statement for the fastio.basics