返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch BCELoss和BCEWithLogitsLoss的使用
  • 125
分享到

Pytorch BCELoss和BCEWithLogitsLoss的使用

2024-04-02 19:04:59 125人浏览 独家记忆

Python 官方文档:入门教程 => 点击学习

摘要

BCELoss 在图片多标签分类时,如果3张图片分3类,会输出一个3*3的矩阵。 先用Sigmoid给这些值都搞到0~1之间: 假设Target是: 下面我们用BCELoss

BCELoss

在图片多标签分类时,如果3张图片分3类,会输出一个3*3的矩阵。

先用Sigmoid给这些值都搞到0~1之间:

假设Target是:

下面我们用BCELoss来验证一下Loss是不是0.7194!

emmm应该是我上面每次都保留4位小数,算到最后误差越来越大差了0.0001。不过也很厉害啦哈哈哈哈哈!

BCEWithLogitsLoss

BCEWithLogitsLoss就是把Sigmoid-BCELoss合成一步。我们直接用刚刚的input验证一下是不是0.7193:

嘻嘻,我可真是太厉害啦!

补充:Pytorch中BCELoss,BCEWithLogitsLoss和CrossEntropyLoss的区别

BCEWithLogitsLoss = Sigmoid+BCELoss

网络最后一层使用nn.Sigmoid时,就用BCELoss,当网络最后一层不使用nn.Sigmoid时,就用BCEWithLogitsLoss。

(BCELoss)BCEWithLogitsLoss

用于单标签二分类或者多标签二分类,输出和目标的维度是(batch,C),batch是样本数量,C是类别数量,对于每一个batch的C个值,对每个值求sigmoid到0-1之间,所以每个batch的C个值之间是没有关系的,相互独立的,所以之和不一定为1。

每个C值代表属于一类标签的概率。如果是单标签二分类,那输出和目标的维度是(batch,1)即可。

CrossEntropyLoss用于多类别分类

输出和目标的维度是(batch,C),batch是样本数量,C是类别数量,每一个C之间是互斥的,相互关联的,对于每一个batch的C个值,一起求每个C的softmax,所以每个batch的所有C个值之和是1,哪个值大,代表其属于哪一类。如果用于二分类,那输出和目标的维度是(batch,2)。

补充:Pytorch踩坑记之交叉熵(nn.CrossEntropy,nn.NLLLoss,nn.BCELoss的区别和使用)

PyTorch中的交叉熵函数的血泪史要从nn.CrossEntropyLoss()这个损失函数开始讲起。

从表面意义上看,这个函数好像是普通的交叉熵函数,但是如果你看过一些Pytorch的资料,会告诉你这个函数其实是softmax()和交叉熵的结合体。

然而如果去官方看这个函数的定义你会发现是这样子的:

哇,竟然是nn.LogSoftmax()和nn.NLLLoss()的结合体,这俩都是什么玩意儿啊。再看看你会发现甚至还有一个损失叫nn.Softmax()以及一个叫nn.nn.BCELoss()。我们来探究下这几个损失到底有何种关系。

nn.Softmax和nn.LogSoftmax

首先nn.Softmax()官网的定义是这样的:

嗯...就是我们认识的那个softmax。那nn.LogSoftmax()的定义也很直观了:

果不其然就是Softmax取了个log。可以写个代码测试一下:


import torch
import torch.nn as nn
 
a = torch.Tensor([1,2,3])
#定义Softmax
softmax = nn.Softmax()
sm_a = softmax=nn.Softmax()
print(sm)
#输出:tensor([0.0900, 0.2447, 0.6652])
 
#定义LogSoftmax
logsoftmax = nn.LogSoftmax()
lsm_a = logsoftmax(a)
print(lsm_a)
#输出tensor([-2.4076, -1.4076, -0.4076]),其中ln(0.0900)=-2.4076

nn.NLLLoss

上面说过nn.CrossEntropy()是nn.LogSoftmax()和nn.NLLLoss的结合,nn.NLLLoss官网给的定义是这样的:

The negative log likelihood loss. It is useful to train a classification problem with C classes

负对数似然损失 ,看起来好像有点晦涩难懂,写个代码测试一下:


import torch
import torch.nn
 
a = torch.Tensor([[1,2,3]])
nll = nn.NLLLoss()
target1 = torch.Tensor([0]).long()
target2 = torch.Tensor([1]).long()
target3 = torch.Tensor([2]).long()
 
#测试
n1 = nll(a,target1)
#输出:tensor(-1.)
n2 = nll(a,target2)
#输出:tensor(-2.)
n3 = nll(a,target3)
#输出:tensor(-3.)

看起来nn.NLLLoss做的事情是取出a中对应target位置的值并取负号,比如target1=0,就取a中index=0位置上的值再取负号为-1,那这样做有什么意义呢,要结合nn.CrossEntropy往下看。

nn.CrossEntropy

看下官网给的nn.CrossEntropy()的表达式:

看起来应该是softmax之后取了个对数,写个简单代码测试一下:


import torch
import torch.nn as nn
 
a = torch.Tensor([[1,2,3]])
target = torch.Tensor([2]).long()
logsoftmax = nn.LogSoftmax()
ce = nn.CrossEntropyLoss()
nll = nn.NLLLoss()
 
#测试CrossEntropyLoss
cel = ce(a,target)
print(cel)
#输出:tensor(0.4076)
 
#测试LogSoftmax+NLLLoss
lsm_a = logsoftmax(a)
nll_lsm_a = nll(lsm_a,target)
#输出tensor(0.4076)

看来直接用nn.CrossEntropy和nn.LogSoftmax+nn.NLLLoss是一样的结果。为什么这样呢,回想下交叉熵的表达式:

l(x,y)=-\sum y*logx=\left\{\begin{matrix} -logx , y=1& \\ 0,y=0& \end{matrix}\right.

其中y是label,x是prediction的结果,所以其实交叉熵损失就是负的target对应位置的输出结果x再取-log。这个计算过程刚好就是先LogSoftmax()再NLLLoss()。

------------------------------------

所以我认为nn.CrossEntropyLoss其实应该叫做softmaxloss更为合理一些,这样就不会误解了。

nn.BCELoss

你以为这就完了吗,其实并没有。还有一类损失叫做BCELoss,写全了的话就是Binary Cross Entropy Loss,就是交叉熵应用于二分类时候的特殊形式,一般都和sigmoid一起用,表达式就是二分类交叉熵:

直觉上和多酚类交叉熵的区别在于,不仅考虑了y_n=1的样本,也考虑了y_n=0的样本的损失。

总结

nn.LogSoftmax是在softmax的基础上取自然对数nn.NLLLoss是负的似然对数损失,但Pytorch的实现就是把对应target上的数取出来再加个负号,要在CrossEntropy中结合LogSoftmax来用BCELoss是二分类的交叉熵损失,Pytorch实现中和多分类有区别

Pytorch是个深坑,让我们一起扎根使用手册,结合实践踏平这些坑吧暴风哭泣

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: Pytorch BCELoss和BCEWithLogitsLoss的使用

本文链接: https://lsjlt.com/news/125813.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Pytorch BCELoss和BCEWithLogitsLoss的使用
    BCELoss 在图片多标签分类时,如果3张图片分3类,会输出一个3*3的矩阵。 先用Sigmoid给这些值都搞到0~1之间: 假设Target是: 下面我们用BCELoss...
    99+
    2024-04-02
  • BCELoss和BCEWithLogitsLoss怎么在Pytorch中使用
    BCELoss和BCEWithLogitsLoss怎么在Pytorch中使用?相信很多没有经验的人对此束手无策,为此本文总结了问题出现的原因和解决方法,通过这篇文章希望你能解决这个问题。BCELoss在图片多标签分类时,如果3张图片分3类,...
    99+
    2023-06-15
  • pytorch中的model.eval()和BN层的使用
    看代码吧~ class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvN...
    99+
    2024-04-02
  • PyTorch中的train()、eval()和no_grad()的使用
    目录什么是train()函数?什么是eval()函数?什么是no_grad()函数?train()、eval()和no_grad()函数的联系总结在PyTorch中,train()、...
    99+
    2023-05-14
    PyTorch train() eval() no_grad()
  • pytorch中F.avg_pool1d()和F.avg_pool2d()的使用操作
    F.avg_pool1d()数据是三维输入 input维度: (batch_size,channels,width)channel可以看成高度 kenerl维度:(一维:表示widt...
    99+
    2024-04-02
  • Pytorch 中net.train 和 net.eval的使用说明
    在训练模型时会在前面加上: model.train() 在测试模型时在前面使用: model.eval() 同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针...
    99+
    2024-04-02
  • Pytorch中Softmax和LogSoftmax的使用详解
    一、函数解释 1.Softmax函数常用的用法是指定参数dim就可以: (1)dim=0:对每一列的所有元素进行softmax运算,并使得每一列所有元素和为1。 (2)dim=1:对...
    99+
    2024-04-02
  • Pytorch中的model.train()和model.eval()怎么使用
    本文小编为大家详细介绍“Pytorch中的model.train()和model.eval()怎么使用”,内容详细,步骤清晰,细节处理妥当,希望这篇“Pytorch中的model.train()和model.eval()怎么使用”文章能帮助...
    99+
    2023-07-06
  • PyTorch中的train()、eval()和no_grad()怎么使用
    本篇内容介绍了“PyTorch中的train()、eval()和no_grad()怎么使用”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!什么...
    99+
    2023-07-05
  • Pytorch使用transforms
    首先,这次讲解的tansforms功能,通俗地讲,类似于在计算机视觉流程里的图像预处理部分的数据增强。 transforms的原理: 说明:图片(输入)通过工具得到结果(输出),这个...
    99+
    2024-04-02
  • pytorch中如何使用model.eval()和BN层
    这篇文章给大家分享的是有关pytorch中如何使用model.eval()和BN层的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。代码如下class ConvNet(nn.module): &n...
    99+
    2023-06-15
  • ubuntu中pytorch怎么安装和使用
    要在Ubuntu中安装PyTorch,可以使用conda或pip进行安装。以下是使用conda安装PyTorch的步骤: 首先,确...
    99+
    2024-03-01
    ubuntu pytorch
  • Pytorch BertModel的使用说明
    基本介绍 环境: Python 3.5+, Pytorch 0.4.1/1.0.0 安装: pip install pytorch-pretrained-bert 必需参数: ...
    99+
    2024-04-02
  • pytorch 中nn.Dropout的使用说明
    看代码吧~ Class USeDropout(nn.Module): def __init__(self): super(DropoutFC, se...
    99+
    2024-04-02
  • pytorch--之halfTensor的使用详解
    证明出错在dataloader里面 在pytorch当中,float16和half是一样的数据结构,都是属于half操作, 然后dataloader不能返回half值,所以在dat...
    99+
    2024-04-02
  • Pytorch中的gather使用方法
    官方说明 gather可以对一个Tensor进行聚合,声明为:torch.gather(input, dim, index, out=None) → Tensor 一般来说有三个参数...
    99+
    2024-04-02
  • pytorch中transforms的使用详解
    目录transformsToTensortransforms使用为什么需要tensor数据类型呢?常见的transforms内置方法__call__()NormalizeResize...
    99+
    2024-04-02
  • Pytorch中transforms.Resize()的简单使用
    目录transforms.Resize()的简单使用transforms.Resize([224, 224])解读transforms.Resize()的简单使用 简单来说就是调整P...
    99+
    2024-04-02
  • PyTorch中的nn.Embedding怎么使用
    这篇“PyTorch中的nn.Embedding怎么使用”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“PyTorch中的nn...
    99+
    2023-07-02
  • Pytorch基础之torch.randperm的使用
    目录Pytorch torch.randperm的使用torch.randn和torch.rand有什么区别均匀分布标准正态分布总结Pytorch torch.randperm的使用...
    99+
    2023-02-02
    Pytorch torch.randperm torch.randperm的使用 torch.randperm
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作