返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch数据读取与预处理该如何实现
  • 487
分享到

Pytorch数据读取与预处理该如何实现

2024-04-02 19:04:59 487人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。PyTorch帮我们实现了方便的数据读取

  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。PyTorch帮我们实现了方便的数据读取与预处理方法,下面记录两个DEMO,便于加快以后的代码效率。

  根据数据是否一次性读取完,将DEMO分为:

  1、串行式读取。也就是一次性读取完所有需要的数据到内存,模型训练时不会再访问外存。通常用在内存足够的情况下使用,速度更快。

  2、并行式读取。也就是边训练边读取数据。通常用在内存不够的情况下使用,会占用计算资源,如果分配的好的话,几乎不损失速度。

  Pytorch官方的数据提取方式尽管方便编码,但由于它提取数据方式比较死板,会浪费资源,下面对其进行分析。

1  串行式读取

1.1  DEMO代码


import torch 
from torch.utils.data import Dataset,DataLoader 
  
class MyDataSet(Dataset):# ————1————
 def __init__(self):  
  self.data = torch.tensor(range(10)).reshape([5,2])
  self.label = torch.tensor(range(5))

 def __getitem__(self, index):  
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)
 
my_data_set = MyDataSet()# ————2————
my_data_loader = DataLoader(
 dataset=my_data_set,  # ————3————
 batch_size=2,     # ————4————
 shuffle=True,     # ————5————
 sampler=None,     # ————6————
 batch_sampler=None,  # ————7———— 
 num_workers=0 ,    # ————8———— 
 collate_fn=None,    # ————9———— 
 pin_memory=True,    # ————10———— 
 drop_last=True     # ————11————
)

for i in my_data_loader: # ————12————
 print(i)

  注释处解释如下:

  1、重写数据集类,用于保存数据。除了 __init__() 外,必须实现 __getitem__() 和 __len__() 两个方法。前一个方法用于输出索引对应的数据。后一个方法用于获取数据集的长度。

  2~5、 2准备好数据集后,传入DataLoader来迭代生成数据。前三个参数分别是传入的数据集对象、每次获取的批量大小、是否打乱数据集输出。

  6、采样器,如果定义这个,shuffle只能设置为False。所谓采样器就是用于生成数据索引的可迭代对象,比如列表。因此,定义了采样器,采样都按它来,shuffle再打乱就没意义了。

  7、批量采样器,如果定义这个,batch_size、shuffle、sampler、drop_last都不能定义。实际上,如果没有特殊的数据生成顺序的要求,采样器并没有必要定义。torch.utils.data 中的各种 Sampler 就是采样器类,如果需要,可以使用它们来定义。

  8、用于生成数据的子进程数。默认为0,不并行。

  9、拼接多个样本的方法,默认是将每个batch的数据在第一维上进行拼接。这样可能说不清楚,并且由于这里可以探究一下获取数据的速度,后面再详细说明。

  10、是否使用页内存。用的话会更快,内存不充足最好别用。

  11、是否把最后小于batch的数据丢掉。

  12、迭代获取数据并输出。

1.2  速度探索

  首先看一下DEMO的输出:

  输出了两个batch的数据,每组数据中data和label都正确排列,符合我们的预期。那么DataLoader是怎么把数据整合起来的呢?首先,我们把collate_fn定义为直接映射(不用它默认的方法),来查看看每次DataLoader从MyDataSet中读取了什么,将上面部分代码修改如下:


my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=2,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=lambda x:x, #修改处
 pin_memory=True,    
 drop_last=True     
)

  结果如下:

  输出还是两个batch,然而每个batch中,单个的data和label是在一个list中的。似乎可以看出,DataLoader是一个一个读取MyDataSet中的数据的,然后再进行相应数据的拼接。为了验证这点,代码修改如下:


import torch 
from torch.utils.data import Dataset,DataLoader 
  
class MyDataSet(Dataset): 
 def __init__(self):  
  self.data = torch.tensor(range(10)).reshape([5,2])
  self.label = torch.tensor(range(5))

 def __getitem__(self, index):  
  print(index)     #修改处2
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)
 
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=2,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=lambda x:x, #修改处1
 pin_memory=True,    
 drop_last=True     
)

for i in my_data_loader: 
 print(i)

  输出如下:

  验证了前面的猜想,的确是一个一个读取的。如果数据集定义的不是格式化的数据,那还好,但是我这里定义的是tensor,是可以直接通过列表来索引对应的tensor的。因此,DataLoader的操作比直接索引多了拼接这一步,肯定是会慢很多的。一两次的读取还好,但在训练中,大量的读取累加起来,就会浪费很多时间了。

  自定义一个DataLoader可以证明这一点,代码如下:


import torch 
from torch.utils.data import Dataset,DataLoader 
from time import time
  
class MyDataSet(Dataset): 
 def __init__(self):  
  self.data = torch.tensor(range(100000)).reshape([50000,2])
  self.label = torch.tensor(range(50000))

 def __getitem__(self, index):  
  return self.data[index], self.label[index]

 def __len__(self):  
  return len(self.data)

# 自定义DataLoader
class MyDataLoader():
 def __init__(self, dataset,batch_size):
  self.dataset = dataset
  self.batch_size = batch_size
 def __iter__(self):
  self.now = 0
  self.shuffle_i = np.array(range(self.dataset.__len__())) 
  np.random.shuffle(self.shuffle_i)
  return self
 
 def __next__(self): 
  self.now += self.batch_size
  if self.now <= len(self.shuffle_i):
   indexes = self.shuffle_i[self.now-self.batch_size:self.now]
   return self.dataset.__getitem__(indexes)
  else:
   raise StopIteration

# 使用官方DataLoader
my_data_set = MyDataSet() 
my_data_loader = DataLoader(
 dataset=my_data_set,  
 batch_size=256,      
 shuffle=True,      
 sampler=None,     
 batch_sampler=None,  
 num_workers=0 ,    
 collate_fn=None, 
 pin_memory=True,    
 drop_last=True     
)

start_t = time()
for t in range(10):
 for i in my_data_loader: 
  pass
print("官方:", time() - start_t)
 
 
#自定义DataLoader
my_data_set = MyDataSet() 
my_data_loader = MyDataLoader(my_data_set,256)

start_t = time()
for t in range(10):
 for i in my_data_loader: 
  pass
print("自定义:", time() - start_t)

运行结果如下:

  以上使用batch大小为256,仅各读取10 epoch的数据,都有30多倍的时间上的差距,更大的batch差距会更明显。另外,这里用于测试的每个数据只有两个浮点数,如果是图像,所需的时间可能会增加几百倍。因此,如果数据量和batch都比较大,并且数据是格式化的,最好自己写数据生成器。

2  并行式读取

2.1  DEMO代码


import matplotlib.pyplot as plt
from torch.utils.data import DataLoader 
from torchvision import transfORMs 
from torchvision.datasets import ImageFolder 
 
path = r'E:\DataSets\ImageNet\ILSVRC2012_img_train\10-19\128x128'
my_data_set = ImageFolder(      #————1————
 root = path,            #————2————
 transform = transforms.Compose([  #————3————
  transforms.ToTensor(),
  transforms.CenterCrop(64)
 ]),
 loader = plt.imread         #————4————
)
my_data_loader = DataLoader(
 dataset=my_data_set,   
 batch_size=128,       
 shuffle=True,       
 sampler=None,       
 batch_sampler=None,    
 num_workers=0,      
 collate_fn=None,      
 pin_memory=True,      
 drop_last=True 
)      

for i in my_data_loader: 
 print(i)

  注释处解释如下:

  1/2、ImageFolder类继承自DataSet类,因此可以按索引读取图像。路径必须包含文件夹,ImageFolder会给每个文件夹中的图像添加索引,并且每张图像会给予其所在文件夹的标签。举个例子,代码中my_data_set[0] 输出的是图像对象和它对应的标签组成的列表。

  3、图像到格式化数据的转换组合。更多的转换方法可以看 transform 模块。

  4、图像法的读取方式,默认是PIL.Image.open(),但我发现plt.imread()更快一些。

  由于是边训练边读取,transform会占用很多时间,因此可以先将图像转换为需要的形式存入外存再读取,从而避免重复操作。

  其中transform.ToTensor()会把正常读取的图像转换为torch.tensor,并且像素值会映射至[0,1][0,1]。由于plt.imread()读取png图像时,像素值在[0,1][0,1],而读取jpg图像时,像素值却在[0,255][0,255],因此使用transform.ToTensor()能将图像像素区间统一化。

以上就是Pytorch数据读取与预处理该如何实现的详细内容,更多关于Pytorch数据读取与预处理的资料请关注编程网其它相关文章!

--结束END--

本文标题: Pytorch数据读取与预处理该如何实现

本文链接: https://lsjlt.com/news/122443.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Pytorch数据读取与预处理该如何实现
      在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮子都我们自己写的话,是很浪费时间和精力的。Pytorch帮我们实现了方便的数据读取...
    99+
    2024-04-02
  • Pytorch数据读取与预处理的实现方法
    这篇文章给大家分享的是有关Pytorch数据读取与预处理的实现方法的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。  在炼丹时,数据的读取与预处理是关键一步。不同的模型所需要的数据以及预处理方式各不相同,如果每个轮...
    99+
    2023-06-14
  • 如何使用PyTorch实现自由的数据读取
    目录前言PyTorch数据读入函数介绍ImageFolderDatasetDataLoader问题来源自定义数据读入的举例实现总结前言 很多前人曾说过,深度学习好比炼丹,框架就是丹炉...
    99+
    2024-04-02
  • PyTorch数据读取的实现示例
    前言 PyTorch作为一款深度学习框架,已经帮助我们实现了很多很多的功能了,包括数据的读取和转换了,那么这一章节就介绍一下PyTorch内置的数据读取模块吧 模块介绍 pan...
    99+
    2024-04-02
  • pytorch如何读取csv数据集
    要使用PyTorch读取CSV数据集,可以使用Python的pandas库来加载CSV文件,并将其转换为PyTorch张量。下面是一...
    99+
    2023-10-09
    pytorch
  • 如何在Python中实现高效的数据读取和处理?
    Python是一种广泛使用的编程语言,它在数据科学和机器学习领域中非常受欢迎。在这些领域中,处理大量数据是一个常见的任务。因此,在这篇文章中,我们将介绍如何在Python中实现高效的数据读取和处理。 使用Pandas库 Pandas是P...
    99+
    2023-08-11
    日志 numy load
  • pytorch 如何用cuda处理数据
    1 设置GPU的一些操作 设置在os端哪些GPU可见,如果不可见,那肯定是不能够调用的~ import os GPU = '0,1,2' os.environ['CUDA_VIS...
    99+
    2024-04-02
  • vue如何使用ssr实现预取数据
    这篇文章主要介绍了vue如何使用ssr实现预取数据的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇vue如何使用ssr实现预取数据文章都会有所收获,下面我们一起来看看吧。Why在 Vue 的服...
    99+
    2023-07-04
  • Pytorch 如何加速Dataloader提升数据读取速度
    在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能是CPU和你的磁盘。 很多设计不当的任务,在训练神经网络的时候,大部分时间都是在从磁盘中读取数据,而不是做 B...
    99+
    2024-04-02
  • 如何在Pytorch中使用Dataset和DataLoader读取数据
    本篇文章给大家分享的是有关如何在Pytorch中使用Dataset和DataLoader读取数据,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。一、前言确保安装scikit-im...
    99+
    2023-06-15
  • Pytorch如何加速Dataloader提升数据读取速度
    这篇文章将为大家详细讲解有关Pytorch如何加速Dataloader提升数据读取速度,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。在利用DL解决图像问题时,影响训练效率最大的有时候是GPU,有时候也可能...
    99+
    2023-06-15
  • 如何使用pytorch加载并读取COCO数据集
    这篇文章主要介绍“如何使用pytorch加载并读取COCO数据集”,在日常操作中,相信很多人在如何使用pytorch加载并读取COCO数据集问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”如何使用pytorch...
    99+
    2023-06-30
  • 批处理中如何实现预处理
    这篇文章主要为大家展示了“批处理中如何实现预处理”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“批处理中如何实现预处理”这篇文章吧。一、预处理究竟要做什么? 根据我的经验,预处理要做的是变量值的替...
    99+
    2023-06-08
  • Requests库实现数据抓取与处理功能
    目录引言安装基本用法发送HTTP请求处理HTTP响应高级功能总结引言 Requests是Python中一个常用的第三方库,用于向Web服务器发起HTTP请求并获取响应。该库的使用简单...
    99+
    2023-05-20
    Requests库数据抓取与处理 Requests库数据抓取
  • 如何在R语言中实现数据预处理操作
    如何在R语言中实现数据预处理操作?很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。一、项目环境开发工具:RStudioR:3.5.2相关包:infotheo,dis...
    99+
    2023-06-15
  • GO web 数据库预处理的实现
    目录什么是预处理? 那么预处理有啥好处? Go实现 MySQL 的事务 sqlx使用 gin + mysql + rest full api  上一篇文章我们进行了数据操作...
    99+
    2024-04-02
  • 如何在Python读取与存储数据
    这篇文章将为大家详细讲解有关如何在Python读取与存储数据,文章内容质量较高,因此小编分享给大家做个参考,希望大家阅读完这篇文章后对相关知识有一定的了解。一、图示二、csv文件读取csv文件read_csv(file_path or bu...
    99+
    2023-06-15
  • Java中如何读取文件并处理数据类型?
    在Java中,读取文件并处理数据类型是一项基本任务。本文将介绍如何使用Java读取文件,并处理文件中的数据类型。 一、读取文件 Java中读取文件有多种方法,例如使用FileInputStream、BufferedReader等。下面我们...
    99+
    2023-08-15
    文件 数据类型 二维码
  • 中文维基百科文本数据获取与预处理
    照例,先讲下环境,Mac OSX 10.11.2 ,Python 3.4.3。 下载数据 方法1:使用官方dump的xml数据 最新打包的中文文档下载地址是:https://dumps.wikimedia.org/zhwiki/lates...
    99+
    2023-01-31
    中文 维基百科 文本
  • Python点云处理(一)点云数据读取与写入
    目录 0 简述1 LAS/LAZ格式1.1 las/laz数据读取1.2 las/laz数据写入 2 PCD格式2.1 pcd格式读取2.2 pcd格式写入 3 PLY格式3.1 pl...
    99+
    2023-09-27
    python 3d 算法 数据可视化
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作