返回顶部
首页 > 资讯 > 后端开发 > Python >dataloader各项参数详解
  • 898
分享到

dataloader各项参数详解

pytorch深度学习python 2023-10-20 15:10:51 898人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

在学习某一神经网络框架时,数据流总是能帮助大家更好地理解整个模型的运行逻辑/顺序,而其中Dataloader的作用在某些时候更是至关重要的。 笔者将自己的学习到的关于dataloader的创建,作

学习某一神经网络框架时,数据流总是能帮助大家更好地理解整个模型的运行逻辑/顺序,而其中Dataloader的作用在某些时候更是至关重要的。
笔者将自己的学习到的关于dataloader的创建,作用尽可能详细地记录下来以方便日后回顾,也欢迎各位匹配指正。

一句话概括

Dataloader本质是一个迭代器对象,也就是可以通过
for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

dataloader的初始化

以openpcdet框架下的dataloader初始化为例:

#in pcdet/datasets/__init__.py    dataloader = DataLoader(        dataset,        batch_size=batch_size,        pin_memory=True,        num_workers=workers,        shuffle=(sampler is None) and training,        collate_fn=dataset.collate_batch, #将一个list的sample组成一个mini-batch的函数        drop_last=False, sampler=sampler, timeout=0    )

下面结合pytorch官方文档来详细解释每个参数的意义
在这里插入图片描述 1. dataset (Dataset) – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. __getitem__ && __len__ )
2. batch_size (int, optional) – how many samples per batch to load(default: 1).

  1. pin_memory(bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

    当设置为True时,将会在返回**batch之前将batch**数据复制到固定的内存区域,这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

    通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用**pin_memory**可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

    需要注意的是,使用**pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory**并不会带来明显的加速效果。

  2. num_workers (int, optional) – how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)

    这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为**0**,即在主进程中进行数据加载,而不使用额外的子进程。
    下面说一下个人的理解,在初始化 dataloader 对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程数=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)
    每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个 (即一帧),并将其放到该worker独有的内存队列中。
    要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。
    当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。
    这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解
    关于num_worker的一些个人理解

  3. collate_fn (Callable, optional) – merges a list of samples to fORM a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
    整合多个样本到一个batch时需要调用的函数,当 __getitem__ 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足PyTorch的输入要求
    比如在openpcdet框架的poinpillar中, __getitem__ 返回的是一个包含标注信息、点云信息、图像信息等的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包
    在poinpillar中该函数为:

    @staticmethod    def collate_batch(batch_list, _unused=False):        """       由于训练集中不同的点云的gt框个数不同,需要重写collate_batch函数,       将不同item的boxes和labels等key放入list,返回batch_size的数据       """        # defaultdict创建一个带有默认返回值的字典,当key不存在时,返回默认值,list默认返回一个空        data_dict = defaultdict(list)        # 把batch里面的每个sample按照key-value合并        for cur_sample in batch_list:            for key, val in cur_sample.items():                data_dict[key].append(val)        batch_size = len(batch_list)        ret = {}        # 将合并后的key内的value进行拼接,先获取最大值,构造空矩阵,不足的部分补0        # 因为pytorch要求输入数据维度一致        for key, val in data_dict.items():            try:                # voxels: optional (num_voxels, max_points_per_voxel, 3 + C)                # voxel_coords: optional (num_voxels, 3)                # voxel_num_points: optional (num_voxels)                if key in ['voxels', 'voxel_num_points']:                    ret[key] = np.concatenate(val, axis=0)                elif key in ['points', 'voxel_coords']:                    coors = []                    for i, coor in enumerate(val):                        #在每个坐标前面加上序号 e.g. shape (N, 4) -> (N, 5)  [20, 30, 40, 0.4] -> [i, 20, 30, 40, 0.4]                        # 在scatter起到作用,因为这时候(生成伪图像)就是分batch操作了,需要根据batch_idx 即 下面函数的                          # constant_values 来区分voxel属于哪一帧                        coor_pad = np.pad(coor, ((0, 0), (1, 0)), mode='constant', constant_values=i)                        """((0,0),(1,0))在二维数组array第一维(此处便是行)前面填充0行,最后面填充0行;在二维数组array第二维(此处便是列)前面填充1列,最后面填充0列mode='constant'表示指定填充的参数constant_values=i 表示第一维填充i                        """                        coors.append(coor_pad)                    ret[key] = np.concatenate(coors, axis=0)  # (B, N, 5) -> (B*N, 5)                elif key in ['gt_boxes']:                    # 获取一个batch中所有帧中3D box最大的数量                    max_gt = max([len(x) for x in val])                    # 构造空的box3d矩阵(B, N, 7)                    batch_gt_boxes3d = np.zeros((batch_size, max_gt, val[0].shape[-1]), dtype=np.float32)                    for k in range(batch_size):                        batch_gt_boxes3d[k, :val[k].__len__(), :] = val[k]                    ret[key] = batch_gt_boxes3d                # gt_boxes2d同gt_boxes                elif key in ['gt_boxes2d']:                    max_boxes = 0                    max_boxes = max([len(x) for x in val])                    batch_boxes2d = np.zeros((batch_size, max_boxes, val[0].shape[-1]), dtype=np.float32)                    for k in range(batch_size):                        if val[k].size > 0:batch_boxes2d[k, :val[k].__len__(), :] = val[k]                    ret[key] = batch_boxes2d                elif key in ["images", "depth_maps"]:                    # Get largest image size (H, W)                    max_h = 0                    max_w = 0                    for image in val:                        max_h = max(max_h, image.shape[0])                        max_w = max(max_w, image.shape[1])                    # Change size of images                    images = []                    for image in val:                        pad_h = common_utils.get_pad_params(desired_size=max_h, cur_size=image.shape[0])                        pad_w = common_utils.get_pad_params(desired_size=max_w, cur_size=image.shape[1])                        pad_width = (pad_h, pad_w)                        # Pad with nan, to be replaced later in the pipeline.                        pad_value = np.nan                        if key == "images":pad_width = (pad_h, pad_w, (0, 0))                        elif key == "depth_maps":pad_width = (pad_h, pad_w)                        image_pad = np.pad(image,               pad_width=pad_width,               mode='constant',               constant_values=pad_value)                        images.append(image_pad)                    ret[key] = np.stack(images, axis=0)                else:                    ret[key] = np.stack(val, axis=0)            except:                print('Error in collate_batch: key=%s' % key)                raise TypeError        ret['batch_size'] = batch_size        return ret
  1. sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shufflemust not be specified.

    sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。
    比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。
    此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。
    因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

  2. 待续

来源地址:https://blog.csdn.net/vonct/article/details/130263743

--结束END--

本文标题: dataloader各项参数详解

本文链接: https://lsjlt.com/news/433055.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • dataloader各项参数详解
    在学习某一神经网络框架时,数据流总是能帮助大家更好地理解整个模型的运行逻辑/顺序,而其中Dataloader的作用在某些时候更是至关重要的。 笔者将自己的学习到的关于dataloader的创建,作...
    99+
    2023-10-20
    pytorch 深度学习 python
  • pytorch中dataloader的sampler参数详解
    目录1. dataloader() 初始化函数2. shuffle 与sample 之间的关系3. sample 的定义方法3.1 sampler 参数的使用4. batch 生成过...
    99+
    2024-04-02
  • Mysql explain 各参数详解
    id序号 select_type simple:即简单select 查询,不包含union及子查询; primary:最外层的 select 查询; union:表示此查询是 union 的第二或随后的查...
    99+
    2016-04-15
    Mysql explain 各参数详解
  • pytorch DataLoader的num_workers参数与设置大小详解
    Q:在给Dataloader设置worker数量(num_worker)时,到底设置多少合适?这个worker到底怎么工作的? train_loader = torch....
    99+
    2024-04-02
  • CrystalDiskInfo 各项参数说明电脑硬盘详细参数
    CrystalDiskInfo 各项参数说明 Mr_Pmc 于 2021-02-24 19:02:52 发布 27571  收藏 39 分类专栏: Apple 文章标签: 服务器 负载均衡 版权 华为云开发者联盟 该内容已被华为云开发者联盟...
    99+
    2023-09-10
    服务器 运维
  • 阿里云服务器各项参数讲解
    阿里云服务器是阿里云推出的一项云计算服务,提供了丰富的各项参数,让用户可以根据自己的需求选择最适合的服务器配置。本文将对阿里云服务器的各项参数进行详细讲解,帮助用户更好地理解并选择阿里云服务器。 一、CPU参数阿里云服务器的CPU参数主要包...
    99+
    2023-12-17
    阿里 参数 服务器
  • Pytorch使用技巧之Dataloader中的collate_fn参数详析
    以MNIST为例 from torchvision import datasets mnist = datasets.MNIST(root='./data/', train=True...
    99+
    2024-04-02
  • 怎么理解redis info memory命令的各项参数
    这篇文章主要讲解了“怎么理解redis info memory命令的各项参数”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“怎么理解redis info me...
    99+
    2024-04-02
  • iframe中各项参数的示例分析
    这篇文章主要为大家展示了“iframe中各项参数的示例分析”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“iframe中各项参数的示例分析”这篇文章吧。ifram...
    99+
    2024-04-02
  • Python 之 print 函数语法格式及各参数详解
    1、print语法格式 print() 函数具有丰富的功能,详细语法格式如下: print(value, …, sep=’ ‘, end=’\n’, file=sys.stdout, flush=False) 默认情况下,将值打印到流或sy...
    99+
    2023-09-13
    数学建模
  • spring-AOP 及 AOP获取request各项参数操作
    spring-AOP 及 AOP获取request各项参数 AOP称为面向切面编程,在程序开发中主要用来解决一些系统层面上的问题,比如日志,事务,权限等待。 一、AOP的基本概念 ...
    99+
    2024-04-02
  • spring-AOP 及 AOP怎么获取request各项参数
    这篇文章主要介绍“spring-AOP 及 AOP怎么获取request各项参数”,在日常操作中,相信很多人在spring-AOP 及 AOP怎么获取request各项参数问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对...
    99+
    2023-06-20
  • Pytorch中DataLoader的使用方法详解
    目录一:dataset类构建。二:DataLoader使用三:举例前言加载数据datasetdataloader在Pytorch中,torch.utils.data中的Dataset...
    99+
    2024-04-02
  • python正则表达式re.sub各个参数的超详细讲解
    目录一、re.sub(pattern, repl, string, count=0, flags=0)二、参数讲解1、pattern参数2、repl参数2.1、repl是字符串2.2...
    99+
    2024-04-02
  • spring task @Scheduled注解各参数的用法
    目录参数详解1. cron2. zone3. fixedDelay4. fixedDelayString5. fixedRate6. fixedRateString7. initia...
    99+
    2024-04-02
  • OGG参数详解
    一直以来对oracle goldengate许多参数比较疑惑,正好在MOS看到这个文章,转载到BLOG,以备参考 Objective: This paper provides samp...
    99+
    2024-04-02
  • PyTorch Dataset与DataLoader使用超详细讲解
    目录一、Dataset1. 在控制台进行操作①获取图片的基本信息②获取文件的基本信息2. 编写一个继承Dataset 的类加载数据①定义 MyData类②创建类的实例并调用二、Dat...
    99+
    2024-04-02
  • linux shell命令行选项与参数用法详解
    问题描述:在linux shell中如何处理tail -n 10 access.log这样的命令行选项?在bash中,可以用以下三种方式来处理命令行参数,每种方式都有自己的应用场景。1,直接处理,依次对$1...
    99+
    2022-06-04
    命令行 详解 选项
  • vue 项目优雅的对url参数加密详解
    目录实现方案:stringifyQuery 和 parseQuery更进一步:相关实现原理实现方案:stringifyQuery 和 parseQuery 近期因为公司内部的安全检查...
    99+
    2022-11-13
    vue url 参数加密 vue url
  • RMI 各参数意义sun.rmi Properties
    sun.rmi PropertiesWARNING: The properties described here are not supported, can change at any time, and only exist in ce...
    99+
    2023-06-03
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作