返回顶部
首页 > 资讯 > 后端开发 > Python >解决Pytorch半精度浮点型网络训练的问题
  • 777
分享到

解决Pytorch半精度浮点型网络训练的问题

2024-04-02 19:04:59 777人浏览 泡泡鱼

Python 官方文档:入门教程 => 点击学习

摘要

用PyTorch1.0进行半精度浮点型网络训练需要注意下问题: 1、网络要在GPU上跑,模型和输入样本数据都要cuda().half() 2、模型参数转换为half型,不必索引到每层

PyTorch1.0进行半精度浮点型网络训练需要注意下问题:

1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()

2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可

3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。

另外,SGD算法对于半精度和全精度计算均没有问题。

还有一个问题是不知道是不是网络结构比较小的原因,使用半精度的训练速度还没有全精度快。这个值得后续进一步探索。

对于上面的这个问题,的确是网络很小的情况下,在1080Ti上半精度浮点型没有很明显的优势,但是当网络变大之后,半精度浮点型要比全精度浮点型要快。

但具体快多少和模型的大小以及输入样本大小有关系,我测试的是要快1/6,同时,半精度浮点型在占用内存上比较有优势,对于精度的影响尚未探究。

将网络再变大些,epoch的次数也增大,半精度和全精度的时间差就表现出来了,在训练的时候。

补充:pytorch半精度,混合精度,单精度训练的区别amp.initialize

看代码吧~


mixed_precision = True
try:  # Mixed precision training https://GitHub.com/NVIDIA/apex
    from apex import amp
except:
    mixed_precision = False  # not installed

 model, optimizer = amp.initialize(model, optimizer, opt_level='O1', verbosity=1)

为了帮助提高Pytorch的训练效率,英伟达提供了混合精度训练工具Apex。号称能够在不降低性能的情况下,将模型训练的速度提升2-4倍,训练显存消耗减少为之前的一半。

文档地址是:Https://nvidia.github.io/apex/index.html

该 工具 提供了三个功能,amp、parallel和nORMalization。由于目前该工具还是0.1版本,功能还是很基础的,在最后一个normalization功能中只提供了LayerNorm层的复现,实际上在后续的使用过程中会发现,出现问题最多的是pytorch的BN层。

第二个工具是pytorch的分布式训练的复现,在文档中描述的是和pytorch中的实现等价,在代码中可以选择任意一个使用,实际使用过程中发现,在使用混合精度训练时,使用Apex复现的parallel工具,能避免一些bug。

默认训练方式是 单精度float32


import torch
model = torch.nn.Linear(D_in, D_out)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

半精度 model(img.half())


import torch
model = torch.nn.Linear(D_in, D_out).half()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
for img, label in dataloader:
 out = model(img.half())
 loss = LOSS(out, label)
 loss.backward()
 optimizer.step()
 optimizer.zero_grad()

接下来是混合精度的实现,这里主要用到Apex的amp工具。

代码修改为:

加上这一句封装,


model, optimizer = amp.initialize(model, optimizer, opt_level=“O1”)

import torch
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

for img, label in dataloader:
 out = model(img)
 loss = LOSS(out, label)
 # loss.backward()
 with amp.scale_loss(loss, optimizer) as scaled_loss:
     scaled_loss.backward()

 optimizer.step()
 optimizer.zero_grad()

实际流程为:调用amp.initialize按照预定的opt_level对model和optimizer进行设置。在计算loss时使用amp.scale_loss进行回传。

需要注意以下几点:

在调用amp.initialize之前,模型需要放在GPU上,也就是需要调用cuda()或者to()。

在调用amp.initialize之前,模型不能调用任何分布式设置函数。

此时输入数据不需要在转换为半精度。

在使用混合精度进行计算时,最关键的参数是opt_level。他一共含有四种设置值:‘00',‘01',‘02',‘03'。实际上整个amp.initialize的输入参数很多:

但是在实际使用过程中发现,设置opt_level即可,这也是文档中例子的使用方法,甚至在不同的opt_level设置条件下,其他的参数会变成无效。(已知BUG:使用‘01'时设置keep_batchnorm_fp32的值会报错)

概括起来:

00相当于原始的单精度训练。01在大部分计算时采用半精度,但是所有的模型参数依然保持单精度,对于少数单精度较好的计算(如softmax)依然保持单精度。02相比于01,将模型参数也变为半精度。

03基本等于最开始实验的全半精度的运算。值得一提的是,不论在优化过程中,模型是否采用半精度,保存下来的模型均为单精度模型,能够保证模型在其他应用中的正常使用。这也是Apex的一大卖点。

在Pytorch中,BN层分为train和eval两种操作。

实现时若为单精度网络,会调用CUDNN进行计算加速。常规训练过程中BN层会被设为train。Apex优化了这种情况,通过设置keep_batchnorm_fp32参数,能够保证此时BN层使用CUDNN进行计算,达到最好的计算速度。

但是在一些fine tunning场景下,BN层会被设为eval(我的模型就是这种情况)。此时keep_batchnorm_fp32的设置并不起作用,训练会产生数据类型不正确的bug。此时需要人为的将所有BN层设置为半精度,这样将不能使用CUDNN加速。

一个设置的参考代码如下:


def fix_bn(m):
 classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
     m.eval().half()

model.apply(fix_bn)

实际测试下来,最后的模型准确度上感觉差别不大,可能有轻微下降;时间上变化不大,这可能会因不同的模型有差别;显存开销上确实有很大的降低。

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: 解决Pytorch半精度浮点型网络训练的问题

本文链接: https://lsjlt.com/news/126673.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 解决Pytorch半精度浮点型网络训练的问题
    用Pytorch1.0进行半精度浮点型网络训练需要注意下问题: 1、网络要在GPU上跑,模型和输入样本数据都要cuda().half() 2、模型参数转换为half型,不必索引到每层...
    99+
    2024-04-02
  • 使用Pytorch怎么实现半精度浮点型网络训练
    今天就跟大家聊聊有关使用Pytorch怎么实现半精度浮点型网络训练,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。用Pytorch2.0进行半精度浮点型网络训练需要注意下问题:网络要在...
    99+
    2023-06-15
  • iOS浮点类型精度问题的原因与解决办法
    目录前言如何解决浮点型精度问题四舍五入处理更优的解决方案精度丢失的原因浮点类型的存储方式有效位数指数的存储方式:移位存储double类型总结:输出结果丢失精度原因前言 相信不少人(其...
    99+
    2024-04-02
  • python浮点数运算精度问题如何解决
    在Python中,浮点数运算可能存在精度问题,可以采取以下方法解决:1. 使用Decimal模块:Decimal模块提供了精确的十进...
    99+
    2023-08-26
    python
  • pytorch训练神经网络爆内存的解决方案
    训练的时候内存一直在增加,最后内存爆满,被迫中断。 后来换了一个电脑发现还是这样,考虑是代码的问题。 检查才发现我的代码两次存了loss,只有一个地方写的是loss.item()。...
    99+
    2024-04-02
  • iOS浮点类型精度问题的原因与解决办法是什么
    这篇文章主要为大家分析了iOS浮点类型精度问题的原因与解决办法是什么的相关知识点,内容详细易懂,操作细节合理,具有一定参考价值。如果感兴趣的话,不妨跟着跟随小编一起来看看,下面跟着小编一起深入学习“iOS浮点类型精度问题的原因与解决办法是什...
    99+
    2023-06-29
  • js浮点数精度丢失的问题及解决方法
    本篇内容介绍了“js浮点数精度丢失的问题及解决方法”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!说明在数学计算中,小数会有一定的误差,这是计...
    99+
    2023-06-20
  • C语言中浮点数的精度丢失问题解决
    目录一 先来看一段代码运行结果:二 如何解决(1)浮点数的大小比较(2)含浮点数的表达式和0.0的比较总结一 先来看一段代码 #include<stdio.h> int ...
    99+
    2024-04-02
  • pytorch训练时的显存占用递增的问题解决
    目录遇到的问题:解决方法:补充:Pytorch显存不断增长问题的解决思路遇到的问题: 在pytorch训练过程中突然out of memory。 解决方法: 1. 测试的时候爆显存有...
    99+
    2023-01-15
    pytorch 显存占用递增 pytorch 显存占用
  • JS中浮点数精度问题的分析与解决方法
    目录前言问题的发现浮点数运算后的精度问题toFixed奇葩问题为什么会产生浮点数的存储浮点数的运算解决方法解决toFixed解决浮点数运算精度附:JS浮点数精度问题的一些实用建议总结...
    99+
    2024-04-02
  • Golang处理浮点数遇到的精度问题怎么解决
    这篇文章主要介绍“Golang处理浮点数遇到的精度问题怎么解决”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“Golang处理浮点数遇到的精度问题怎么解决”文章能帮助大家解决问题。一、浮点数是什么?浮...
    99+
    2023-06-29
  • Python中的浮点数计算精度问题是如何解决的?
    Python中的浮点数计算精度问题是如何解决的?在计算机科学中,浮点数计算精度问题是常见的挑战之一。由于计算机内部使用有限的比特位来表示浮点数,所以对于某些小数的表示和运算时,可能会出现精度损失的情况。Python作为一门强大的编程语言,提...
    99+
    2023-10-22
    解决 浮点数 精度
  • 如何解决C语言中浮点数的精度丢失问题
    小编给大家分享一下如何解决C语言中浮点数的精度丢失问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧!一 先来看一段代码#include<stdio.h>...
    99+
    2023-06-26
  • Golang浮点数精度丢失问题扩展包的解决方法
    小编给大家分享一下Golang浮点数精度丢失问题扩展包的解决方法,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!PS: 今天在做项目进行精度处理的时候出现精度丢失问题,在这里跟大家分享下扩展包解决方案, 注意:这个问题是可能...
    99+
    2023-06-08
  • Pytorch训练网络过程中loss突然变为0的解决方案
    问题 // loss 突然变成0 python train.py -b=8 INFO: Using device cpu INFO: Network: 1 inp...
    99+
    2024-04-02
  • pytorch网络模型构建场景的问题如何解决
    今天小编给大家分享一下pytorch网络模型构建场景的问题如何解决的相关知识点,内容详细,逻辑清晰,相信大部分人都还太了解这方面的知识,所以分享这篇文章给大家参考一下,希望大家阅读完这篇文章后有所收获,下面我们一起来了解一下吧。网络模型构建...
    99+
    2023-07-05
  • Python中的浮点数计算精度问题的原因和解决方案有哪些?
    Python中的浮点数计算精度问题的原因和解决方案有哪些?在进行浮点数计算时,我们经常会遇到精度问题。这是由于计算机采用二进制来表示浮点数,而不是十进制。由于二进制无法准确表示一些十进制小数,导致了浮点数计算的精度问题。一、浮点数计算精度问...
    99+
    2023-10-22
    解决方案 (solution) 浮点数 (float) 精度问题 (precision issue)
  • 解决Pytorch在测试与训练过程中的验证结果不一致问题
    引言 今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,经过排查后发现是forward与eval()顺序问题。 现象 此前的错误代码是 ...
    99+
    2024-04-02
  • 如何解决Pytorch在测试与训练过程中的验证结果不一致问题
    小编给大家分享一下如何解决Pytorch在测试与训练过程中的验证结果不一致问题,希望大家阅读完这篇文章之后都有所收获,下面让我们一起去探讨吧!引言今天在使用Pytorch导入此前保存的模型进行测试,在过程中发现输出的结果与验证结果差距甚大,...
    99+
    2023-06-15
  • springboot中使用FastJson解决long类型在js中失去精度的问题
    目录使用FastJson解决long类型在js中失去精度问题1.pom中需要将默认的jackson排除掉2.利用fastJson替换掉jacksonspringboot long精度...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作