返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch实现将label变成onehot编码的两种方式
  • 826
分享到

Pytorch实现将label变成onehot编码的两种方式

Pytorchlabelonehot编码onehot编码labelonehot编码 2023-02-01 09:02:34 826人浏览 八月长安

Python 官方文档:入门教程 => 点击学习

摘要

目录前言使用scatter_获得one hot 编码使用tensor.index_select获得one hot编码第二种针对分割网络的one_hot编码总结由于PyTorch不像T

由于PyTorch不像Tensorflow有谷歌巨头做维护,很多功能并没有很高级的封装,比如说没有tf.one_hot函数。

本篇介绍将一个mini batch的label向量变成形状为[batch size, class numbers]的one hot编码的两种方法,涉及到

  • tensor.scatter_
  • tensor.index_select

前言

本文将针对全连接网络和全卷积网络输出的形式不同,将one hot编码分两种情况。

  • 第一种针对网络输出是二维,即全连接层的输出形式, [Batchsize, Num_class]
  • 第二种针对输出是四维特征图,即分割网络的输出形式,[Batchsize, Num_class, H,W]

先将第一种情况

使用scatter_获得one hot 编码

我相信在CSDN上找这个函数用法的人都是看不懂官方介绍的,所以我不会像其他地方那样,搬官方教程,我也是琢磨了很久才看懂这个函数,但函数声明还是要看看的。

tensor.scatter_(dim, index, src) 
  • dim : 指定了覆盖数据是从哪个轴作为依据。后面再详细解释。值的范围是从0到 sum(tensor.shape)-1
  • index : 告诉函数要将src中对应的值放到tensor的哪个位置。index的shape要和src一致,或者src可以通过广播机制实现shape一致。
  • src : 保存了想用来覆盖tensor的值

我们先看一个例子,例子从别的博客copy过来,但我会做更加详细的介绍。觉得讲得好请留言作为鼓励。

>>> x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

注意到dim为0,代表以第一个维度作为依托。index是一个二维数组

[0,1,2,0,0]
[2,0,0,1,2]

那么我们要覆盖tensor的位置有10个,分别为

[0,0];[1,1];[2,2];[0,3];[0,4]
[2,0];[0,1];[0,2];[1,3];[2,4]

dim指定了index我们要将index的值作为哪一个轴的值。其他轴就是按照0到max shape -1变化罢了。比如说dim为0,那么index的值都作为坐标的第一个位置的值,另一个位置从0到4变换。

你们可以验证下,是不是这10个位置被覆盖了。10个位置的第一个轴是index的数字,第二个数字是index中的列数,从0到4。

要覆盖的位置有了,那么用什么值覆盖呢?别忘了我们的index的维度和src是一样的。index中选择什么位置的坐标,就对应用src对应的位置的值代替。

比如说要代替tensor中[0,0]的值,index中[0,0]就是第0行第0列对应的位置,那我们用src第0行第0列的值代替tensor的值。大家可以去验证一下。

我们看看下面的的情况,如果dim为1呢。

>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z

先分析一下

dim为1,那么index的值都作为坐标的第2个位置的值,第一个位置的值应该从0到1变化。

所以要被代替的位置有

[0,2];[1,3]

而[0,2]的位置要填入的值为1.23,[1,3]要填入的值为1.23。(广播机制将1.23这个标量扩展到了shape为(2,1))

好的,函数用法知道了。我们现在看看如何用该函数将label编码为one hot编码。

首先设想一个batch size为8的label。有10类,所以label中的数字应该是从0到9的。

import torch as t
import numpy as np

batch_size = 8
class_num = 10
label = np.random.randint(0,class_num,size=(batch_size,1))
label = t.LongTensor(label)

我们就获得了一个label,shape是(8,1),必须是2维。如果是(8,)下面的内容会报错的。

y_one_hot = t.zeros(batch_size,class_num).scatter_(1,label,1)
print(y_one_hot)

'''
tensor([[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]])
'''

搞定。下面我们看下面一种方法。

使用tensor.index_select获得one hot编码

还是先看下index_select的用法。

tensor.index_select( dim, index, out=None)
  • dim: 指定按什么维度取tensor中的向量
  • index: 是一个一维的张量。描述了按照dim维度取出tensor对应的index值的向量。

我们不看例子了,直接看方法,以此为例。

ones = torch.sparse.torch.eye(class_num)
return ones.index_select(0,label)

这里的label是一维的向量,不是二维的。因为index制定了必须是一维的

先生成一个单位矩阵,尺寸是[class_num, class_num]。

dim为0,以为这按照行来取tensor的向量。具体取哪一行呢,就是label中的值了。

这时我们应该也明白为啥这两行代码能实现one hot编码了吧。

如果label是[ 1,3,0],有四类。那我们得到就是

[0,1,0,0]
[0,0,0,1]
[1,0,0,0]

第二种针对分割网络的one_hot编码

对于分割类任务,网络的GT肯定是二维数组,而不是像分类任务那样的一维数组了。而对于分割任务,我们将其视作很多个像素值的分类任务,将ground truth 直接 reshape为向量形式,然后用上面的方法转为one hot编码,然后再reshape回来。核心是不变的。

下面举个例子。

import torch
import numpy as np

gt = np.random.randint(0,5, size=[15,15])  #先生成一个15*15的label,值在5以内,意思是5类分割任务
gt = torch.LongTensor(gt)

def get_one_hot(label, N):
    size = list(label.size())
    label = label.view(-1)   # reshape 为向量
    ones = torch.sparse.torch.eye(N)
    ones = ones.index_select(0, label)   # 用上面的办法转为换one hot
    size.append(N)  # 把类别输目添到size的尾后,准备reshape回原来的尺寸
    return ones.view(*size)


gt_one_hot = get_one_hot(gt, 5)
print(gt_one_hot)
print(gt_one_hot.shape)

print(gt_one_hot.argmax(-1) == gt)  # 判断one hot 转换方式是否正确,全是1就是正确的

另外注意,在Pytorch中,如果要和网络输出的特征图一起计算loss,还要把上面输出的one hot编码的最后一个维度使用permute转到通道维度上。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: Pytorch实现将label变成onehot编码的两种方式

本文链接: https://lsjlt.com/news/193779.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Pytorch实现将label变成onehot编码的两种方式
    目录前言使用scatter_获得one hot 编码使用tensor.index_select获得one hot编码第二种针对分割网络的one_hot编码总结由于Pytorch不像T...
    99+
    2023-02-01
    Pytorch label one hot编码 one hot编码 label one hot编码
  • Jmeter实现Base64编码的两种方式
    Jmeter实现Base64编码有两种方式: 1、如果安装的Jmeter版本内置提供了Base64加密函数,可以直接使用该内置函数,方法如下: 点击Tools --> 函数助手...
    99+
    2024-04-02
  • js将列表组装成树结构的两种实现方式分享
    目录前言背景介绍实现方案递归法资源总结前言 工作中偶尔就会遇到后端同学丢来一个列表,要我们自己组装成一个树结构渲染到页面上,本文以两种不同方式探索生成树的算法思想。 背景介绍 可组...
    99+
    2024-04-02
  • php生成用户密码的两种方式
    目录一、md5密码 二、hash密码PS:php生成随机密码的几种方法方法一:方法二:方法三:方法四:在用户系统中,生成用户的密码是很重要的,而简单的密码必然给一些不法用户...
    99+
    2024-04-02
  • go实现base64编码的四种方式
    go的encoding/base64有四种编码方式: 编码方式说明StdEncoding常规编码URLEncodingURL safe 编码,相当于替换掉字符串中的特殊字符,+ 和 ...
    99+
    2023-03-07
    go实现base64编码 go base64编码
  • JAVA实现Base64编码的三种方式
    目录定义: 二进制文件可视化sun 包下的 BASE64Encoderapache 包下的 Base64util 包下的 Base64 (jdk8)定义: 二进制文件可视化 Base...
    99+
    2024-04-02
  • 实现Vue-router编程式导航的两种方法
    这篇文章主要介绍“实现Vue-router编程式导航的两种方法”,在日常操作中,相信很多人在实现Vue-router编程式导航的两种方法问题上存在疑惑,小编查阅了各式资料,整理出简单好用的操作方法,希望对大家解答”实现Vue-router编...
    99+
    2023-06-06
  • pytorch 实现变分自动编码器的操作
    本来以为自动编码器是很简单的东西,但是也是看了好多资料仍然不太懂它的原理。先把代码记录下来,有时间好好研究。 这个例子是用MNIST数据集生成为例子 # -*- coding: ...
    99+
    2024-04-02
  • js将列表组装成树结构的两种实现方式分别是什么
    js将列表组装成树结构的两种实现方式分别是什么,很多新手对此不是很清楚,为了帮助大家解决这个难题,下面小编将为大家详细讲解,有这方面需求的人可以来学习下,希望你能有所收获。前言工作中偶尔就会遇到后端同学丢来一个列表,要我们自己组装成一个树结...
    99+
    2023-06-26
  • redis实现缓存的两种方式
    本篇文章给大家分享的是有关redis实现缓存的两种方式,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。redis实现缓存大致为两种措施:一、脚本...
    99+
    2024-04-02
  • 详解SpringMVC的两种实现方式
    目录一、方法一:实现Controller接口二、方法二:使用注解开发一、方法一:实现Controller接口 这个在我的第一个SpringMVC程序中已经学习过了,在此不作赘述,现在...
    99+
    2022-11-13
    SpringMVC实现方式 SpringMVC的两种实现方式
  • Java生成二维码的几种实现方式
    前言 本文将基于Spring Boot介绍两种生成二维码的实现方式,一种是基于Google开发工具包,另一种是基于Hutool来实现; 下面我们将基于Spring Boot,并采用两种方式实现二维码的...
    99+
    2023-09-06
    java 开发语言
  • C#格式化JSON的两种实现方式
    目录实现功能:开发环境:实现代码:当我们拿到一大段JSON字符串的时候,分析起来简直头皮发麻,相信很大一部分朋友也都会直接去BEJSON等网站去做一个JSON格式化,已方便自己查看数...
    99+
    2024-04-02
  • javascript 实现纯前端将数据导出excel两种方式
    目录前言方法一方法二前言 修改之前项目代码的时候,发现前人导出excel是用纯javascript实现的。并没有调用后台接口。 之前从来没这么用过,记录一下。以备不时之需。 方法一 ...
    99+
    2024-04-02
  • 一篇文章学会两种将python打包成exe的方式
    目录前言详细步骤图形窗口打包总结前言 python 可以做网站应用,也可以做客户端应用。但是客户端应用需要运行 py 脚本,如果用户不懂 python 就是一件比较麻烦的事情。幸好 ...
    99+
    2024-04-02
  • python实现多线程的两种方式
    目前python 提供了几种多线程实现方式 thread,threading,multithreading ,其中thread模块比较底层,而threading模块是对thread做了一些包装,可以更加方便...
    99+
    2022-06-04
    两种 多线程 方式
  • python调用excel_vba的两种实现方式
    目录方法一: 方法二:方法一:  import win32com.client xl = win32com.client.Dispatch("Excel....
    99+
    2023-01-29
    python调用excel_vba python excel_vba调用
  • Pytorch中实现CPU和GPU之间的切换的两种方法
    目录方法一:.to(device)1.不知道电脑GPU可不可用时:2.指定GPU时3.指定cpu时:方法二:总结:如何在pytorch中指定CPU和GPU进行训练,以及cpu和gpu...
    99+
    2023-01-28
    Pytorch CPU和GPU切换 Pytorch CPU GPU
  • Android实现圆形图片的两种方式
    在项目中,我们经常会用到圆形图片,但是android本身又没有提供,那我只能我们自己来完成。 第一种方式,自定义CircleImageView: public class C...
    99+
    2022-06-06
    图片 Android
  • 两种java实现二分查找的方式
    目录1、二分查找算法思想2、二分查找图示说明3、二分查找优缺点3、java代码实现3.1 使用递归实现 3.1 不使用递归实现(while循环) 3.3 测试4、时间复杂度5、空间复...
    99+
    2024-04-02
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作