返回顶部
首页 > 资讯 > 后端开发 > Python >PyTorch中关于tensor.repeat()的使用
  • 371
分享到

PyTorch中关于tensor.repeat()的使用

Python 官方文档:入门教程 => 点击学习

摘要

目录关于tensor.repeat()的使用Tensor.repeat()的简单用法关于tensor.repeat()的使用 考虑到很多人在学习这个函数,我想在这里提 一个建议: 强

关于tensor.repeat()的使用

考虑到很多人在学习这个函数,我想在这里提 一个建议:

强烈推荐 使用 einops 模块中的 repeat() 函数 替代 tensor.repeat()!

它可以摆脱 tensor.repeat() 参数的神秘主义。

einops 模块文档地址:https://nbviewer.jupyter.org/GitHub/aroGozhnikov/einops/blob/master/docs/1-einops-basics.ipynb

学习 tensor.repeat() 这个函数的功能的时候,最好还是要观察所得到的 结果的维度。

不多说,看代码:

>>> import torch
>>> 
>>> # 定义一个 33x55 张量
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>> 
>>> # 下面开始尝试 repeat 函数在不同参数情况下的效果
>>> a.repeat(1,1).size()     # 原始值:torch.Size([33, 55])
torch.Size([33, 55])
>>> 
>>> a.repeat(2,1).size()     # 原始值:torch.Size([33, 55])
torch.Size([66, 55])
>>> 
>>> a.repeat(1,2).size()     # 原始值:torch.Size([33, 55])
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size()   # 原始值:torch.Size([33, 55])
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size() # 原始值:torch.Size([33, 55])
torch.Size([1, 1, 33, 55])
>>> 
>>> # ------------------ 割割 ------------------
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
>>> # 下面是一些错误示例
>>> a.repeat(2).size()  # 1D < 2D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> # 定义一个3维的张量,然后展示前面提到的那个错误
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>> 
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>

Tensor.repeat()的简单用法

相当于手动实现广播机制,即沿着给定的维度对tensor进行重复:

比如说对下面x的第1个通道复制三次,其余通道保持不变:

import torch

x = torch.randn(1, 3, 224, 224)
y = x.repeat(3, 1, 1, 1)
print(x.shape)
print(y.shape)

结果为:

torch.Size([1, 3, 224, 224])
torch.Size([3, 3, 224, 224])

这个在复制batch的时候用的比较多,上面的情况就相当于batch为1的3×224×224特征图复制成了batch为3

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: PyTorch中关于tensor.repeat()的使用

本文链接: https://lsjlt.com/news/171093.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作