返回顶部
首页 > 资讯 > 精选 >如何对pytorch中不定长序列补齐
  • 518
分享到

如何对pytorch中不定长序列补齐

2023-06-15 07:06:36 518人浏览 安东尼
摘要

小编给大家分享一下如何对PyTorch中不定长序列补齐,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.以下给出两种思路:第一种思路是比较容

小编给大家分享一下如何对PyTorch中不定长序列补齐,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!

第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.

以下给出两种思路:

第一种思路是比较容易想到的, 就是对一个batch的样本进行遍历, 然后使用np.pad对每一个样本进行补齐.

for unit in data:        mask = np.zeros(max_length)        s_len = len(unit[0])    # calculate the length of sequence in each unit        mask[: s_len] = 1        unit[0] = np.pad(unit[0], (0, max_length - s_len), 'constant', constant_values=(0, 0))        mask_batch.append(mask)

但是这种方法在batch size很大的情况下会很慢, 因为使用for循环进行了遍历. 我在实际用的时候, 当batch_size=128时, 一个batch的加载时间甚至是一个batch训练时间的几倍!

因此, 我想到如何并行地对序列进行补齐. 第二种方法的思路就是使用torch中自带的pad_sequence来并行补齐.

batch_sequence = list(map(lambda x: torch.tensor(x[findex]), x_data))batch_data[feat] = torch.nn.utils.rnn.pad_sequence(batch_sequence).T

可以看到这里使用pad_sequence一次性对整个batch进行补齐. 下面对这个函数进行详细说明.

pad_sequence详解

from torch.utils.rnn import pad_sequencea = torch.ones(10)b = torch.ones(6)c = torch.ones(20)abc = pad_sequence([a,b,c])  # shape(20, 3)

注意这个函数接收的是一个元素为tensor的列表, 而不是tensor.

最终, 这个函数会将所有tensor转换为tensor矩阵#shape(max_length, batch_size). 因此, 在使用完后通常还需要转置一下.

补充:PyTorch中用于RNN变长序列填充函数的简单使用

1、PyTorch中RNN变长序列的问题   

RNN在处理变长序列时有它的优势。在分批处理变长序列问题时,每个序列的长度往往不会完全相等,因此针对一个batch中序列长度不一的情况,需要对某些序列进行PAD(填充)操作,使得一个batch内的序列长度相等。   

PyTorch中的pack_padded_sequence和pad_packed_sequence可处理上述问题,以下用一个示例演示这两个函数的简单使用方法。

2、填充函数简介

“压缩”函数:用于将填充后的序列tensor进行压缩,方便RNN处理

pack_padded_sequence(input, lengths, batch_first=False, enforce_sorted=True)

(1)input->被“压缩”的tensor,维度一般为[batch_size,_max_seq_len[,embedding_size]]或者[max_seq_len,batch_size[,embedding_size]]

若input维度为:[batch_size,_max_seq_len[,embedding_size]]

要将batch_first设置为True,这表示input的第一个维度为batch的数量

若input维度为:[max_seq_len,batch_size[,embedding_size]]

要将batch_first设置为False(默认值),这表示input的第一个维度不是batch的数量

(2)lengths->lengths参数表示一个batch中序列真实长度,类型为列表,在例子中详细说明

(3)batch_first->表示batch的数量是否在input的第一维度,默认值为False

(4)enforce_sorted->input中的会自动按照lengths的情况进行排序,默认值为

“解压”函数:该函数与"压缩函数"相对应,经“压缩函数”处理的输入经过RNN得到的最终结果可以利用该函数进行“解压”

pad_packed_sequence(sequence, batch_first=False, padding_value=0.0, total_length=None):

(1)sequence->压缩函数处理过的input经RNN后得到的结果

(2)batch_first->与“压缩”函数中的batch_first一致

(3)padding_value->序列进行填充时使用的索引,默认为0

(4)total_length->暂略

3、PyTorch代码示例

代码如下(示例):

# Create by leslie_miao on 2020/11/1import torchimport torch.nn as nnd_model = 10 # 词嵌入的维度hidden_size = 20 # lstm隐藏层单元数量layer_num = 1 # lstm层数# 输入inputs,维度为[batch_size,max_seq_len]=[3,4],其中0代表填充# 该input包含3个序列,每个序列的真实长度分别为: 4 3 2inputs = torch.tensor([[1,2,3,4],[1,2,3,0],[1,2,0,0]])embedding = nn.Embedding(5,d_model)# 获取词嵌入后的inputs 当前inputs的维度为[batch_size,max_seq_len,d_model]=[3,4,10]inputs = embedding(inputs)# 查看inputs的维度print(inputs.size())# print: torch.Size([3, 4, 10])# 利用“压缩”函数对inputs进行压缩处理,[4,3,2]分别为inputs中序列的真实长度,batch_first=True表示inputs的第一维是batch_sizeinputs = nn.utils.rnn.pack_padded_sequence(inputs,lengths=[4,3,2],batch_first=True)# 查看经“压缩”函数处理过的inputs的维度print(inputs[0].size())# print: torch.Size([9, 10])# 定义RNN网络network = nn.LSTM(input_size=d_model,hidden_size=hidden_size,batch_first=True,num_layers=layer_num)# 初始化RNN相关门参数c_0 = torch.zeros((layer_num,3,hidden_size))h_0 = torch.zeros((layer_num,3,hidden_size)) # [rnn层数,batch_size,hidden_size]# inputs经过RNN网络后得到的结果outputsoutput,(h_n,c_n) = network(inputs,(h_0,c_0))#查看未经“解压函数”处理的outputs维度print(output[0].size())# print: torch.Size([9, 20])# 利用“解压函数”对outputs进行解压操作,其中batch_first设置与“压缩函数相同”,padding_value为0output = nn.utils.rnn.pad_packed_sequence(output,batch_first=True,padding_value=0)# 查看经“解压函数”处理的outputs维度print(output[0].size())# print:torch.Size([3, 4, 20])

看完了这篇文章,相信你对“如何对pytorch中不定长序列补齐”有了一定的了解,如果想了解更多相关知识,欢迎关注编程网精选频道,感谢各位的阅读!

--结束END--

本文标题: 如何对pytorch中不定长序列补齐

本文链接: https://lsjlt.com/news/278582.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 如何对pytorch中不定长序列补齐
    小编给大家分享一下如何对pytorch中不定长序列补齐,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的.以下给出两种思路:第一种思路是比较容...
    99+
    2023-06-15
  • 对pytorch中不定长序列补齐的操作
    第二种方法通常是在load一个batch数据时, 在collate_fn中进行补齐的. 以下给出两种思路: 第一种思路是比较容易想到的, 就是对一个batch的样本进行遍历, 然后使...
    99+
    2024-04-02
  • word中不同长度的文字如何对齐
    在Word中,可以使用以下方法对齐不同长度的文字:1. 使用制表符:在需要对齐的文本前插入一个制表符(Tab键)。这样,所有的文本将...
    99+
    2023-09-11
    word
  • python如何使用sample()函数从指定序列中随机获取指定长度的片段
    这篇文章给大家分享的是有关python如何使用sample()函数从指定序列中随机获取指定长度的片段的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。sample(sequence,k)从指定序列中随机获取指定长度的...
    99+
    2023-06-08
  • C#开发中如何处理对象序列化和反序列化
    C#开发中如何处理对象序列化和反序列化,需要具体代码示例在C#开发中,对象序列化和反序列化是非常重要的概念。序列化是将对象转换为可以在网络上传输或在磁盘上存储的格式,而反序列化则是将序列化后的数据重新转换为原始对象。本文将介绍在C#中如何处...
    99+
    2023-10-22
    序列化 反序列化 对象处理
  • 如何在html中自定义有序列表
    如何在html中自定义有序列表?针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。list-style-type 属性设置列表项标记的类型。语法:元素{list-style-ty...
    99+
    2023-06-15
  • Redis中如何实现自定义序列化器
    在Redis中实现自定义序列化器需要使用Redis的自定义模块功能。Redis的自定义模块功能允许用户编写自定义的功能模块,并在Re...
    99+
    2024-04-29
    Redis
  • python如何对列表中的元素进行排序
    这篇文章主要介绍了python如何对列表中的元素进行排序,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。示例:# sort:排序,对...
    99+
    2024-04-02
  • html中如何解决图片与文字垂直方向不对齐问题
    这篇文章主要介绍html中如何解决图片与文字垂直方向不对齐问题,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!比如说,现在我要做一个简单的删除按钮,只由一个icon和“删除”两个字组成...
    99+
    2024-04-02
  • 如何使用Python中的pickle和JSON进行对象序列化和反序列化
    如何使用Python中的pickle和JSON进行对象序列化和反序列化Python是一种简单而强大的编程语言,其内置了许多有用的库和模块,使开发人员能够快速进行各种任务。其中,pickle和JSON是两个常用的模块,用于对象序列化和反序列化...
    99+
    2023-10-22
    序列化 JSON pickle
  • Go Spring开发技术中,如何实现对象的序列化和反序列化?
    在Go Spring开发中,对象的序列化和反序列化是非常常见的操作。序列化是将对象转换为字节流的过程,而反序列化则是将字节流转换回对象。在本文中,我们将探讨Go Spring开发技术中如何实现对象的序列化和反序列化。 一、JSON序列化和...
    99+
    2023-07-26
    spring 开发技术 对象
  • 如何在C#项目中实现对象序列化XML
    这篇文章给大家介绍如何在C#项目中实现对象序列化XML,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。首先,需要用到的是这两个命名空间(主要)using System.Xml;using System...
    99+
    2023-06-06
  • Redis中如何使用不同的序列化机制
    在Redis中,可以通过配置参数来使用不同的序列化机制。Redis支持多种序列化格式,包括JSON、MsgPack、Protobuf...
    99+
    2024-04-29
    Redis
  • python如何检查给定的字符串是不是回文序列
    这篇文章主要介绍python如何检查给定的字符串是不是回文序列,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!回文序列以下方法会检查给定的字符串是不是回文序列,它首先会把所有字母转化为小写,并移除非英文字母符号。最后,...
    99+
    2023-06-27
  • Laravel中的对象序列化是什么?如何使用它?
    Laravel是一个流行的PHP框架,它提供了许多强大的功能和工具,以帮助开发人员快速构建高质量的Web应用程序。其中一个功能是对象序列化,它可以帮助您在应用程序中轻松地处理和存储对象。 在本文中,我们将深入探讨Laravel中的对象序列化...
    99+
    2023-09-25
    编程算法 laravel 对象
  • Python中如何通过itemgetter对字典列表进行排序
    本篇文章为大家展示了Python中如何通过itemgetter对字典列表进行排序,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。前言:我们有一个字典列表,想根据一个或多个字典中的值对列表进行排序。利用...
    99+
    2023-06-02
  • pandas中按行或列的值对数据排序如何实现
    本文小编为大家详细介绍“pandas中按行或列的值对数据排序如何实现”,内容详细,步骤清晰,细节处理妥当,希望这篇“pandas中按行或列的值对数据排序如何实现”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。在处理...
    99+
    2023-07-05
  • 如何使用Python中的pickle模块进行对象序列化
    如何使用Python中的pickle模块进行对象序列化概述:在Python编程中,我们经常需要将数据保存到文件或通过网络传输。而对象序列化是一种将对象转化为可存储或传输的格式的过程,而pickle模块正是Python中一种常用的序列化模块。...
    99+
    2023-10-22
    Python pickle 对象序列化
  • 如何使用C#中的List.Sort函数对列表进行排序
    如何使用C#中的List.Sort函数对列表进行排序在C#编程语言中,我们经常需要对列表进行排序操作。而List类的Sort函数正是为此设计的一个强大工具。本文将介绍如何使用C#中的List.Sort函数对列表进行排序,并提供具体的代码示例...
    99+
    2023-11-17
    C# list sort
  • 如何在Java项目中利用序列化与反序列化将对象文件写入与导出
    如何在Java项目中利用序列化与反序列化将对象文件写入与导出?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。Java类中对象的序列化工作是通过ObjectOutp...
    99+
    2023-05-31
    java 序列化 反序列化
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作