返回顶部
首页 > 资讯 > 后端开发 > Python >基于BCEWithLogitsLoss样本不均衡的处理方案
  • 781
分享到

基于BCEWithLogitsLoss样本不均衡的处理方案

2024-04-02 19:04:59 781人浏览 安东尼

Python 官方文档:入门教程 => 点击学习

摘要

最近在做deepfake检测任务(可以将其视为二分类问题,label为1和0),遇到了正负样本不均衡的问题,正样本数目是负样本的5倍,这样会导致FP率较高。 尝试将正样本的loss权

最近在做deepfake检测任务(可以将其视为二分类问题,label为1和0),遇到了正负样本不均衡的问题,正样本数目是负样本的5倍,这样会导致FP率较高。

尝试将正样本的loss权重增高,看BCEWithLogitsLoss的源码


Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)
 
Args:
    weight (Tensor, optional): a manual rescaling weight given to the loss
        of each batch element. If given, has to be a Tensor of size `nbatch`.
    size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
        the losses are averaged over each loss element in the batch. Note that for
        some losses, there are multiple elements per sample. If the field :attr:`size_average`
        is set to ``False``, the losses are instead summed for each minibatch. Ignored
        when reduce is ``False``. Default: ``True``
    reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
        losses are averaged or summed over observations for each minibatch depending
        on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
        batch element instead and ignores :attr:`size_average`. Default: ``True``
    reduction (string, optional): Specifies the reduction to apply to the output:
        ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
        ``'mean'``: the sum of the output will be divided by the number of
        elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average`
        and :attr:`reduce` are in the process of being deprecated, and in the meantime,
        specifying either of those two args will override :attr:`reduction`. Default: ``'mean'``
    pos_weight (Tensor, optional): a weight of positive examples.
            Must be a vector with length equal to the number of classes.

对其中的参数pos_weight的使用存在疑惑,BCEloss里的例子pos_weight = torch.ones([64]) # All weights are equal to 1,不懂为什么会有64个class,因为BCEloss是针对二分类问题的loss,后经过检索,得知还有多标签分类

多标签分类就是多个标签,每个标签有两个label(0和1),这类任务同样可以使用BCEloss。

现在讲一下BCEWithLogitsLoss里的pos_weight使用方法

比如我们有正负两类样本,正样本数量为100个,负样本为400个,我们想要对正负样本的loss进行加权处理,将正样本的loss权重放大4倍,通过这样的方式缓解样本不均衡问题。


criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([4]))
 
# pos_weight (Tensor, optional): a weight of positive examples.
#            Must be a vector with length equal to the number of classes.

pos_weight里是一个tensor列表,需要和标签个数相同,比如我们现在是二分类,只需要将正样本loss的权重写上即可。

如果是多标签分类,有64个标签,则


Examples::
 
    >>> target = torch.ones([10, 64], dtype=torch.float32)  # 64 classes, batch size = 10
    >>> output = torch.full([10, 64], 0.999)  # A prediction (logit)
    >>> pos_weight = torch.ones([64])  # All weights are equal to 1
    >>> criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    >>> criterion(output, target)  # -log(sigmoid(0.999))
    tensor(0.3135)

补充:Pytorch —— BCEWithLogitsLoss()的一些问题

一、等价表达

1、PyTorch


torch.sigmoid() + torch.nn.BCELoss()

2、自己编写


def ce_loss(y_pred, y_train, alpha=1):
    
    p = torch.sigmoid(y_pred)
    # p = torch.clamp(p, min=1e-9, max=0.99)  
    loss = torch.sum(- alpha * torch.log(p) * y_train \
           - torch.log(1 - p) * (1 - y_train))/len(y_train)
    return loss~

3、验证


import torch
import torch.nn as nn
torch.cuda.manual_seed(300)       # 为当前GPU设置随机种子
torch.manual_seed(300)            # 为CPU设置随机种子
def ce_loss(y_pred, y_train, alpha=1):
   # 计算loss
   p = torch.sigmoid(y_pred)
   # p = torch.clamp(p, min=1e-9, max=0.99)
   loss = torch.sum(- alpha * torch.log(p) * y_train \
          - torch.log(1 - p) * (1 - y_train))/len(y_train)
   return loss
py_lossFun = nn.BCEWithLogitsLoss()
input = torch.randn((10000,1), requires_grad=True)
target = torch.ones((10000,1))
target.requires_grad_(True)
py_loss = py_lossFun(input, target)
py_loss.backward()
print("*********BCEWithLogitsLoss***********")
print("loss: ")
print(py_loss.item())
print("梯度: ")
print(input.grad)
input = input.detach()
input.requires_grad_(True)
self_loss = ce_loss(input, target)
self_loss.backward()
print("*********SelfCELoss***********")
print("loss: ")
print(self_loss.item())
print("梯度: ")
print(input.grad)

测试结果:

在这里插入图片描述

– 由上结果可知,我编写的loss和pytorch中提供的j基本一致。

– 但是仅仅这样就可以了吗?NO! 下面介绍BCEWithLogitsLoss()的强大之处:

– BCEWithLogitsLoss()具有很好的对nan的处理能力,对于我写的代码(四层神经网络,层之间的激活函数采用的是ReLU,输出层激活函数采用sigmoid(),由于数据处理的问题,所以会导致我们编写的CE的loss出现nan:原因如下:

–首先神经网络输出的pre_target较大,就会导致sigmoid之后的p为1,则torch.log(1 - p)为nan;

– 使用clamp(函数虽然会解除这个nan,但是由于在迭代过程中,网络输出可能越来越大(层之间使用的是ReLU),则导致我们写的loss陷入到某一个数值而无法进行优化。但是BCEWithLogitsLoss()对这种情况下出现的nan有很好的处理,从而得到更好的结果。

– 我此实验的目的是为了比较CE和FL的区别,自己编写FL,则必须也要自己编写CE,不能使用BCEWithLogitsLoss()。

二、使用场景

二分类 + sigmoid()

使用sigmoid作为输出层非线性表达的分类问题(虽然可以处理多分类问题,但是一般用于二分类,并且最后一层只放一个节点)

三、注意事项

输入格式

要求输入的input和target均为float类型

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: 基于BCEWithLogitsLoss样本不均衡的处理方案

本文链接: https://lsjlt.com/news/125818.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • 基于BCEWithLogitsLoss样本不均衡的处理方案
    最近在做deepfake检测任务(可以将其视为二分类问题,label为1和0),遇到了正负样本不均衡的问题,正样本数目是负样本的5倍,这样会导致FP率较高。 尝试将正样本的loss权...
    99+
    2024-04-02
  • 怎样进行nginx部署基于http的负载均衡器
    怎样进行nginx部署基于http的负载均衡器,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。nginx跨多个应用程序实例的负载平衡是一种用于优化资源利用率,最大...
    99+
    2023-06-05
  • 基于逻辑回归的利用欠采样处理类别不平衡的
        In [2]: import pandas as pd import matplotlib.pyplot as plt import numpy as np %matplotlib inline ...
    99+
    2023-01-30
    不平衡 逻辑 类别
  • Nginx负载均衡方案的报错处理与容错策略
    引言:随着互联网的发展,Web服务的负载越来越大,为了提高系统的性能和可用性,负载均衡成为一个重要的技术手段。在负载均衡中,Nginx是一种常用的反向代理服务器,它能够将客户端请求分发到多台真实的Web服务器上。在实际的应用中,无论是硬件故...
    99+
    2023-10-21
    nginx 负载均衡 报错处理 容错策略
  • 负载均衡集群的session处理方法
    本篇内容主要讲解“负载均衡集群的session处理方法”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“负载均衡集群的session处理方法”吧!通常面临的问题从用户端来解释,就是当一个用户第一次访...
    99+
    2023-06-27
  • 基于FeignClient调用超时的处理方案
    FeignClient调用超时 出现问题的前提 SpringCloud间FeignClient调用出现ReadTimeOut的情况 FeignClient服务间调用的默认超时时间为2...
    99+
    2024-04-02
  • grpc-java k8s下的负载均衡处理方法
    目录前言现状负载均衡的方案一、客户端 dns 模式 二、客户端注册中心模式三、代理端走 ingress四、代理端 service mesh结语前言 grpc 因为是长连接的...
    99+
    2024-04-02
  • hadoop Hdfs的数据磁盘大小不均衡怎么处理
    这篇文章主要讲解了“hadoop Hdfs的数据磁盘大小不均衡怎么处理”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“hadoop Hdfs的数据磁盘大小不均衡怎么处理”吧!现象描述建集群的时...
    99+
    2023-06-19
  • 负载均衡器如何处理不同类型的流量
    负载均衡器处理不同类型的流量时通常会根据特定的算法和策略来进行分配。以下是一些常见的处理方式: 基于轮询:负载均衡器会按照事先设...
    99+
    2024-04-17
    负载均衡
  • 基于FeignClient怎么调用超时的处理方案
    这篇文章给大家分享的是有关基于FeignClient怎么调用超时的处理方案的内容。小编觉得挺实用的,因此分享给大家做个参考,一起跟随小编过来看看吧。FeignClient调用超时出现问题的前提SpringCloud间FeignClient调...
    99+
    2023-06-20
  • 关于WCF异常处理解决方案是怎样的
    这期内容当中小编将会给大家带来有关关于WCF异常处理解决方案是怎样的,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。异常处理在我们的程序中是不可缺少的,异常可以反馈我们信息,如果还不知道WCF异常的朋友请看...
    99+
    2023-06-17
  • docker swarm外部验证负载均衡时不生效的解决方案
    问题描述 我在本地创建了3个装了centos7的虚拟机, 并初始化了swarm集群, 即1个manager节点, 2个worker节点; 三台机子的ip分别是 192.168.124...
    99+
    2024-04-02
  • 如何理解Oracle 的“HA”和“LB”及怎样用脚本测试负载均衡
    本篇文章为大家展示了如何理解Oracle 的“HA”和“LB”及怎样用脚本测试负载均衡,内容简明扼要并且容易理解,绝对能使你眼前一亮,通过这篇文章的详细介绍希望你能有所收获。概述今天主要介绍一下ORACLE...
    99+
    2024-04-02
  • 基于spring-security 401 403错误自定义处理方案
    spring-security 401 403错误自定义处理 为了返回给前端统一的数据格式, 一般所有的数据都会以类似下面的方式返回: public class APIResul...
    99+
    2024-04-02
  • Springboot基于assembly的服务化打包方案是怎样的
    Springboot基于assembly的服务化打包方案是怎样的,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。    在使用assembly来打包spr...
    99+
    2023-06-19
  • 基于pgrouting的路径规划处理方法
    目录一、数据处理二、原理分析三、效率优化四、数据bug处理五、后续规划对于GIS业务来说,路径规划是非常基础的一个业务,一般公司如果处理,都会直接选择调用已经成熟的第三方的接口,比如...
    99+
    2024-04-02
  • C++中引用处理的基本方法
    目录1.引用的基本用法1.1 引用的实质1.2 引用的用法2.函数中的引用3.引用的本质4.指针的引用5.常量引用补充:引用和指针的区别(重要)总结1.引用的基本用法 引用是C++对...
    99+
    2022-12-21
    c++ 引用 c++引用调用 c++引用用法
  • ASP.NET MVC基于异常处理的解决方法
    今天就跟大家聊聊有关ASP.NET MVC基于异常处理的解决方法,可能很多人都不太了解,为了让大家更加了解,小编给大家总结了以下内容,希望大家根据这篇文章可以有所收获。EntLib的异常处理应用块(Exception Handling Ap...
    99+
    2023-06-17
  • 详解OpenMV图像处理的基本方法
    目录一、图像处理基础知识二、OpenMV图像处理的基本方法1. 感光元件相关名词解释2. 图像的基本运算3. 使用图像的统计信息4. 画图5. 寻找色块6. AprilTag实现标记...
    99+
    2024-04-02
  • History是基本原理及使用方法是怎样的
    这期内容当中小编将会给大家带来有关History是基本原理及使用方法是怎样的,文章内容丰富且以专业的角度为大家分析和叙述,阅读完这篇文章希望大家可以有所收获。当我们频繁使用 Linux 命令行时,有效地使用历史记录,可以大大提高工作效率。在...
    99+
    2023-06-15
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作