PyTorch函数解释:cat、stack、transpose、permute、squeeze、unsqueeze_torch.stack permute-程序员宅基地

技术标签: stack  深度学习  pytorch  


原文链接请参考:PyTorch 常用函数解析 | 梦家博客


torch.cat() 张量拼接

对张量沿着某一维度进行拼接。连接后数据的总维数不变。,ps:能拼接的前提是对应的维度相同!!!

例如对两个2维tensor(分别为2*3,1*3)进行拼接,拼接完后变为3*3的2维 tensor。

In [1]: import torch

In [2]: torch.manual_seed(1)
Out[2]: <torch._C.Generator at 0x19e56f02e50>

In [3]: x = torch.randn(2,3)

In [4]: y = torch.randn(1,3)

In [5]: x
Out[5]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661]])

In [6]: y
Out[6]: tensor([[-1.5228,  0.3817, -1.0276]])

In [9]: torch.cat((x,y),0)
Out[9]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661],
        [-1.5228,  0.3817, -1.0276]])

以上dim=0 表示按列进行拼接,dim=1表示按行进行拼接。

代码如下:

In [11]: z = torch.randn(2,2)

In [12]: z
Out[12]:
tensor([[-0.5631, -0.8923],
        [-0.0583, -0.1955]])

In [13]: x
Out[13]:
tensor([[ 0.6614,  0.2669,  0.0617],
        [ 0.6213, -0.4519, -0.1661]])

In [14]: torch.cat((x,z),1)
Out[14]:
tensor([[ 0.6614,  0.2669,  0.0617, -0.5631, -0.8923],
        [ 0.6213, -0.4519, -0.1661, -0.0583, -0.1955]])

torch.stack() 张量堆叠

torch.cat()拼接不会增加新的维度,但torch.stack()则会增加新的维度。

例如对两个1*2 维的 tensor 在第0个维度上stack,则会变为2*1*2的 tensor;在第1个维度上stack,则会变为1*2*2 的tensor。

In [22]: x = torch.randn(1,2)

In [23]: y = torch.randn(1,2)

In [24]: x.shape
Out[24]: torch.Size([1, 2])

In [25]: x = torch.randn(1,2)

In [26]: y = torch.randn(1,2)

In [27]: torch.stack((x,y),0) # 维度0堆叠
Out[27]:
tensor([[[-1.8313,  1.5987]],

        [[-1.2770,  0.3255]]])

In [28]: torch.stack((x,y),0).shape
Out[28]: torch.Size([2, 1, 2])

In [29]: torch.stack((x,y),1) # 维度1堆叠
Out[29]:
tensor([[[-1.8313,  1.5987],
         [-1.2770,  0.3255]]])

In [30]: torch.stack((x,y),1).shape
Out[30]: torch.Size([1, 2, 2])

torch.transpose() 矩阵转置

举例说明

torch.manual_seed(1)
x = torch.randn(2,3)
print(x)

原来x的结果:

 0.6614  0.2669  0.0617
 0.6213 -0.4519 -0.1661
[torch.FloatTensor of size 2x3]

将x的维度互换:x.transpose(0,1) ,其实相当于转置操作!
结果

0.6614  0.6213
 0.2669 -0.4519
 0.0617 -0.1661
[torch.FloatTensor of size 3x2]

torch.permute() 多维度互换

permute是更灵活的transpose,可以灵活的对原数据的维度进行调换,而数据本身不变。

In [31]: x = torch.randn(2,3,4)

In [32]: x
Out[32]:
tensor([[[ 0.7626,  0.4415,  1.1651,  2.0154],
         [ 0.2152, -0.5242, -1.8034, -1.3083],
         [ 0.4100,  0.4085,  0.2579,  1.0950]],

        [[-0.5065,  0.0998, -0.6540,  0.7317],
         [-1.4567,  1.6089,  0.0938, -1.2597],
         [ 0.2546, -0.5020, -1.0412,  0.7323]]])

In [33]: x.shape
Out[33]: torch.Size([2, 3, 4])

In [34]: x.permute(1,0,2) # 0维和1维互换,2维不变!
Out[34]:
tensor([[[ 0.7626,  0.4415,  1.1651,  2.0154],
         [-0.5065,  0.0998, -0.6540,  0.7317]],

        [[ 0.2152, -0.5242, -1.8034, -1.3083],
         [-1.4567,  1.6089,  0.0938, -1.2597]],

        [[ 0.4100,  0.4085,  0.2579,  1.0950],
         [ 0.2546, -0.5020, -1.0412,  0.7323]]])

In [35]: x.permute(1,0,2).shape
Out[35]: torch.Size([3, 2, 4])

torch.squeeze() 和 torch.unsqueeze()

常用来增加或减少维度,如没有batch维度时,增加batch维度为1。

  • squeeze(dim_n)压缩,减少dim_n维度 ,即去掉元素数量为1的dim_n维度。
  • unsqueeze(dim_n),增加dim_n维度,元素数量为1。
In [38]: x = torch.randn(1,3,4)

In [39]: x.shape
Out[39]: torch.Size([1, 3, 4])

In [40]: x
Out[40]:
tensor([[[-0.4791,  0.2912, -0.8317, -0.5525],
         [ 0.6355, -0.3968, -0.6571, -1.6428],
         [ 0.9803, -0.0421, -0.8206,  0.3133]]])

In [41]: x.squeeze()
Out[41]:
tensor([[-0.4791,  0.2912, -0.8317, -0.5525],
        [ 0.6355, -0.3968, -0.6571, -1.6428],
        [ 0.9803, -0.0421, -0.8206,  0.3133]])

In [42]: x.squeeze().shape
Out[42]: torch.Size([3, 4])

In [43]: x.unsqueeze(0)
Out[43]:
tensor([[[[-0.4791,  0.2912, -0.8317, -0.5525],
          [ 0.6355, -0.3968, -0.6571, -1.6428],
          [ 0.9803, -0.0421, -0.8206,  0.3133]]]])

In [44]: x.unsqueeze(0).shape
Out[44]: torch.Size([1, 1, 3, 4])
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/DreamHome_S/article/details/106027070

智能推荐

python opencv 4.1.0 cv2.convertScaleAbs()函数 (通过线性变换将数据转换成8位[uint8])(用于Intel Realsense D435显示depth图像)-程序员宅基地

文章浏览阅读2.7w次,点赞20次,收藏92次。def convertScaleAbs(src, dst=None, alpha=None, beta=None): # real signature unknown; restored from __doc__ """ convertScaleAbs(src[, dst[, alpha[, beta]]]) -> dst . @brief Scales, cal..._cv2.convertscaleabs

C语言编译器(C语言编程软件)完全攻略(第九部分:VS2019使用教程(使用VS2019编写C语言程序))_visualstudio2019编译-程序员宅基地

文章浏览阅读1.1k次,点赞23次,收藏18次。现在,你就可以将 MyDemo.exe 分享给你的朋友了,告诉他们这是你编写的第一个C语言程序。虽然这个程序非常简单,但是你已经越过了第一道障碍,学会了如何编写代码,如何将代码生成可执行程序,这是一个完整的体验。在本教程的基础部分,教大家编写的程序都是这样的“黑窗口”,与我们平时使用的软件不同,它们没有漂亮的界面,没有复杂的功能,只能看到一些文字,这就是控制台程序(Console Application),它与DOS非常相似,早期的计算机程序都是这样的。_visualstudio2019编译

css3d转换_CSS的性感链接转换-程序员宅基地

文章浏览阅读436次。css3d转换 I was recently visiting MooTools Developer Christoph Pojer's website and noticed a sexy link hover effect: when you hover the link, the the link animates and tilts to the left or the ri...

php-读取excel文件_php 读取excel-程序员宅基地

文章浏览阅读373次。php-读取excel文件_php 读取excel

(2022,FreGAN)利用频率分量在有限数据下训练 GAN-程序员宅基地

文章浏览阅读522次,点赞4次,收藏7次。GAN 在拟合数据分布时,往往倾向于拟合低频信息而忽略高频信息。本文提出了一种称为 FreGAN 的频率感知模型,以提高 G 和 D 的频率感知能力。通过鼓励 G 生成更合理高频信号,从而提高有限数据下的合成性能。_fregan

pyc文件逆向_攻防世界python-trade_逆向之旅010_逆向 pyc base64-程序员宅基地

文章浏览阅读1.5k次,点赞3次,收藏8次。使用工具easy python Decompiler 将pyc文件反编译为py文件。我把这个工具放在百度云里了,按需自取,链接(永久): 链接:https://pan.baidu.com/s/14iQRcHyAEpmmycTQZLbwlQ 提取码:3naj_逆向 pyc base64

随便推点

爬虫开发(2)——序列化_爬虫的批处理序列化-程序员宅基地

文章浏览阅读195次。为什么要使用序列化?我们定义了一个字典:aDict = dict(url = 'lu &amp;amp; yi.html', content = 'They will be ...')这里我们把网页 lu &amp;amp; yi.html 作为起始的网页地址,在之后的爬取过程中,将使用新的网页url来替换它。但是当我们关闭工程,重新启动之后,该字典又重新初始化起始网页为lu &amp;amp; yi.htm..._爬虫的批处理序列化

去除数字滤波器后的相位延迟(内有实操代码和效果图)_减小相位延迟的方法-程序员宅基地

文章浏览阅读2.9k次,点赞2次,收藏14次。参考链接https://ww2.mathworks.cn/help/signal/digital-filter-analysis.htmlhttps://ww2.mathworks.cn/help/signal/ug/compensate-for-delay-and-distortion-introduced-by-filters.html?s_tid=srchtitleFIR 线性相位延迟  FIR滤波器是有限长单位冲激响应滤波器,又称为非递归型滤波器,是数字信号处理系统中基本的元件.._减小相位延迟的方法

enumerate()用法_enumerate(a)-程序员宅基地

文章浏览阅读221次。for index , item in enumerate (a , x):for index , item in enumerate (a):这里有n,v俩参数,n先不管,v为a中的元素,比较简单。a=[[8,2],[2,3],[5,4]]print(a)for n , v in enumerate(a): v += v print(v) #print(n)输出[[8, 2], [2, 3], [5, 4]][8, 2, 8, 2][2, 3, 2, 3]_enumerate(a)

计算几何讲义——计算几何中的欧拉定理-程序员宅基地

文章浏览阅读188次。在处理计算几何的问题中,有时候我们会将其看成图论中的graph图,结合我们在图论中学习过的欧拉定理,我们可以通过图形的节点数(v)和边数(e)得到不是那么好求的面数f。 平面图中的欧拉定理: 定理:设G为任意的连通的平面图,则v-e+f=2,v是G的顶点数,e是G的边数,f是G的面数。证明:其实有点类似几何学中的欧拉公式的证明方法,这里采用归纳证明的方法。对m..._怎么证明平面图欧拉定理

c语言中各种括号的作用,C语言中各种类型指针的特性与用法介绍-程序员宅基地

文章浏览阅读750次。C语言中各种类型指针的特性与用法介绍本文主要介绍了C语言中各种类型指针的特性与用法,有需要的朋友可以参考一下!想了解更多相关信息请持续关注我们应届毕业生考试网!指针为什么要区分类型:在同一种编译器环境下,一个指针变量所占用的内存空间是固定的。比如,在16位编译器环境 下,任何一个指针变量都只占用8个字节,并不会随所指向变量的类型而改变。虽然所有的指针都只占8个字节,但不同类型的变量却占不同的字节数..._c语言带括号指针

缅甸文字库 缅甸语字库 缅甸字库算法_0x103c-程序员宅基地

文章浏览阅读9.5k次。字库交流 QQ:2229691219 缅甸语比较特殊、缅甸语有官方和民间之分,二者不同的是编码机制不同,因此这2种缅甸语的字串翻译、处理引擎、字库都是不同的。我们这里只讨论官方语言。 缅文、泰文等婆罗米系文字大多是元音附标文字,一般辅音字母自带默认元音可以发音,真正拼写词句时元音像标点符号一样附标在辅音上下左右的相应位置。由于每个元音位于辅音的具体位置是有自己的规则的,当只书写..._0x103c

推荐文章

热门文章

相关标签