相关文章推荐
要出家的大熊猫  ·  adb shell getevent - ...·  1 年前    · 
傲视众生的电影票  ·  sed.exe & ...·  1 年前    · 
冷冷的洋葱  ·  c# - How to pass ...·  2 年前    · 

tensor 数据类型转换

Tensor 是 PyTorch 中的重要数据类型,它主要用于深度学习模型的训练和推理。在 PyTorch 中,tensor 的数据类型可以通过 dtype 参数来指定。常见的 dtype 类型包括 float32, float64, int32, int64 等。如果需要将 tensor 的数据类型转换为其他类型,可以使用 torch.to 方法。

格式:torch.to(dtype)

x = torch.tensor([1, 2, 3], dtype=torch.float32) print("x:",x) x = x.to(torch.float64) print("x:",x)

在这个例子中,我们首先创建了一个 float32 类型的 tensor,然后使用 to 方法将其转换为 float64 类型。

除此之外,也可以使用 numpy() 函数将 tensor 转换为 numpy array,或者使用 from_numpy() 函数将 numpy array 转换为 tensor。

import numpy as np x = torch.tensor([1, 2, 3], dtype=torch.float32) print("x:",x) x_np = x.numpy() print("x_np:",x_np) x_torch = torch.from_numpy(x_np) print("x_torch:",x_torch)

在这个例子中,我们首先创建了一个 float32 类型的 tensor,然后使用 numpy() 函数将其转换为 numpy array,最后再使用 from_numpy() 将 numpy array 转换为 tensor。

  •