返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch之如何dropout避免过拟合
  • 703
分享到

Pytorch之如何dropout避免过拟合

2024-04-02 19:04:59 703人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

一.做数据 二.搭建神经网络 三.训练 四.对比测试结果 注意:测试过程中,一定要注意模式切换 PyTorch的学习——过拟合 过拟合 过拟合是当数据量较小时或者输出结

一.做数据

在这里插入图片描述

二.搭建神经网络

三.训练

在这里插入图片描述

四.对比测试结果

注意:测试过程中,一定要注意模式切换

在这里插入图片描述

PyTorch学习——过拟合

过拟合

过拟合是当数据量较小时或者输出结果过于依赖某些特定的神经元,训练神经网络训练会发生一种现象。出现这种现象的神经网络预测的结果并不具有普遍意义,其预测结果极不准确。

解决方法

1.增加数据量

2.L1,L2,L3…正规化,即在计算误差值的时候加上要学习的参数值,当参数改变过大时,误差也会变大,通过这种惩罚机制来控制过拟合现象

3.dropout正规化,在训练过程中通过随机屏蔽部分神经网络连接,使神经网络不完整,这样就可以使神经网络的预测结果不会过分依赖某些特定的神经元

例子

这里小编通过dropout正规化的列子来更加形象的了解神经网络的过拟合现象


import torch
import matplotlib.pyplot as plt
N_SAMPLES = 20
N_HIDDEN = 300
# train数据
x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
y = x + 0.3*torch.nORMal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))
# test数据
test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)
test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))
# 可视化
plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.5, label='train')
plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.5, label='test')
plt.legend(loc='upper left')
plt.ylim((-2.5, 2.5))
plt.show()
# 网络一,未使用dropout正规化
net_overfitting = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, N_HIDDEN),
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, 1),
)
# 网络二,使用dropout正规化
net_dropped = torch.nn.Sequential(
    torch.nn.Linear(1, N_HIDDEN),
    torch.nn.Dropout(0.5),  # 随机屏蔽50%的网络连接
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, N_HIDDEN),
    torch.nn.Dropout(0.5),  # 随机屏蔽50%的网络连接
    torch.nn.ReLU(),
    torch.nn.Linear(N_HIDDEN, 1),
)
# 选择优化器
optimizer_ofit = torch.optim.Adam(net_overfitting.parameters(), lr=0.01)
optimizer_drop = torch.optim.Adam(net_dropped.parameters(), lr=0.01)
# 选择计算误差的工具
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(500):
    # 神经网络训练数据的固定过程
    pred_ofit = net_overfitting(x)
    pred_drop = net_dropped(x)
    loss_ofit = loss_func(pred_ofit, y)
    loss_drop = loss_func(pred_drop, y)
    optimizer_ofit.zero_grad()
    optimizer_drop.zero_grad()
    loss_ofit.backward()
    loss_drop.backward()
    optimizer_ofit.step()
    optimizer_drop.step()
    if t % 10 == 0:
        # 脱离训练模式,这里便于展示神经网络的变化过程
        net_overfitting.eval()
        net_dropped.eval() 
        # 可视化
        plt.cla()
        test_pred_ofit = net_overfitting(test_x)
        test_pred_drop = net_dropped(test_x)
        plt.scatter(x.data.numpy(), y.data.numpy(), c='magenta', s=50, alpha=0.3, label='train')
        plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='cyan', s=50, alpha=0.3, label='test')
        plt.plot(test_x.data.numpy(), test_pred_ofit.data.numpy(), 'r-', lw=3, label='overfitting')
        plt.plot(test_x.data.numpy(), test_pred_drop.data.numpy(), 'b--', lw=3, label='dropout(50%)')
        plt.text(0, -1.2, 'overfitting loss=%.4f' % loss_func(test_pred_ofit, test_y).data.numpy(),
                 fontdict={'size': 20, 'color':  'red'})
        plt.text(0, -1.5, 'dropout loss=%.4f' % loss_func(test_pred_drop, test_y).data.numpy(),
                 fontdict={'size': 20, 'color': 'blue'})
        plt.legend(loc='upper left'); plt.ylim((-2.5, 2.5));plt.pause(0.1)
        # 重新进入训练模式,并继续上次训练
        net_overfitting.train()
        net_dropped.train()
plt.ioff()
plt.show()

效果

可以看到红色的线虽然更加拟合train数据,但是通过test数据发现它的误差反而比较大

在这里插入图片描述

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: Pytorch之如何dropout避免过拟合

本文链接: https://lsjlt.com/news/127481.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Pytorch之如何dropout避免过拟合
    一.做数据 二.搭建神经网络 三.训练 四.对比测试结果 注意:测试过程中,一定要注意模式切换 Pytorch的学习——过拟合 过拟合 过拟合是当数据量较小时或者输出结...
    99+
    2024-04-02
  • pytorch Dropout过拟合的操作
    如下所示: import torch from torch.autograd import Variable import matplotlib.pyplot as plt t...
    99+
    2024-04-02
  • pytorch Dropout过拟合的示例分析
    这篇文章主要介绍pytorch Dropout过拟合的示例分析,文中介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们一定要看完!如下所示:import torchfrom torch.autograd&nb...
    99+
    2023-06-15
  • Torch中如何避免过拟合
    数据增强(Data Augmentation):通过对训练数据进行随机变换、裁剪、翻转等操作,增加数据的多样性,从而减少模型对特...
    99+
    2024-03-08
    Torch
  • 如何通过ORM避免直接SQL拼接
    通过使用ORM(对象关系映射)框架,可以避免直接拼接SQL语句。 ORM框架可以将数据库表的结构映射为对象的属性,使开发人员可以通过...
    99+
    2024-04-29
    SQL
  • Neuroph如何解决过拟合和欠拟合问题
    Neuroph是一个开源的Java神经网络库,它提供了一些方法来解决神经网络的过拟合和欠拟合问题。 过拟合问题:过拟合是指模型在训...
    99+
    2024-04-02
  • PHP8如何通过Nullsafe Operator避免空值检查?
    PHP8如何通过Nullsafe Operator避免空值检查?在传统的PHP开发中,我们经常需要对变量进行空值检查,以避免因为变量为空而引发错误。然而,这样的空值检查代码可能会使代码变得冗长,降低代码的可读性和可维护性。幸运的是,在PHP...
    99+
    2023-10-22
    PHP Nullsafe Operator 空值检查
  • hadoop如何通过cachefile来避免数据倾斜
    这篇文章主要介绍了hadoop如何通过cachefile来避免数据倾斜,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。package hello_hadoop;imp...
    99+
    2023-06-02
  • PHP8如何通过Nullsafe Operator避免空指针异常?
    PHP8如何通过Nullsafe Operator避免空指针异常?摘要:Nullsafe Operator是PHP8版本引入的新特性之一,它提供了一种简洁而有效的方式来处理空指针异常。本文将详细介绍Nullsafe Operator的用法,...
    99+
    2023-10-22
    PHP Nullsafe Operator 空指针异常避免
  • golang 函数命名如何避免过于具体或过于抽象?
    为避免函数名过于具体或抽象,应遵循以下最佳实践:描述性:函数名应准确描述其功能,而不使用技术细节。简洁:尽可能简短,但仍能传达函数的含义。可读:容易阅读和理解。 Go 函数命名:避免过...
    99+
    2024-04-22
    golang 函数命名 作用域
  • ASP 编程中如何避免算法复杂度过高?
    ASP(Active Server Pages)是一种动态网页技术,它使用VBScript或JScript等编程语言进行编写。在ASP编程中,算法复杂度过高可能会导致网页响应时间过长,从而影响用户体验。为了避免这种情况的发生,我们需要采取一...
    99+
    2023-08-21
    编程算法 linux 文件
  • Keras中如何处理过拟合问题
    Keras提供了多种方法来处理过拟合问题,以下是一些常用的方法: 早停法(Early Stopping):在训练过程中监控验证集...
    99+
    2024-04-02
  • PHP与MySQL合作时,如何避免0值转义问题?
    PHP与MySQL是一对常见的技术组合,用于构建动态网站和应用程序。然而,有时在处理数据时会遇到一些问题,比如0值转义问题。本文将介绍如何避免在PHP与MySQL合作时出现的0值转义问...
    99+
    2024-02-29
    mysql php 转义
  • PaddlePaddle框架如何应对过拟合问题
    PaddlePaddle框架提供了一些方法来应对过拟合问题: 数据增强:通过对训练数据进行随机旋转、裁剪、缩放等操作,增加训练数...
    99+
    2024-03-08
    PaddlePaddle
  • MySQL问答系列之如何避免ibdata1文件大小暴涨
    0、导读 ibdata1文件是什么? ibdata1是一个用来构建innodb系统表空间的文件,这个文件包含了innodb表的元数据、撤销记录、修改buffer和双写buffer。如果file-per-t...
    99+
    2024-04-02
  • 如何通过celery_one避免Celery定时任务重复执行的问题
    这篇文章主要介绍了如何通过celery_one避免Celery定时任务重复执行的问题,具有一定借鉴价值,感兴趣的朋友可以参考下,希望大家阅读完这篇文章之后大有收获,下面让小编带着大家一起了解一下。在使用Celery统计每日访问数量的时候,发...
    99+
    2023-06-25
  • PHP、Git和IDE:如何在开发过程中避免常见的错误?
    PHP是一种流行的服务器端编程语言,Git是一个开源的分布式版本控制系统,IDE则是一种集成开发环境。在开发过程中,这些工具都是不可或缺的。但是,即使是最经验丰富的开发人员,也难免会犯一些常见的错误。在本文中,我们将讨论如何使用PHP、G...
    99+
    2023-09-11
    git ide 关键字
  • npm和Go:如何在开发过程中避免文件路径错误?
    随着现代前端开发的发展,npm和Go已经成为开发者们常用的工具。然而,对于初学者来说,在使用这些工具时经常会遇到文件路径错误的问题。本文将介绍如何在开发过程中避免这些错误,使开发更加高效。 什么是文件路径错误? 文件路径错误是指在使用npm...
    99+
    2023-06-03
    npm path 文件
  • Go语言中如何避免Shell缓存加载过程中的错误?
    在Go语言的开发过程中,我们经常需要使用Shell命令来完成一些任务。然而,由于Shell缓存的存在,可能会导致在加载过程中出现错误。本文将介绍如何在Go语言中避免Shell缓存加载过程中的错误,并提供一些演示代码。 Shell缓存的概念...
    99+
    2023-08-19
    load shell 缓存
  • PHP 并发处理过程中如何避免日志重定向问题?
    PHP 作为一门强大的服务器端脚本语言,越来越多地被用于处理高并发的场景。在高并发处理过程中,日志记录是非常重要的,可以帮助我们排查问题和了解系统运行情况。但是,由于 PHP 本身的特性,很容易出现日志重定向的问题,导致我们无法正确地记录...
    99+
    2023-06-30
    并发 日志 重定向
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作