返回顶部
首页 > 资讯 > 后端开发 > Python >Pytorch:dtype不一致问题(expecteddtypeDoublebutgotdtypeFloat)
  • 594
分享到

Pytorch:dtype不一致问题(expecteddtypeDoublebutgotdtypeFloat)

Pytorch:dtype不一致Pytorch:dtypePytorchdtype 2023-02-21 12:02:19 594人浏览 薄情痞子

Python 官方文档:入门教程 => 点击学习

摘要

目录PyTorch:dtype不一致1. 说明2. 解决办法一3. 解决办法二Pytorch:盲点总结Pytorch:dtype不一致 RuntimeError: Expected

Pytorch:dtype不一致

RuntimeError: Expected object of Scalar type Double but Got scalar type Float for argument #3 ‘mat2’ in call to _th_addmm_out

1. 说明

在训练网络的过程中由于类型的冲突导致这种错误,主要是模型内部参数和输入类型不一致所导致的。主要有两个部分需要注意到:1.自己定义的变量要设置为一种数据类型;2.网络内部的变量类型也要统一。

2. 解决办法一

统一声明变量的类型。

# 将接下来创建的变量类型均为Double
torch.set_default_tensor_type(torch.DoubleTensor)

or

#将接下来创建的变量类型均为Float
torch.set_default_tensor_type(torch.FloatTensor)

一定要注意要在变量创建之间声明类型。

3. 解决办法二

在训练过程中加入一下两点即可:

# For your model
net = net.double()
# For your data
net(input_x.double)

Pytorch:盲点

1. 用conda安装pytorch-gpu时

用这个命令就够了,网上其他人说的都不好使

conda install pytorch cuda92

注意得是清华源的

2. 比较两个行向量或者列向量

以期求得布尔数组时,必须要保证两边的数据类型一样,并且返回的布尔数组类型和比较的两个向量结构保持一致。另外,所有torch.返回的东西,如果要取得里面的值,必须要加.item()

# !user/bin/python
# -*- coding: UTF-8 -*-
 
import torch
 
a = torch.arange(16).view(4, 4)
b = torch.argmax(a, dim = 1)
print([round(x.item(), 5) for x in b])
 
z = torch.tensor([3, 1, 2, 5], dtype = torch.long) # 类型必须保持一致
z = z.view(-1, 1)
b = b.view(-1, 1)
print(b)
print(z)
print(b == z)
# tensor([[ True],
#         [False],
#         [False],
#         [False]])
print(torch.sum(b == z)) # tensor(1)

3. numpy转tensor,其中,ndarray必须是等长的

x = np.array([[1, 2, 3], [4, 5, 6]]) # 正确
# x = np.array([[1, 2, 3], [4, 5]]) # 错误
print(torch.from_numpy(x))

4. unsqueeze (不改变原有数据)

import torch
import numpy as np
 
x = torch.tensor([[1, 2], [3, 4]])
print(x)
# tensor([[1, 2],
#         [3, 4]])
 
# 在第0维的地方插入一维
print(x.unsqueeze(0))
# tensor([[[1, 2],
#          [3, 4]]])
print(x.unsqueeze(0).shape) # torch.Size([1, 2, 2])
print(x.unsqueeze(1))
# tensor([[[1, 2]],
 
#         [[3, 4]]])
print(x.unsqueeze(1).shape) # torch.Size([2, 1, 2])

5. nn.embedding

# !user/bin/Python
# -*- coding: UTF-8 -*-
 
import torch
import torch.nn as nn
import torch.nn.functional as F
 
# 看看torch中的torch.nn.embedding
# embedding接收两个参数
# 第一个是num_embeddings,它表示词库的大小,则所有词的下标从0 ~ num_embeddings-1
# 第二个是embedding_dim,表示词嵌入维度
# 词嵌入层有了上面这两个必须有的参数,就形成了类,这个类可以有输入和输出
# 输入的数据结构不限,但是数据结构里面每个单元的元素必须指的是下标,即要对应0 ~ num_embeddings-1
# 输出的数据结构和输入一样,只不过将下标换成对应的词嵌入
# 最开始的时候词嵌入的矩阵是随机初始化的,但是作为嵌入层,会不断的学习参数,所以最后训练完成的参数一定是学习完成的
# embedding层还可以接受一个可选参数padding_idx,这个参数指定的维度,但凡输入的时候有这个维度,输出一律填0
 
# 下面来看一下吧
embedding = nn.Embedding(10, 3)
inputs = torch.tensor([[1, 2, 4, 5],
                       [4, 3, 2, 9]])
print(embedding(inputs))
 
# tensor([[[ 0.3721,  0.3502,  0.8029],
#          [-0.2410,  0.0723, -0.6451],
#          [-0.4488,  1.4382,  0.1060],
#          [-0.1430, -0.8969,  0.7086]],
#
#         [[-0.4488,  1.4382,  0.1060],
#          [ 1.3503, -0.0711,  1.5412],
#          [-0.2410,  0.0723, -0.6451],
#          [-0.3360, -0.7692,  2.2596]]], grad_fn=<EmbeddingBackward>)

6. eq

# !user/bin/python
# -*- coding: UTF-8 -*-
 
 
# eq
import torch
a = torch.tensor([1, 2, 2, 3])
b = torch.tensor((1, 3, 2, 3))
print(a.eq(b)) # tensor([ True, False,  True,  True])
print(a.eq(0)) # tensor([False, False, False, False])
print(a.eq(2)) # tensor([False,  True,  True, False])

7. expand

# expand
# expand不修改原有值
# 只能扩展维度是1的那个维度
# 另外,expand还能增加新的维度,不过新的维度必须在已知维度之前比如从2 * 3 到 5 * 2 * 3
a = torch.tensor([[1, 2, 3]])
print(a.size()) # torch.Size([1, 3]),第一维是1,所以只能扩展第一维
print(a.expand(3, 3))
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])
print(a) # tensor([[1, 2, 3]])
 
a = torch.tensor([[1], [2], [3]])
print(a.size()) # torch.Size([3, 1]),第二维是1,只能扩展第二维
print(a.expand(-1, 4)) # 第一维用-1代表第一维不变,还是3
# tensor([[1, 1, 1, 1],
#         [2, 2, 2, 2],
#         [3, 3, 3, 3]])
 
a = torch.randn(2, 1, 1, 4) # 同理,只能扩展第2和第3维
print(a.expand(-1, 2, 3, -1))
 
a = torch.tensor([1, 2, 3])
print(a.size()) # torch.Size([3])
print(a.expand(2, 3))
# tensor([[1, 2, 3],
#         [1, 2, 3]])
# print(a.expand(3, 2)) 会报错,因为新维度跑到原有维度之后了
print(a.expand(3, 2, -1))
 
a = torch.rand(2, 3)
print(a.expand(5, 2, 3)) # 正确

8. repeat

# repeat
# repeat不改变原有值
# repeat传入的参数的个数等于原有值的维度个数,表示将对应维度的内容重复多少次
 
a = torch.tensor([1, 2, 3]) # size = [3]
print(a.repeat(2)) # 将第一个维度的重复两次 tensor([1, 2, 3, 1, 2, 3])
print(a) # tensor([1, 2, 3])
 
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 2 * 3
print(a.repeat(2, 3)) # 第一维重复两次,第二维重复三次,就变成了4 * 9
# tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
#         [4, 5, 6, 4, 5, 6, 4, 5, 6],
#         [1, 2, 3, 1, 2, 3, 1, 2, 3],
#         [4, 5, 6, 4, 5, 6, 4, 5, 6]])

9. torch.stack 这个函数接受一个由张量组成的元组或者列表

与cat的区别是,stack会先增加一维,然后再进行拼接

10. 对于一维张量a,维度为m,a[None, :]的shape为1×m,a[:, None]的shape为m×1

11. 两个不同维度的矩阵比较,利用了广播机制

12. torch.nn.CrossEntropyLoss(),这个类比较复杂,我们慢点说。

首先,这是一个类,定义如下:

class CrossEntropyLoss(_WeightedLoss):
    def __init__(self, weight=None, size_average=None, ignore_index=-100,
                 reduce=None, reduction='mean'):
        super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
        self.ignore_index = ignore_index
 
    def forward(self, input, target):
        return F.cross_entropy(input, target, weight=self.weight,
                               ignore_index=self.ignore_index, reduction=self.reduction)

可以看到这个类非常简单。在构造函数中,我们重点关注weight和reduction

  • weight:表示每个类的权重,这个在样本不平衡分类问题中很有用,它应该是一个长度为C的一维张量,C为类别数。
  • reduction:用于控制最终的输出结果,默认为mean,如果是mean,返回的是一个数,即shape为torch.tensor([]) ,如果是none,则返回的情况有两种,一种是(N,),一种是(N, d_1, ..., d_K),至于是哪一种,得看我们输入到forward函数中的input和target是哪种形状的。

可以看到,forward函数直接调用F.cross_entropy这个函数,这个函数中的weight和reduction我们已经讲过,不再赘述。我们重点将input和target应该是什么形状的。

这里分两种情况。

第一种,input是二维,即(N, C),N代表batch_size,C代表类别,即对于每一个batch,都有一个一维长度为C的向量,通常这里的C表示的是对应类别的得分。target表示标签,它应该是(N,),表示每一个batch的正确标签是什么,内容应该在0~C-1之中。如果reduction取默认值mean,则返回的是一个数,这个数是每个batch的损失的平均。如果是none,则返回的是(N,),即代表每一个batch的损失,没有进行平均。

第二种,input的维度是(N, C, d_1, ... d_K),这里K>=1。N表示batch_size,C表示类别数,d_1...d_K 可以看做一个整体,表示在某个批次,某个类别上,损失并不是一个数字,而是一个张量。按这种方式理解的话,第一种就可以理解为,在某个批次,某个类别上,损失是一个数,而第二种不是一个数,是一个张量,这个张量的形状是(d_1, ..., d_K),这个张量的每一个位置都代表对应位置的损失。拿NLP中seq2seq的损失函数为例,decoder的输出应是(batch_size, seq_len, vocab_size),label为(batch_size, seq_len),那么这里我们显然应该用第二种,因为在某个批次,某个类别上,我们的损失函数并不单单是一个数,而是一个(seq_len,)的张量,对于长度为seq_len的每个单词,每个位置都有一个损失,所以我们要用第二种。所以这里,我们需要将input的后两维置换,即transpose(0, 1),使其变成(batch_size, vocab_size, seq_len)。对于第二种,target,也就是label的维度应该是(N, d1, ..., d_K),表示对于每个批次,这个“损失张量”(这个名字我自己起的)的每个位置对应的标签。因此,对于seq2seq来说,label的维度应该是(N, seq_len)。对于第二种情况,如果reduction是mean,输出还是一个数,这个数表示所有批次,“损失张量”所有位置的损失的平均值。如果是none,输出为(N, d1, ...d_K),表示每一个批次,“损失张量”每一个位置的损失。

下面的例子代表第一种情况。

13. torch中mask作为torch的下标,可以不必和torch一样的shape

当然,也可以直接用下表来进行赋值

14. repeat in torch & numpy is very different.

look at the pic above, repeat in torch does not have a argument 'axis =', and it regard the whole tensor as one which can not be seperated.

Next, we will talk about repeat function in numpy. We would like to divide it into two part. The fORMer is that the array is 1-D and the other is N-D

  • part 1:

if array is one-dimensional, there is no need to specify the 'axis' argument. It will seperate each number, and repeat them respectively.

  • part 2:

Here the shape of array c is (2, 2), so we can specify the 'axis'. The condition will be as follows:

(1) If axis is not specified, it will firstly flatten the array and continue repeating operation which is like part 1.

(2) If axis = 0, it will repeat along the first dimension.

(3) If axis = 1, it will repeat along the second dimension.

15. torch的多维张量a,如果a[-1],默认代表第一axis的最后一维,即等价于a[-1,:,:,...,:]

16. torch.cat(XXX, dim = ?) 其中XXX可以是list,不一定非要tensor

17. contiguous()

亦即,对tensor进行transpose时,其实是浅拷贝,如果要深拷贝,就要在后面补上一个contiguous()

18. net是网络模型的一个object,调用zero_grad函数,表示将网络内部所有参数的梯度置0

19. torch.utils.data中的DataLoader和TensorDataset

先用TensorDataset对打包的文件进行校对,在DataLoader中会指定batch_size,假设原本data和label各N个,那么DataLoader会将其打乱后,每batch_size为一组,共N/batch_size个。

假设DataLoader返回的是iter,对iter进行for遍历的时候,假设每一轮取样为batch,则batch的长度是2,batch为一个list,这个list里面有两个元素,一个是data,一个是label,data和label的第一维大小都是batch_size。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

--结束END--

本文标题: Pytorch:dtype不一致问题(expecteddtypeDoublebutgotdtypeFloat)

本文链接: https://lsjlt.com/news/196902.html(转载时请注明来源链接)

有问题或投稿请发送至: 邮箱/279061341@qq.com    QQ/279061341

猜你喜欢
  • Pytorch:dtype不一致问题(expecteddtypeDoublebutgotdtypeFloat)
    目录Pytorch:dtype不一致1. 说明2. 解决办法一3. 解决办法二Pytorch:盲点总结Pytorch:dtype不一致 RuntimeError: Expected ...
    99+
    2023-02-21
    Pytorch:dtype不一致 Pytorch:dtype Pytorch dtype
  • Pytorch:dtype不一致问题如何解决
    这篇“Pytorch:dtype不一致问题如何解决”文章的知识点大部分人都不太理解,所以小编给大家总结了以下内容,内容详细,步骤清晰,具有一定的借鉴价值,希望大家阅读完这篇文章能有所收获,下面我们一起来看看这篇“Pytorch:dtype不...
    99+
    2023-07-05
  • MySQL主从不一致的问题分析
    这篇文章主要讲解了“MySQL主从不一致的问题分析”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“MySQL主从不一致的问题分析”吧!  &nbs...
    99+
    2024-04-02
  • PostgreSQL问题分析1:时间线不一致
    一、问题:requested timeline %u does not contain minimum recovery point %X/%X on timeline %u 该日志在代码中的位置如下: S...
    99+
    2024-04-02
  • css怎么解决高度不一致问题
    这篇文章主要为大家展示了“css怎么解决高度不一致问题”,内容简而易懂,条理清晰,希望能够帮助大家解决疑惑,下面让小编带领大家一起研究并学习一下“css怎么解决高度不一致问题”这篇文章吧。 ...
    99+
    2024-04-02
  • Java与MySQL时间不一致问题解决
    目录一、问题情况描述二、CST时区混乱1. CST有四种含义2. 什么是时区三、绝对时间与本地时间1. 绝对时间2. 本地时间3. 时区偏移量四、MySQL服务端时区1. syste...
    99+
    2023-01-05
    Java与MySQL时间不一致 MySQL时间不一致
  • 关于两次访问接口的sessionid不一致问题
    在测试验证邮箱、注册逻辑时,出现验证码错误的问题。验证码是存放在session内的,在排除了逻辑代码的问题后,检查出这两次访问接口的sessionid并不一致,而在swagger测试接口时是一致的。...
    99+
    2023-09-28
    ajax javascript 前端 java 服务器
  • redis怎么解决缓存不一致的问题
    本文小编为大家详细介绍“redis怎么解决缓存不一致的问题”,内容详细,步骤清晰,细节处理妥当,希望这篇“redis怎么解决缓存不一致的问题”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新...
    99+
    2024-04-02
  • Mysql 连接数与配置文件不一致问题
    在一次部署物理机时开发那边报无法连接数据库,登录后台查看发现登录不上,报连接数太多,重启数据库登录后查看连接数 查询Mysql 最大连接数: mysql> select @@max_c...
    99+
    2024-04-02
  • lsof处理df和du大小不一致的问题
    APP服务器根满了,一直报警df显示根分区已经使用了90%的空间,但是du根分区总和只有40G左右,该分区应该没有大量小文件,所以应该不会产生大量小文件导致的block写满的问题.网上搜了下发现有可能是有程...
    99+
    2024-04-02
  • 解决Beanutils.copyproperties实体类对象不一致的问题
    今天给大家分析一个解决Beanutils.copyproperties实体类对象名不一致的解决方法,一般我们在两个对象拷贝的问题上,我个人用的比较多的就是Beanutils.copy...
    99+
    2024-04-02
  • 解决vue前后端端口不一致的问题
    vue前后端端口不一致 在config index.js文件中 引入如下代码即可 proxyTable: { '/api': { target: 'http://local...
    99+
    2024-04-02
  • 如何解决BOX模型解释不一致问题
    小编给大家分享一下如何解决BOX模型解释不一致问题,相信大部分人都还不怎么了解,因此分享这篇文章给大家参考一下,希望大家阅读完这篇文章后大有收获,下面让我们一起去了解一下吧! 在FF和IE中的BOX模型解...
    99+
    2024-04-02
  • C语言中函数返回值不一致问题
    目录C语言函数返回值不一致函数的返回值注意事项函数的返回值注意事项总结C语言函数返回值不一致 在运行成程序上有时会发现函数内部的值与返回到主函数的值会相差很多出现随机值,但是它们的地...
    99+
    2023-02-24
    C语言函数 函数返回值不一致 C语言函数返回值
  • mybatis映射和实际类型不一致的问题
    目录mybatis映射和实际类型不一致原因分析小结一下解决方法mybatis映射器Mapper(结果映射以及解决列名不一致)结果映射:(resultMap, resultType)1...
    99+
    2024-04-02
  • redis主从数据不一致问题如何解决
    使用Redis的复制(Replication)功能来保证数据一致性。可以将主节点写入的数据同步到从节点,确保从节点的数据与主节点...
    99+
    2024-04-09
    redis
  • NoSQL怎么处理数据的不一致性问题
    NoSQL数据库通常使用多种方法来处理数据的不一致性问题,具体取决于数据库的类型和实现方式。以下是一些常见的方法: ACID属性...
    99+
    2024-05-07
    NoSQL
  • MySQL从库的列类型不一致导致的复制异常问题
    官方文档:https://dev.mysql.com/doc/refman/5.6/en/replication-features-differing-tables.htmlslave_type_conve...
    99+
    2024-04-02
  • 记一次CurrentDirectory导致的问题
    在编程中,CurrentDirectory是一个表示当前工作目录的属性。它指示了程序在运行时所在的目录。一次由CurrentDire...
    99+
    2023-09-15
    问题
  • redis 数据库主从不一致问题解决方案
     在聊数据库与缓存一致性问题之前,先聊聊数据库主库与从库的一致性问题。   问:常见的数据库集群架构如何? 答:一主多从,主从同步,读写分离。 如上图: (1)一个主库提供写服务 (2)多个从库提供读服务,可以增加从库提升读性能 (3)主...
    99+
    2020-05-07
    redis 数据库主从不一致问题解决方案
软考高级职称资格查询
编程网,编程工程师的家园,是目前国内优秀的开源技术社区之一,形成了由开源软件库、代码分享、资讯、协作翻译、讨论区和博客等几大频道内容,为IT开发者提供了一个发现、使用、并交流开源技术的平台。
  • 官方手机版

  • 微信公众号

  • 商务合作