图片转自:https://discuss.pytorch.org/t/what-we-should-use-align-corners-false/22663/9
import torch.nn.functional as F
import torch
a=torch.arange(12,dtype=torch.float32).reshape(1,2,2,3)
b=F.interpolate(a,size=(4,4),mode='bilinear')
# 这里的(4,4)指的是将后两个维度放缩成4*4的大小
print(a)
print(b)
print('原数组尺寸:',a.shape)
print('size采样尺寸:',b.shape)
输出结果,一二维度大小不会发生变化
# 原数组
tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.]],
[[ 6., 7., 8.],
[ 9., 10., 11.]]]])
# 采样后的数组
tensor([[[[ 0.0000, 0.6250, 1.3750, 2.0000],
[ 0.7500, 1.3750, 2.1250, 2.7500],
[ 2.2500, 2.8750, 3.6250, 4.2500],
[ 3.0000, 3.6250, 4.3750, 5.0000]],
[[ 6.0000, 6.6250, 7.3750, 8.0000],
[ 6.7500, 7.3750, 8.1250, 8.7500],
[ 8.2500, 8.8750, 9.6250, 10.2500],
[ 9.0000, 9.6250, 10.3750, 11.0000]]]])
原数组尺寸: torch.Size([1, 2, 2, 3])
size采样尺寸: torch.Size([1, 2, 4, 4])
# 规定三四维度放缩成4*4大小
size与scale_factor的区别:输入序列时
import torch.nn.functional as F
import torch
a=torch.arange(4*512*14*14,dtype=torch.float32).reshape(4,512,14,14)
b=F.interpolate(a,size=(28,56),mode='bilinear')
c=F.interpolate(a,scale_factor=(4,8),mode='bilinear')
print('原数组尺寸:',a.shape)
print('size采样尺寸:',b.shape)
print('scale_factor采样尺寸:',c.shape)
原数组尺寸: torch.Size([4, 512, 14, 14])
size采样尺寸: torch.Size([4, 512, 28, 56])
# 第三维度放大成28,第四维度放大成56
scale_factor采样尺寸: torch.Size([4, 512, 56, 112])
# 第三维度放大4倍,第四维度放8倍
size与scale_factor的区别:输入整数时
import torch.nn.functional as F
import torch
a=torch.arange(4*512*14*14,dtype=torch.float32).reshape(4,512,14,14)
b=F.interpolate(a,size=28,mode='bilinear')
c=F.interpolate(a,scale_factor=4,mode='bilinear')
print('原数组尺寸:',a.shape)
print('size采样尺寸:',b.shape)
print('scale_factor采样尺寸:',c.shape)
原数组尺寸: torch.Size([4, 512, 14, 14])
size采样尺寸: torch.Size([4, 512, 28, 28])
# 三四维度数组被放大成28*28
scale_factor采样尺寸: torch.Size([4, 512, 56, 56])
# 三四维度数组被放大了4倍
align_corners=True与False的区别
import torch.nn.functional as F
import torch
a=torch.arange(18,dtype=torch.float32).reshape(1,2,3,3)
b=F.interpolate(a,size=(4,4),mode='bicubic',align_corners=True)
c=F.interpolate(a,size=(4,4),mode='bicubic',align_corners=False)
print(a)
print(b)
print(c)
输出结果,具体效果会因mode插值方法而异
tensor([[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.]]]])
# align_corners=True
tensor([[[[ 0.0000, 0.5741, 1.4259, 2.0000],
[ 1.7222, 2.2963, 3.1481, 3.7222],
[ 4.2778, 4.8519, 5.7037, 6.2778],
[ 6.0000, 6.5741, 7.4259, 8.0000]],
[[ 9.0000, 9.5741, 10.4259, 11.0000],
[10.7222, 11.2963, 12.1481, 12.7222],
[13.2778, 13.8519, 14.7037, 15.2778],
[15.0000, 15.5741, 16.4259, 17.0000]]]])
# align_corners=False
tensor([[[[-0.2871, 0.3145, 1.2549, 1.8564],
[ 1.5176, 2.1191, 3.0596, 3.6611],
[ 4.3389, 4.9404, 5.8809, 6.4824],
[ 6.1436, 6.7451, 7.6855, 8.2871]],
[[ 8.7129, 9.3145, 10.2549, 10.8564],
[10.5176, 11.1191, 12.0596, 12.6611],
[13.3389, 13.9404, 14.8809, 15.4824],
[15.1436, 15.7451, 16.6855, 17.2871]]]])
在计算机视觉中,interpolate函数常用于图像的放大(即上采样操作)。比如在细粒度识别领域中,注意力图有时候会对特征图进行裁剪操作,将有用的部分裁剪出来,裁剪后的图像往往尺寸小于原始特征图,这时候如果强制转换成原始图像大小,往往是无效的,会丢掉部分有用的信息。所以这时候就需要用到interpolate函数对其进行上采样操作,在保证图像信息不丢失的情况下,放大图像,从而放大图像的细节,有利于进一步的特征提取工作。
官方文档
torch.nn.functional.interpolate:https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html?highlight=interpolate#torch.nn.functional.interpolate
到此这篇关于Pytorch上下采样函数之F.interpolate数组采样操作的文章就介绍到这了,更多相关Pytorch F.interpolate数组采样内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
您可能感兴趣的文章:
电脑版 - 返回首页
2006-2023 脚本之家 JB51.Net , All Rights Reserved.
苏ICP备14036222号